#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 5/17/23 2:58 PM
# @Author : Chang Xu
# @File : dataprep.py
# @Email : changxu@nus.edu.sg
import logging
from typing import Union, Callable, Any, Iterable, List, Optional, Dict
import numpy as np
import torch
from scipy import sparse
from torch.utils.data import Dataset
from torch_geometric.data import InMemoryDataset, Data
[docs]class GraphDS(Dataset):
"""
PyTorch Dataset for single-cell/spatial profiles with optional labels and domains.
Parameters
----------
counts : np.ndarray or sparse.csr_matrix
Shape [cells, genes]. Expression/count matrix.
labels : np.ndarray or sparse.csr_matrix, optional
Shape [cells,]. Integer cell-type labels.
domains : np.ndarray or sparse.csr_matrix, optional
Shape [cells,]. Integer domain labels.
transform : Callable, optional
Callable applied to each sample dict.
num_domains : int, optional
Total number of domains for one-hot encoding of `domains`. Default: -1.
Returns
-------
None
Notes
-----
- Dense copies are created for input arrays when needed.
- One-hot encodings are produced for labels/domains when provided.
"""
[docs] def __init__(
self,
counts: Union[sparse.csr.csr_matrix, np.ndarray],
labels: Union[sparse.csr.csr_matrix, np.ndarray] = None,
domains: Union[sparse.csr.csr_matrix, np.ndarray] = None,
transform: Callable = None,
num_domains: int = -1,
) -> None:
super(GraphDS, self).__init__()
# type checks
if type(counts) not in (np.ndarray, sparse.csr_matrix):
msg = f"Counts is type {type(counts)}, must `np.ndarray` or `sparse.csr_matrix`"
raise TypeError(msg)
# densify counts if needed
counts = counts.toarray() if sparse.issparse(counts) else counts
self.counts = torch.FloatTensor(counts)
self.labels = self._process_labels(labels)
self.domains = self._process_domains(domains, num_domains)
self.transform = transform
self.indexes = torch.arange(self.counts.shape[0]).long()
[docs] def _process_labels(
self, labels: Optional[Union[np.ndarray, sparse.csr_matrix]]
) -> tuple:
"""
Convert labels to torch tensors and one-hot encodings.
Parameters
----------
labels : np.ndarray or sparse.csr_matrix, optional
Shape [cells,]. Integer labels.
Returns
-------
(labels_tensor, one_hot) : Tuple[torch.LongTensor or None, torch.FloatTensor or None]
Dense label tensor and one-hot tensor, or (None, None) if `labels` is None.
Notes
-----
One-hot dimension equals the number of unique labels in the batch.
"""
if labels is not None:
if not isinstance(labels, (np.ndarray, sparse.csr_matrix)):
raise TypeError(
f"Labels is type {type(labels)}, must be `np.ndarray` or `sparse.csr_matrix`"
)
labels = labels.toarray() if sparse.issparse(labels) else labels
labels = torch.from_numpy(labels).long()
labels_one_hot = torch.nn.functional.one_hot(
labels, num_classes=len(torch.unique(labels))
).float()
return labels, labels_one_hot
return None, None
[docs] def _process_domains(
self, domains: Optional[Union[np.ndarray, sparse.csr_matrix]], num_domains: int
) -> tuple:
"""
Convert domain labels to torch tensors and one-hot encodings.
Parameters
----------
domains : np.ndarray or sparse.csr_matrix, optional
Shape [cells,]. Integer domain labels.
num_domains : int
Number of domain categories for one-hot encoding.
Returns
-------
(domains_tensor, one_hot) : Tuple[torch.LongTensor or None, torch.FloatTensor or None]
Dense domain tensor and one-hot tensor, or (None, None) if `domains` is None.
"""
if domains is not None:
if not isinstance(domains, (np.ndarray, sparse.csr_matrix)):
raise TypeError(
f"Domains is type {type(domains)}, must be `np.ndarray` or `sparse.csr_matrix`"
)
domains = domains.toarray() if sparse.issparse(domains) else domains
domains = torch.from_numpy(domains).long()
domains_one_hot = torch.nn.functional.one_hot(domains, num_classes=num_domains).float()
return domains, domains_one_hot
return None, None
def __len__(self) -> int:
"""
Number of examples in the dataset.
Returns
-------
n : int
Dataset length.
"""
return self.counts.shape[0]
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
Retrieve a single sample with optional labels/domains.
Parameters
----------
idx : int
Index in `range(len(self))`.
Returns
-------
sample : dict
{
"input": torch.FloatTensor, # feature vector
"idx": torch.LongTensor, # original index
"output": torch.LongTensor, # optional label
"output_one_hot": torch.FloatTensor, # optional label one-hot
"domain": torch.LongTensor, # optional domain label
"domain_one_hot": torch.FloatTensor, # optional domain one-hot
}
Notes
-----
Applies `self.transform(sample)` if a transform is provided.
"""
if not isinstance(idx, int):
raise TypeError(f"indices must be int, you passed {type(idx)}, {idx}")
if idx < 0 or idx >= len(self):
raise ValueError(f"idx {idx} is invalid for dataset with {len(self)} examples.")
input_ = self.counts[idx, ...]
sample: Dict[str, torch.Tensor] = {"input": input_, "idx": self.indexes[idx]}
if self.labels is not None:
sample["output"] = self.labels[0][idx]
sample["output_one_hot"] = self.labels[1][idx]
if self.domains is not None:
sample["domain"] = self.domains[0][idx]
sample["domain_one_hot"] = self.domains[1][idx]
if self.transform is not None:
sample = self.transform(sample)
return sample
[docs]def balance_classes(
y: np.ndarray,
class_min: int = 256,
random_state: Optional[int] = None,
) -> np.ndarray:
"""
Balance class indices by undersampling majorities and oversampling minorities.
Parameters
----------
y : np.ndarray
Shape [N,]. Class labels.
class_min : int, default 256
Minimum examples per class after balancing.
random_state : int, optional
Random seed for reproducibility.
Returns
-------
balanced_idx : np.ndarray
Balanced indices (with replacement for minority classes).
Notes
-----
The smallest effective class count used is `max(min_count, class_min)`.
"""
if not isinstance(y, np.ndarray):
raise TypeError(f"y should be a numpy array, but got {type(y)}")
if not isinstance(class_min, int) or class_min <= 0:
raise ValueError(f"class_min should be a positive integer, but got {class_min}")
if random_state is not None:
np.random.seed(random_state)
classes, counts = np.unique(y, return_counts=True)
min_count = np.min(counts)
min_count = max(min_count, class_min)
balanced_idx: List[np.ndarray] = []
for cls, count in zip(classes, counts):
class_idx = np.where(y == cls)[0].astype(int)
oversample = count < min_count
if oversample:
print(f"Class {cls} has {count} samples. Oversampling to {min_count} samples.")
sampled_idx = np.random.choice(class_idx, size=min_count, replace=oversample)
balanced_idx.append(sampled_idx)
balanced_idx = np.concatenate(balanced_idx).astype(int)
return balanced_idx
[docs]class GraphDataset(InMemoryDataset):
"""
In-memory PyG dataset for a *paired* graph with features, batches, domains, and labels.
Parameters
----------
data : np.ndarray
Shape [num_nodes, num_features]. Node features.
batch : np.ndarray
Shape [num_nodes]. Batch assignment per node.
domain : np.ndarray
Shape [num_nodes]. Domain labels per node.
edge_index : torch.Tensor
Shape [2, num_edges]. Edge index.
label : np.ndarray, optional
Shape [num_nodes]. Node labels. Default: None.
transform : callable, optional
A callable that takes and returns a `torch_geometric.data.Data` object.
Attributes
----------
graph_data : torch_geometric.data.Data
Graph data object with fields:
- data_0 (FloatTensor), batch_0 (LongTensor), domain_0 (LongTensor),
edge_index (Tensor), idx (LongTensor), label (LongTensor or None),
num_nodes (int).
Notes
-----
This dataset contains a single graph (length = 1).
"""
[docs] def __init__(
self,
data: np.ndarray,
batch: np.ndarray,
domain: np.ndarray,
edge_index: torch.Tensor,
label: np.ndarray = None,
transform: Callable = None,
):
self.root = "." # customizable
super(GraphDataset, self).__init__(self.root, transform)
# type checks
if not isinstance(data, np.ndarray):
raise TypeError(f"data should be of type np.ndarray, but got {type(data)}")
if not isinstance(batch, np.ndarray):
raise TypeError(f"batch should be of type np.ndarray, but got {type(batch)}")
if not isinstance(domain, np.ndarray):
raise TypeError(f"domain should be of type np.ndarray, but got {type(domain)}")
if not isinstance(edge_index, torch.Tensor):
raise TypeError(f"edge_index should be of type torch.Tensor, but got {type(edge_index)}")
if label is not None and not isinstance(label, np.ndarray):
raise TypeError(f"label should be of type np.ndarray, but got {type(label)}")
self.graph_data = Data(
data_0=torch.FloatTensor(data.copy()),
batch_0=torch.LongTensor(batch.copy()),
edge_index=edge_index,
idx=torch.LongTensor(np.arange(data.shape[0])),
domain_0=torch.LongTensor(domain.copy()),
label=None if label is None else torch.LongTensor(label),
num_nodes=data.shape[0],
)
def __len__(self) -> int:
"""
Number of graphs in the dataset.
Returns
-------
n : int
Always 1 for `InMemoryDataset` here.
"""
return 1
def __getitem__(self, idx: int) -> Data:
"""
Retrieve the single stored graph.
Parameters
----------
idx : int
Graph index (must be 0).
Returns
-------
graph_data : torch_geometric.data.Data
The stored graph object.
Raises
------
IndexError
If `idx != 0`.
"""
if idx != 0:
raise IndexError("Index out of range. This dataset contains only one graph.")
return self.graph_data
[docs]class GraphDataset_unpaired(InMemoryDataset):
"""
In-memory PyG dataset for an *unpaired* graph with features, domains, and labels.
Parameters
----------
data : np.ndarray
Shape [num_nodes, num_features]. Node features.
domain : np.ndarray
Shape [num_nodes]. Domain labels per node.
edge_index : torch.Tensor
Shape [2, num_edges]. Edge index.
label : np.ndarray, optional
Shape [num_nodes]. Node labels. Default: None.
transform : callable, optional
A callable that takes and returns a `torch_geometric.data.Data` object.
Attributes
----------
graph_data : torch_geometric.data.Data
Graph data object with fields:
- data (FloatTensor), domain (LongTensor),
edge_index (Tensor), idx (LongTensor), label (LongTensor or None),
num_nodes (int).
Notes
-----
This dataset contains a single graph (length = 1).
"""
[docs] def __init__(
self,
data: np.ndarray,
domain: np.ndarray,
edge_index: torch.Tensor,
label: np.ndarray = None,
transform: Callable = None,
):
self.root = "." # customizable
super(GraphDataset_unpaired, self).__init__(self.root, transform)
# type checks
if not isinstance(data, np.ndarray):
raise TypeError(f"data should be of type np.ndarray, but got {type(data)}")
if not isinstance(domain, np.ndarray):
raise TypeError(f"domain should be of type np.ndarray, but got {type(domain)}")
if not isinstance(edge_index, torch.Tensor):
raise TypeError(f"edge_index should be of type torch.Tensor, but got {type(edge_index)}")
if label is not None and not isinstance(label, np.ndarray):
raise TypeError(f"label should be of type np.ndarray, but got {type(label)}")
self.graph_data = Data(
data=torch.FloatTensor(data.copy()),
edge_index=edge_index,
idx=torch.LongTensor(np.arange(data.shape[0])),
domain=torch.LongTensor(domain.copy()),
label=None if label is None else torch.LongTensor(label),
num_nodes=data.shape[0],
)
def __len__(self) -> int:
"""
Number of graphs in the dataset.
Returns
-------
n : int
Always 1 for `InMemoryDataset` here.
"""
return 1
def __getitem__(self, idx: int) -> Data:
"""
Retrieve the single stored graph.
Parameters
----------
idx : int
Graph index (must be 0).
Returns
-------
graph_data : torch_geometric.data.Data
The stored graph object.
Raises
------
IndexError
If `idx != 0`.
"""
if idx != 0:
raise IndexError("Index out of range. This dataset contains only one graph.")
return self.graph_data