dirac.dataprep¶
- class sodirac.dataprep.GraphDS(*args: Any, **kwargs: Any)[source]¶
Bases:
DatasetPyTorch Dataset for single-cell/spatial profiles with optional labels and domains.
- Parameters:
counts (np.ndarray or sparse.csr_matrix) – Shape [cells, genes]. Expression/count matrix.
labels (np.ndarray or sparse.csr_matrix, optional) – Shape [cells,]. Integer cell-type labels.
domains (np.ndarray or sparse.csr_matrix, optional) – Shape [cells,]. Integer domain labels.
transform (Callable, optional) – Callable applied to each sample dict.
num_domains (int, optional) – Total number of domains for one-hot encoding of domains. Default: -1.
- Return type:
None
Notes
Dense copies are created for input arrays when needed.
One-hot encodings are produced for labels/domains when provided.
- __init__(counts: scipy.sparse.csr.csr_matrix | numpy.ndarray, labels: scipy.sparse.csr.csr_matrix | numpy.ndarray | None = None, domains: scipy.sparse.csr.csr_matrix | numpy.ndarray | None = None, transform: Callable | None = None, num_domains: int = -1) None[source]¶
- _process_labels(labels: numpy.ndarray | scipy.sparse.csr_matrix | None) tuple[source]¶
Convert labels to torch tensors and one-hot encodings.
- Parameters:
labels (np.ndarray or sparse.csr_matrix, optional) – Shape [cells,]. Integer labels.
- Returns:
(labels_tensor, one_hot) – Dense label tensor and one-hot tensor, or (None, None) if labels is None.
- Return type:
Tuple[torch.LongTensor or None, torch.FloatTensor or None]
Notes
One-hot dimension equals the number of unique labels in the batch.
- _process_domains(domains: numpy.ndarray | scipy.sparse.csr_matrix | None, num_domains: int) tuple[source]¶
Convert domain labels to torch tensors and one-hot encodings.
- Parameters:
domains (np.ndarray or sparse.csr_matrix, optional) – Shape [cells,]. Integer domain labels.
num_domains (int) – Number of domain categories for one-hot encoding.
- Returns:
(domains_tensor, one_hot) – Dense domain tensor and one-hot tensor, or (None, None) if domains is None.
- Return type:
Tuple[torch.LongTensor or None, torch.FloatTensor or None]
- sodirac.dataprep.balance_classes(y: numpy.ndarray, class_min: int = 256, random_state: int | None = None) numpy.ndarray[source]¶
Balance class indices by undersampling majorities and oversampling minorities.
- Parameters:
- Returns:
balanced_idx – Balanced indices (with replacement for minority classes).
- Return type:
np.ndarray
Notes
The smallest effective class count used is max(min_count, class_min).
- class sodirac.dataprep.GraphDataset(*args: Any, **kwargs: Any)[source]¶
Bases:
InMemoryDatasetIn-memory PyG dataset for a paired graph with features, batches, domains, and labels.
- Parameters:
data (np.ndarray) – Shape [num_nodes, num_features]. Node features.
batch (np.ndarray) – Shape [num_nodes]. Batch assignment per node.
domain (np.ndarray) – Shape [num_nodes]. Domain labels per node.
edge_index (torch.Tensor) – Shape [2, num_edges]. Edge index.
label (np.ndarray, optional) – Shape [num_nodes]. Node labels. Default: None.
transform (callable, optional) – A callable that takes and returns a torch_geometric.data.Data object.
- graph_data¶
Graph data object with fields: - data_0 (FloatTensor), batch_0 (LongTensor), domain_0 (LongTensor),
edge_index (Tensor), idx (LongTensor), label (LongTensor or None), num_nodes (int).
- Type:
torch_geometric.data.Data
Notes
This dataset contains a single graph (length = 1).
- class sodirac.dataprep.GraphDataset_unpaired(*args: Any, **kwargs: Any)[source]¶
Bases:
InMemoryDatasetIn-memory PyG dataset for an unpaired graph with features, domains, and labels.
- Parameters:
data (np.ndarray) – Shape [num_nodes, num_features]. Node features.
domain (np.ndarray) – Shape [num_nodes]. Domain labels per node.
edge_index (torch.Tensor) – Shape [2, num_edges]. Edge index.
label (np.ndarray, optional) – Shape [num_nodes]. Node labels. Default: None.
transform (callable, optional) – A callable that takes and returns a torch_geometric.data.Data object.
- graph_data¶
Graph data object with fields: - data (FloatTensor), domain (LongTensor),
edge_index (Tensor), idx (LongTensor), label (LongTensor or None), num_nodes (int).
- Type:
torch_geometric.data.Data
Notes
This dataset contains a single graph (length = 1).