Scalable Data Loading for Deep Learning on Large-Scale Single-Cell Omics
scDataset is a flexible and efficient PyTorch IterableDataset for large-scale single-cell omics datasets. It supports a variety of data formats (e.g., AnnData, HuggingFace Datasets, NumPy arrays) and is designed for high-throughput deep learning workflows. While optimized for single-cell data, it is general-purpose and can be used with any dataset.
- Flexible Data Source Support: Integrates seamlessly with AnnData, HuggingFace Datasets, NumPy arrays, PyTorch Datasets, and more.
- Scalable: Handles datasets with billions of samples without loading everything into memory.
- Efficient Data Loading: Block sampling and batched fetching optimize random access for large datasets.
- Dynamic Splitting: Split datasets into train/validation/test dynamically, without duplicating data or rewriting files.
- Custom Hooks: Apply transformations at fetch or batch time via user-defined callbacks.
Install the latest release from PyPI:
pip install scDatasetOr install the latest development version from GitHub:
pip install git+https://github.com/scDataset/scDataset.gitscDataset v0.2.0 uses a strategy-based approach for flexible data sampling:
from scdataset import scDataset, Streaming
from torch.utils.data import DataLoader
# Create dataset with streaming strategy
data = my_data_collection # Any indexable object (numpy array, AnnData, etc.)
strategy = Streaming()
dataset = scDataset(data, strategy, batch_size=64)
loader = DataLoader(dataset, batch_size=None) # scDataset handles batching internallyNote: Set
batch_size=Nonein the DataLoader to delegate batching toscDataset.
from scdataset import Streaming
# Simple sequential access
strategy = Streaming()
dataset = scDataset(data, strategy, batch_size=64)
# Sequential with buffer-level shuffling (like Ray Dataset or WebDataset). The buffer size is equal to batch_size * fetch_factor (defined in the scDataset init)
strategy = Streaming(shuffle=True)
dataset = scDataset(data, strategy, batch_size=64, fetch_factor=8)
# Use only a subset of indices
train_indices = [0, 2, 4, 6, 8, ...] # Your training indices
strategy = Streaming(indices=train_indices)
dataset = scDataset(data, strategy, batch_size=64)from scdataset import BlockShuffling
# Shuffle blocks while maintaining some data locality
strategy = BlockShuffling(block_size=8)
dataset = scDataset(data, strategy, batch_size=64)
# With subset of indices
strategy = BlockShuffling(block_size=8, indices=train_indices)
dataset = scDataset(data, strategy, batch_size=64)from scdataset import BlockWeightedSampling
# Uniform weighted sampling
strategy = BlockWeightedSampling(total_size=10000, block_size=16)
dataset = scDataset(data, strategy, batch_size=64)
# Custom weights (e.g., for imbalanced data)
sample_weights = compute_weights(data) # Your weight computation
strategy = BlockWeightedSampling(
weights=sample_weights,
total_size=5000,
block_size=16
)
dataset = scDataset(data, strategy, batch_size=64)from scdataset import ClassBalancedSampling
# Automatically balance classes from labels
cell_types = ['T_cell', 'B_cell', 'NK_cell', ...] # Your class labels
strategy = ClassBalancedSampling(cell_types, total_size=8000)
dataset = scDataset(data, strategy, batch_size=64)Handle multiple related data modalities that need to be indexed together:
from scdataset import MultiIndexable, Streaming
# Group multiple data modalities
multi_data = MultiIndexable(
genes=gene_expression_data, # Shape: (n_cells, n_genes)
proteins=protein_data, # Shape: (n_cells, n_proteins)
metadata=cell_metadata # Shape: (n_cells, n_features)
)
# Use with any sampling strategy
strategy = Streaming()
dataset = scDataset(multi_data, strategy, batch_size=64)
for batch in dataset:
genes = batch['genes'] # Gene expression for this batch
proteins = batch['proteins'] # Corresponding protein data
metadata = batch['metadata'] # Corresponding metadataConfigure fetch_factor to fetch multiple batches worth of data at once:
strategy = BlockShuffling(block_size=16)
dataset = scDataset(
data,
strategy,
batch_size=64,
fetch_factor=8 # Fetch 8*64=512 samples at once
)
loader = DataLoader(
dataset,
batch_size=None,
num_workers=4,
prefetch_factor=9 # fetch_factor + 1
)We recommend setting prefetch_factor to fetch_factor + 1 for efficient data loading. For parameter details, see the original paper.
Apply custom transformations at fetch or batch time using the new callback system:
-
fetch_callback(collection, indices): Customizes how samples are fetched from the underlying data collection. Use this if your collection does not support batched indexing or requires special access logic.- Input: the data collection and an array of indices
- Output: the fetched data
-
fetch_transform(fetched_data): Transforms each fetched chunk (e.g., sparse-to-dense conversion, normalization).- Input: the fetched data
- Output: the transformed data
-
batch_callback(fetched_data, batch_indices): Selects or arranges a minibatch from the fetched/transformed data.- Input: the fetched/transformed data and a list of batch indices within the chunk
- Output: the batch to yield
-
batch_transform(batch): Applies final processing to each batch before yielding (e.g., collation, augmentation).- Input: the batch
- Output: the processed batch
from scdataset import scDataset, Streaming
def fetch_transform(chunk):
# Example: convert sparse to dense, normalization, etc.
# Applied to entire fetched chunk
return chunk.toarray() if hasattr(chunk, 'toarray') else chunk
def batch_transform(batch):
# Example: batch-level augmentation or tensor conversion
import torch
return torch.from_numpy(batch).float()
strategy = Streaming()
dataset = scDataset(
data,
strategy,
batch_size=64,
fetch_transform=fetch_transform,
batch_transform=batch_transform
)from scdataset import scDataset, BlockShuffling, Streaming
from torch.utils.data import DataLoader
import numpy as np
# Your data
data = my_data_collection
train_indices = np.arange(0, 8000)
val_indices = np.arange(8000, 10000)
# Training with block shuffling
train_strategy = BlockShuffling(block_size=32, indices=train_indices)
train_dataset = scDataset(
data,
train_strategy,
batch_size=64,
fetch_factor=8
)
train_loader = DataLoader(
train_dataset,
batch_size=None,
num_workers=4,
prefetch_factor=9
)
# Validation with streaming (deterministic)
val_strategy = Streaming(indices=val_indices)
val_dataset = scDataset(
data,
val_strategy,
batch_size=64,
fetch_factor=8
)
val_loader = DataLoader(
val_dataset,
batch_size=None,
num_workers=4,
prefetch_factor=9
)
# Training loop
for epoch in range(num_epochs):
# Training
for batch in train_loader:
# Training code here
pass
# Validation
for batch in val_loader:
# Validation code here
passIf you use scDataset in your research, please cite the following paper:
@article{scdataset2025,
title={scDataset: Scalable Data Loading for Deep Learning on Large-Scale Single-Cell Omics},
author={D'Ascenzo, Davide and Cultrera di Montesano, Sebastiano},
journal={arXiv:2506.01883},
year={2025}
}scDataset v0.2.0 introduces breaking changes with a new strategy-based API. Here's how to migrate your code:
# v0.1.x - No longer supported
from scdataset import scDataset
dataset = scDataset(data, batch_size=64, block_size=8, fetch_factor=4)
dataset.subset(train_indices)
dataset.set_mode('train')# v0.2.0 - Strategy-based approach
from scdataset import scDataset, BlockShuffling, Streaming
# Training with shuffling
train_strategy = BlockShuffling(block_size=8, indices=train_indices)
train_dataset = scDataset(data, train_strategy, batch_size=64, fetch_factor=4)
# Evaluation with streaming
val_strategy = Streaming(indices=val_indices)
val_dataset = scDataset(data, val_strategy, batch_size=64, fetch_factor=4)Key Changes:
- Required strategy parameter: Must provide a
SamplingStrategyinstance - No more
subset()andset_mode(): Use strategyindicesparameter and different strategy types - Create separate datasets: For different splits instead of modifying a single instance
- New import: Import specific strategies like
Streaming,BlockShuffling, etc.
This project is licensed under the MIT License.
Contributions are welcome! Please open issues or pull requests on GitHub.
