Source code for sodirac.main

import os
import time
import random
from typing import Callable, Iterable, Union, List, Tuple, Dict, Any
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc 
import anndata

import torch
from torch_geometric.loader import DataLoader, ClusterData, ClusterLoader
from torch_geometric.utils import to_undirected
from torchvision import transforms
import torchvision
from torch_geometric.data import InMemoryDataset, Data

from .dataprep import GraphDS, GraphDataset, GraphDataset_unpaired
from .model import integrate_model, annotate_model
from .trainer import train_integrate, train_annotate
from .hyper import *


#########################################################
# Dirac's integration and annotation app
#########################################################

[docs]class integrate_app(): """High-level API for multi-omics graph **integration**. This class prepares data (optionally with subgraph sampling), builds an integration model, trains it in an unsupervised manner, and returns embeddings/reconstructions. """
[docs] def __init__( self, save_path: str = './Results/', subgraph: bool = True, use_gpu: bool = True, **kwargs, )-> None: """Initialize the integration app. Parameters ---------- save_path : str, default './Results/' Directory to write outputs (figures, checkpoints, etc.). Must be writable. subgraph : bool, default True If ``True``, use ``ClusterData``/``ClusterLoader`` for sampling. If ``False``, use a full-batch ``DataLoader`` for small graphs. use_gpu : bool, default True If ``True``, selects ``cuda`` when available; otherwise CPU. **kwargs : Any Ignored; forwarded to ``super``. Side Effects ------------ Sets ``self.device``, ``self.subgraph``, and ``self.save_path``. """ super(integrate_app, self).__init__(**kwargs) if use_gpu: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = "cpu" self.subgraph = subgraph self.save_path = save_path
[docs] def _get_data( self, dataset_list: list, edge_index, domain_list = None, batch = None, num_parts: int = 10, num_workers: int = 1, batch_size: int = 1, ): """Process multi-omics node features and construct a graph dataset. Parameters ---------- dataset_list : list of (ndarray | torch.Tensor) List of feature matrices, one per modality/layer. Each element must be shaped ``(n_nodes, n_features_i)`` **(rows = nodes, cols = features)**. edge_index : torch.LongTensor Graph connectivity in COO format with shape ``(2, E)``. Will be made undirected via ``to_undirected``. domain_list : list[np.ndarray] | None, optional Optional per-modality integer domain labels of length ``n_nodes``. If ``None``, each dataset is treated as its own domain (0..n-1). batch : None | pandas.Series | np.ndarray | list, optional Optional per-node batch labels of length ``n_nodes``. Non-numeric labels are categorical-encoded. If ``None``, a zero vector is used for each modality. num_parts : int, default 10 Number of partitions for ``ClusterData`` when ``self.subgraph=True``. num_workers : int, default 1 Number of workers for the loaders. batch_size : int, default 1 Batch size for ``ClusterLoader`` when ``self.subgraph=True``. Returns ------- dict A dictionary with the following keys: - ``graph_ds`` : dict Underlying graph data object/dict from ``GraphDataset`` with additional modality tensors (e.g., ``data_1``, ``domain_1``, ``batch_1``...). - ``graph_dl`` : ClusterLoader | DataLoader A ``ClusterLoader`` if ``self.subgraph=True``; otherwise a full-batch ``DataLoader`` with a single item. - ``n_samples`` : int Number of input datasets/modalities. - ``n_inputs_list`` : list[int] Feature dimensions for each dataset ``[n_features_0, n_features_1, ...]``. - ``n_domains`` : int Number of unique domains inferred from ``domain_list``. Raises ------ ValueError If node counts differ across ``dataset_list``; if ``batch`` length mismatches data; or an unsupported ``batch`` type is provided. Notes ----- Sets ``self.n_samples``, ``self.n_inputs_list``, and ``self.num_domains``. Prints the number of unique domains detected. """ # Store number of input datasets (omics layers) self.n_samples = len(dataset_list) # Validate consistent number of nodes across datasets def _n_nodes(x): return x.shape[0] if hasattr(x, "shape") else len(x) n_nodes = _n_nodes(dataset_list[0]) for idx, data in enumerate(dataset_list): if _n_nodes(data) != n_nodes: raise ValueError( f"All datasets must have the same number of rows (nodes). " f"dataset_list[0] has {n_nodes}, but dataset_list[{idx}] has {_n_nodes(data)}." ) # Calculate number of unique domains if domain_list is None: # If no domain labels provided, treat each dataset as separate domain domain_list = [np.full(n_nodes, i, dtype=np.int64) for i in range(self.n_samples)] self.num_domains = len(dataset_list) else: # Find maximum domain index across all domain label arrays domains_max = [int(domain.max()) for domain in domain_list if domain is not None] # Number of domains is max index + 1 (assuming 0-based indexing) self.num_domains = max(domains_max) + 1 if domains_max else 1 print(f"Found {self.num_domains} unique domains.") # Process batch information if batch is None: # Case 1: No batch information provided - create dummy batch labels (all 0) batch_size = len(dataset_list[0]) if dataset_list else 0 batch_list = [np.zeros(batch_size, dtype=np.int64) for _ in range(self.n_samples)] elif hasattr(batch, 'values'): # Case 2: Pandas Series input (e.g., adata.obs['batch']) # Convert to categorical codes (numerical representation) batch_values = batch.values categorical = pd.Categorical(batch_values) batch_list = [categorical.codes.astype(np.int64) for _ in range(self.n_samples)] elif isinstance(batch, (np.ndarray, list)): # Case 3: Numpy array or Python list input batch_array = np.asarray(batch) if not np.issubdtype(batch_array.dtype, np.number): # Convert non-numeric batch labels to categorical codes categorical = pd.Categorical(batch_array) batch_list = [categorical.codes.astype(np.int64) for _ in range(self.n_samples)] else: # Use numerical batch labels directly batch_list = [batch_array.astype(np.int64) for _ in range(self.n_samples)] else: raise ValueError(f"Unsupported batch type: {type(batch)}") # Validate batch dimensions match data for batch_arr in batch_list: if len(batch_arr) != len(dataset_list[0]): raise ValueError("Batch length does not match data length") # Initialize storage for graph data and feature dimensions self.n_inputs_list = [] # Will store feature dimensions for each dataset graph_data = {} # Will store final graph data dictionary # Process each omics dataset for i, data in enumerate(dataset_list): # Store feature dimension for current dataset self.n_inputs_list.append(data.shape[1]) if i == 0: # First dataset initializes the graph structure graph_ds = GraphDataset( data=data, domain=domain_list[i], batch=batch_list[i], edge_index=to_undirected(edge_index), # Ensure undirected graph ) graph_data = graph_ds.graph_data else: # Additional datasets are added as node features graph_data[f"data_{i}"] = torch.FloatTensor(data) graph_data[f"domain_{i}"] = torch.LongTensor(domain_list[i]) graph_data[f"batch_{i}"] = torch.LongTensor(batch_list[i].copy()) # Create appropriate data loader if self.subgraph: # For large graphs: use neighborhood sampling with ClusterData graph_dataset = ClusterData(graph_data, num_parts=num_parts, recursive=False) graph_dl = ClusterLoader( graph_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers ) else: # For small graphs: use full-batch loading graph_dl = DataLoader([graph_data]) # Return processed data and metadata return { "graph_ds": graph_data, "graph_dl": graph_dl, "n_samples": self.n_samples, "n_inputs_list": self.n_inputs_list, "n_domains": self.num_domains }
[docs] def _get_model( self, samples, n_hiddens: int = 128, n_outputs: int = 64, opt_GNN = "GAT", dropout_rate = 0.1, use_skip_connections = True, use_attention = True, n_attention_heads = 4, use_layer_scale = False, layer_scale_init = 1e-2, use_stochastic_depth = False, stochastic_depth_rate = 0.1, combine_method = 'concat', # 'concat', 'sum', 'attention' ): """Build the integration model with the provided hyperparameters. Parameters ---------- samples : dict Output from ``_get_data``. Must contain ``n_inputs_list`` and ``n_domains``. n_hiddens : int, default 128 Hidden dimension for GNN layers. n_outputs : int, default 64 Output/embedding dimension per node. opt_GNN : str, default 'GAT' GNN backbone option consumed by ``integrate_model``. dropout_rate : float, default 0.1 Dropout rate inside the model. use_skip_connections : bool, default True Whether to enable residual/skip connections (if supported). use_attention : bool, default True Whether to use attention (if supported by the chosen backbone). n_attention_heads : int, default 4 Number of attention heads (if applicable). use_layer_scale : bool, default False If ``True``, enable layer scale with initialization ``layer_scale_init``. layer_scale_init : float, default 1e-2 Initialization value for layer scaling. use_stochastic_depth : bool, default False Enable stochastic depth. stochastic_depth_rate : float, default 0.1 Drop probability for stochastic depth. combine_method : {'concat','sum','attention'}, default 'concat' How to combine multi-modal features inside the model. Returns ------- models : Any The model instance returned by ``integrate_model(...)``, ready for training. """ ##### Build a transfer model to conver atac data to rna shape models = integrate_model(n_inputs_list = samples["n_inputs_list"], n_domains = samples["n_domains"], n_hiddens = n_hiddens, n_outputs = n_outputs, opt_GNN = opt_GNN, dropout_rate = dropout_rate, use_skip_connections = use_skip_connections, use_attention = use_attention, n_attention_heads = n_attention_heads, use_layer_scale = use_layer_scale, layer_scale_init = layer_scale_init, use_stochastic_depth = use_stochastic_depth, stochastic_depth_rate = stochastic_depth_rate, combine_method = combine_method ) return models
[docs] def _train_dirac_integrate( self, samples, models, epochs: int = 500, optimizer_name: str = "adam", lr: float = 1e-3, tau: float = 0.9, wd: float = 5e-2, scheduler: bool = True, lamb: float = 5e-4, scale_loss: float = 0.025, ): """Train the integration model and evaluate embeddings/reconstructions. Parameters ---------- samples : dict Output from ``_get_data`` with keys like ``graph_ds``, ``graph_dl``, ``n_inputs_list``, ``n_domains``. models : Any Model returned by ``_get_model`` / ``integrate_model``. epochs : int, default 500 Training epochs. optimizer_name : str, default 'adam' Optimizer identifier consumed by the trainer. lr : float, default 1e-3 Learning rate. tau : float, default 0.9 Momentum/EMA or contrastive temperature parameter (per trainer definition). wd : float, default 5e-2 Weight decay. scheduler : bool, default True Whether to use a learning-rate scheduler. lamb : float, default 5e-4 Loss coefficient used by the trainer. scale_loss : float, default 0.025 Additional loss scaling used by the trainer. Returns ------- data_z : torch.Tensor Node embeddings; typically shaped ``(n_nodes, n_outputs)``. combine_recon : Any Reconstruction(s) as returned by ``train_integrate.evaluate``; may be a tensor or a structure of tensors. """ ######### load all dataloaders and dist arrays hyperparams = unsuper_hyperparams(lr = lr, tau = tau, wd = wd, scheduler = scheduler) un_dirac = train_integrate( minemodel = models, save_path = self.save_path, device = self.device, ) un_dirac._train( samples = samples, epochs = epochs, hyperparams = hyperparams, optimizer_name = optimizer_name, lamb = lamb, scale_loss = scale_loss ) data_z, combine_recon = un_dirac.evaluate(samples = samples) return data_z, combine_recon
[docs]class annotate_app(integrate_app): """High-level API for **annotation / domain adaptation** on graphs. Prepares labeled source (and unlabeled target) graphs, builds an annotation model, supports semi-supervised training, optional novel-class discovery, and evaluation on source/target/test. """
[docs] def _get_data( self, source_data, source_label, source_edge_index, target_data, target_edge_index, source_domain = None, target_domain = None, test_data = None, test_edge_index = None, weighted_classes = False, split_list = None, num_workers: int = 1, batch_size: int = 1, num_parts_source: int = 1, num_parts_target: int = 1, ): """Process labeled source and (optional) unlabeled target into loaders. Parameters ---------- source_data : (ndarray | torch.Tensor) Source node features with shape ``(n_source_nodes, n_features)``. source_label : (array-like) Source labels; numeric or categorical. Non-numeric labels are encoded to 0-based integer codes. A mapping is stored in ``self.pairs``. source_edge_index : torch.LongTensor COO connectivity for the source graph, shape ``(2, E_source)``; made undirected. target_data : (ndarray | torch.Tensor) or None Optional target node features with shape ``(n_target_nodes, n_features)``. target_edge_index : torch.LongTensor or None Optional COO connectivity for target graph, shape ``(2, E_target)``; made undirected if provided. source_domain : array-like[int] or None, default None Optional per-node domain labels for source. Defaults to zeros. target_domain : array-like[int] or None, default None Optional per-node domain labels for target. Defaults to ones when ``target_data`` is provided. test_data : (ndarray | torch.Tensor) or None, default None Optional test node features ``(n_test_nodes, n_features)``. test_edge_index : torch.LongTensor or None, default None Required if ``test_data`` is provided. weighted_classes : bool, default False If ``True``, compute inverse-frequency class weights for source labels. split_list : list[tuple[int,int]] or None, default None Optional feature splits for multi-modal inputs, e.g., ``[(0,1000),(1000,1500)]``. num_workers : int, default 1 DataLoader workers for source/target loaders. batch_size : int, default 1 Batch size for ``ClusterLoader``. num_parts_source : int, default 1 ``ClusterData`` partitions for source graph. num_parts_target : int, default 1 ``ClusterData`` partitions for target graph. Returns ------- dict Contains: - ``source_graph_ds`` : dict Graph data object/dict for source (from ``GraphDataset_unpaired``). - ``source_graph_dl`` : ClusterLoader Loader over source clusters. - ``target_graph_ds`` : dict | None Graph data for target or ``None`` if no target. - ``target_graph_dl`` : ClusterLoader | None Loader for target or ``None`` if no target. - ``test_graph_ds`` : torch_geometric.data.Data | None Test graph object if both ``test_data`` and ``test_edge_index`` provided. - ``class_weight`` : torch.FloatTensor | None Class weights when ``weighted_classes=True``. - ``n_labels`` : int Number of unique labels in source. - ``n_inputs`` : int Feature dimension. - ``n_domains`` : int Number of domains inferred from ``source_domain``/``target_domain``. - ``split_list`` : list[tuple[int,int]] | None Echo of the provided ``split_list``. Notes ----- If ``source_label`` is categorical, ``self.pairs`` stores a mapping ``{code: original_label}``; otherwise ``self.pairs`` is ``None``. Sets ``self.n_labels``, ``self.n_inputs``, and ``self.n_domains``. Prints the number of unique domains. """ # Calculate basic dataset properties if not pd.api.types.is_numeric_dtype(source_label): categorical = pd.Categorical(source_label) source_label = np.asarray(categorical.codes, dtype=np.int64) self.pairs = dict(enumerate(categorical.categories)) else: source_label = np.asarray(source_label, dtype=np.int64) self.pairs = None self.n_labels = len(np.unique(source_label)) self.n_inputs = source_data.shape[1] # Handle domain label assignment # Default: source=0, target=1 when domains not specified if source_domain is None: source_domain = np.zeros(source_data.shape[0], dtype=np.int64) if target_domain is None and target_data is not None: target_domain = np.ones(target_data.shape[0], dtype=np.int64) # Determine number of unique domains if target_data is None: self.n_domains = 1 # Only source domain exists else: # Get maximum domain index from both domains source_max = int(source_domain.max()) target_max = int(target_domain.max()) if target_domain is not None else 1 self.n_domains = max(source_max, target_max) + 1 # +1 for zero-based indexing print(f"Identified {self.n_domains} unique domains.") # Calculate class weights for imbalanced datasets if weighted_classes: classes, counts = np.unique(source_label, return_counts=True) class_weights = (1.0 / (counts/counts.sum())) / (1.0 / (counts/counts.sum())).min() class_weight = torch.from_numpy(class_weights).float() else: class_weight = None # Prepare source domain dataset source_graph = GraphDataset_unpaired( data=source_data, domain=source_domain, edge_index=to_undirected(source_edge_index), label=source_label ) source_clusters = ClusterData( source_graph.graph_data, num_parts=num_parts_source, recursive=False ) source_loader = ClusterLoader( source_clusters, batch_size=batch_size, shuffle=True, num_workers=num_workers ) # Prepare target domain dataset (if exists) target_graph = None target_loader = None if target_data is not None: target_graph = GraphDataset_unpaired( data=target_data, domain=target_domain, edge_index=to_undirected(target_edge_index), label=None # Target domain is unlabeled ) target_clusters = ClusterData( target_graph.graph_data, num_parts=num_parts_target, recursive=False ) target_loader = ClusterLoader( target_clusters, batch_size=batch_size, shuffle=True, num_workers=num_workers ) # Prepare test dataset (if exists) test_graph = None if test_data is not None and test_edge_index is not None: test_graph = Data( data=torch.FloatTensor(test_data), edge_index=test_edge_index ) return { "source_graph_ds": source_graph.graph_data, "source_graph_dl": source_loader, "target_graph_ds": target_graph.graph_data if target_graph else None, "target_graph_dl": target_loader, "test_graph_ds": test_graph, "class_weight": class_weight, "n_labels": self.n_labels, "n_inputs": self.n_inputs, "n_domains": self.n_domains, "split_list": split_list, }
[docs] def _get_model( self, samples, n_hiddens: int = 128, n_outputs: int = 64, opt_GNN: str = "SAGE", s: int = 32, m: float = 0.10, easy_margin: bool = False, dropout_rate: float = 0.1, use_skip_connections: bool = False, use_attention: bool = True, n_attention_heads: int = 2, use_layer_scale: bool = False, layer_scale_init: float = 1e-2, use_stochastic_depth: bool = False, stochastic_depth_rate: float = 0.1, combine_method: str = 'concat', # 'concat', 'sum', 'attention' ): """Build the annotation model (classifier/domain-adaptation). Parameters ---------- samples : dict Output from ``annotate_app._get_data``; must include ``n_domains``, ``n_labels``, and either ``n_inputs`` (int) or ``split_list`` for multi-modal cases. n_hiddens : int, default 128 Hidden dimension. n_outputs : int, default 64 Embedding dimension before the classification head. opt_GNN : str, default 'SAGE' GNN backbone identifier consumed by ``annotate_model``. s : int, default 32 Scale parameter for margin-based head (if applicable). m : float, default 0.10 Margin parameter for margin-based head. easy_margin : bool, default False Use easy margin variant if supported. dropout_rate : float, default 0.1 Dropout rate. use_skip_connections : bool, default False Enable skip/residual connections (if supported). use_attention : bool, default True Enable attention (if supported). n_attention_heads : int, default 2 Number of attention heads when applicable. use_layer_scale : bool, default False Enable layer scaling. layer_scale_init : float, default 1e-2 Initial value for layer scale. use_stochastic_depth : bool, default False Enable stochastic depth. stochastic_depth_rate : float, default 0.1 Drop probability for stochastic depth. combine_method : {'concat','sum','attention'}, default 'concat' Feature fusion strategy for multi-modal inputs. Returns ------- models : Any Model instance returned by ``annotate_model(...)``. """ ##### Build a transfer model to conver atac data to rna shape # Handle multi-modal case if samples["split_list"] is not None: # Calculate input dimensions for each modality n_inputs = [] for start, end in samples["split_list"]: n_inputs.append(end - start) else: # Single modality case n_inputs = samples["n_inputs"] models = annotate_model( n_inputs= n_inputs, n_domains = samples["n_domains"], n_labels = samples["n_labels"], n_hiddens = n_hiddens, n_outputs = n_outputs, opt_GNN = opt_GNN, s = s, m = m, easy_margin = easy_margin, dropout_rate = dropout_rate, use_skip_connections = use_skip_connections, use_attention = use_attention, n_attention_heads = n_attention_heads, use_layer_scale = use_layer_scale, use_stochastic_depth = use_stochastic_depth, stochastic_depth_rate = stochastic_depth_rate, combine_method = combine_method, input_split = samples["split_list"], ) self.n_outputs = n_outputs self.opt_GNN = opt_GNN self.n_hiddens = n_hiddens return models
[docs] def _train_dirac_annotate( self, samples, models, n_epochs: int = 200, optimizer_name: str = "adam", lr: float = 1e-3, wd: float = 5e-3, scheduler: bool = True, filter_low_confidence: bool = True, confidence_threshold: float = 0.5, ): """Train the annotation model (semi-supervised/domain adaptation) and evaluate. Parameters ---------- samples : dict Output from ``_get_data``. Expected keys include ``source_graph_ds``, ``source_graph_dl``, optional ``target_graph_dl`` and ``test_graph_ds``, and possibly ``class_weight``. models : Any Model returned by ``_get_model`` / ``annotate_model``. n_epochs : int, default 200 Number of training epochs. optimizer_name : str, default 'adam' Optimizer identifier. lr : float, default 1e-3 Learning rate. wd : float, default 5e-3 Weight decay. scheduler : bool, default True Whether to enable learning-rate scheduling. filter_low_confidence : bool, default True If ``True``, mark predictions with confidence < ``confidence_threshold`` as ``"unassigned"`` in the returned ``target_pred_filtered`` / ``test_pred_filtered``. confidence_threshold : float, default 0.5 Confidence threshold in [0, 1]. Returns ------- dict With keys (some may be ``None`` if target/test are absent): ``source_feat``, ``target_feat``, ``target_output``, ``target_prob``, ``target_pred``, ``target_pred_filtered``, ``target_confs``, ``target_mean_uncert``, ``test_feat``, ``test_output``, ``test_prob``, ``test_pred``, ``test_pred_filtered``, ``test_confs``, ``test_mean_uncert``, ``pairs``, ``pairs_filter``, and ``low_confidence_threshold``. """ def _filter_predictions_by_confidence(preds, confs): """Return 'unassigned' where confidence is below the threshold.""" return np.where(confs < confidence_threshold, "unassigned", preds) # Step 1: Initialize samples["n_outputs"] = self.n_outputs hyperparams = unsuper_hyperparams(lr=lr, wd=wd, scheduler=scheduler) semi_dirac = train_annotate( minemodel=models, save_path=self.save_path, device=self.device, ) # Step 2: Train model semi_dirac._train( samples=samples, epochs=n_epochs, hyperparams=hyperparams, optimizer_name=optimizer_name, ) # Step 3: Evaluate source _, source_feat, _, _ = semi_dirac.evaluate_source( graph_dl=samples["source_graph_ds"], return_lists_roc=True, ) # Step 4: Evaluate target (novel) target_feat, target_output, target_prob, target_pred, target_confs, target_mean_uncert = semi_dirac.evaluate_novel_target( graph_dl=samples["target_graph_dl"], return_lists_roc=True, ) target_pred_filtered = _filter_predictions_by_confidence(target_pred, target_confs) if filter_low_confidence else None # Step 5: Evaluate test set if available if samples.get("test_graph_ds") is not None: test_feat, test_output, test_prob, test_pred, test_confs, test_mean_uncert = semi_dirac.evaluate_target( graph_dl=samples["test_graph_ds"], return_lists_roc=True, ) test_pred_filtered = _filter_predictions_by_confidence(test_pred, test_confs) if filter_low_confidence else None else: test_feat = test_output = test_prob = test_pred = test_confs = test_mean_uncert = test_pred_filtered = None if filter_low_confidence: pairs_filter = {str(k): v for k, v in self.pairs.items()} pairs_filter["unassigned"] = "unassigned" else: pairs_filter = None # Step 6: Package results results = { "source_feat": source_feat, "target_feat": target_feat, "target_output": target_output, "target_prob": target_prob, "target_pred": target_pred, "target_pred_filtered": target_pred_filtered, "target_confs": target_confs, "target_mean_uncert": target_mean_uncert, "test_feat": test_feat, "test_output": test_output, "test_prob": test_prob, "test_pred": test_pred, "test_pred_filtered": test_pred_filtered, "test_confs": test_confs, "test_mean_uncert": test_mean_uncert, "pairs": self.pairs, "pairs_filter": pairs_filter, "low_confidence_threshold": confidence_threshold if filter_low_confidence else None, } return results
[docs] def _train_dirac_novel( self, samples, minemodel, num_novel_class: int = 3, pre_epochs: int = 100, n_epochs: int = 200, num_parts: int = 30, resolution: float = 1, s: int = 64, m: float = 0.1, weights: dict = {"alpha1": 1,"alpha2": 1,"alpha3": 1,"alpha4": 1,"alpha5": 1,"alpha6": 1,"alpha7": 1,"alpha8": 1} ): """Discover novel target classes and retrain with expanded label space. Parameters ---------- samples : dict Output from ``_get_data``; must include keys ``source_graph_ds``, ``source_graph_dl``, ``target_graph_ds``, ``target_graph_dl``, ``class_weight`` (optional), ``n_labels``, and feature sizes ``n_inputs``. minemodel : Any Initial annotation model (from ``_get_model``). num_novel_class : int, default 3 Number of novel classes to discover in target. pre_epochs : int, default 100 Supervised pretraining epochs on source. n_epochs : int, default 200 Training epochs for the novel-phase. num_parts : int, default 30 Number of partitions for the (new) target ``ClusterData``. resolution : float, default 1 Louvain resolution for clustering. s : int, default 64 Scale parameter for the (re)built model head. m : float, default 0.1 Margin parameter for the (re)built model head. weights : dict, default {"alpha1":1, ..., "alpha8":1} Loss weights dictionary consumed by ``_train_novel``. Returns ------- dict With keys: ``source_feat``, ``target_feat``, ``target_output``, ``target_prob``, ``target_pred``, ``target_confs``, ``target_mean_uncert``, ``test_feat``, ``test_pred``. (``test_*`` may be ``None`` if a test set is not provided.) """ samples["n_outputs"] = self.n_outputs samples["opt_GNN"] = self.opt_GNN samples["n_hiddens"] = self.n_hiddens ######### Find Target Data for novel cell type unlabel_x = samples["target_graph_ds"].data print("Performing louvain...") adata = anndata.AnnData(unlabel_x.numpy()) if adata.shape[1] > 100: sc.tl.pca(adata) sc.pp.neighbors(adata) else: sc.pp.neighbors(adata, use_rep = "X") sc.tl.louvain(adata, resolution = resolution, key_added='louvain') clusters = adata.obs["louvain"].values clusters = clusters.astype(int) print("Louvain finished") ########## Training SpaGNNs_gpu for source domain semi_dirac = train_annotate( minemodel = minemodel, save_path = self.save_path, device = self.device, ) pre_model = semi_dirac._train_supervised(samples = samples, graph_dl_source = samples["source_graph_dl"], epochs=pre_epochs, class_weight = samples["class_weight"]) novel_label, entrs = semi_dirac._est_seeds(source_graph = samples["source_graph_ds"], target_graph = samples["target_graph_dl"], clusters = clusters, num_novel_class = num_novel_class) import time now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) adata.obs["novel_cell_type"] = pd.Categorical(novel_label) adata.obs["entrs"] = entrs sc.tl.umap(adata) sc.pl.umap(adata, color=["louvain", "novel_cell_type", "entrs"], cmap="CMRmap_r", size=20) plt.savefig(os.path.join(self.save_path, f"UMAP_clusters_{now}.pdf"), bbox_inches='tight', dpi = 300) samples["target_graph_ds"].label = torch.tensor(novel_label) unlabeled_data = ClusterData(samples["target_graph_ds"], num_parts = num_parts, recursive = False) unlabeled_loader = ClusterLoader(unlabeled_data, batch_size=1, shuffle = True, num_workers=1) samples["target_graph_dl"] = unlabeled_loader samples["n_novel_labels"] = num_novel_class + samples["n_labels"] if samples["class_weight"] is not None: samples["class_weight"] = torch.cat([samples["class_weight"], torch.ones(num_novel_class)], dim=0) ###### change models minemodel = annotate_model( n_inputs= samples["n_inputs"], n_domains = samples["n_domains"], n_labels = samples["n_novel_labels"], n_hiddens = samples["n_hiddens"], n_outputs = samples["n_outputs"], opt_GNN = samples["opt_GNN"] ) semi_dirac = train_annotate( minemodel = minemodel, save_path = self.save_path, device = self.device, ) hyperparams = unsuper_hyperparams() semi_dirac._train_novel( pre_model = pre_model, samples = samples, epochs = n_epochs, hyperparams = hyperparams, weights = weights, ) _, source_feat, _, _ = semi_dirac.evaluate_source(graph_dl = samples["source_graph_ds"], return_lists_roc = True) target_feat, target_output, target_prob, target_pred, target_confs, target_mean_uncert = semi_dirac.evaluate_novel_target(graph_dl = samples["target_graph_dl"], return_lists_roc = True) if samples["test_graph_ds"] is not None: test_feat, _, test_pred = semi_dirac.evaluate_target(graph_dl = samples["test_graph_ds"], return_lists_roc = True) else: test_feat = None test_pred = None results = { "source_feat": source_feat, "target_feat": target_feat, "target_output": target_output, "target_prob": target_prob, "target_pred": target_pred, "target_confs": target_confs, "target_mean_uncert": target_mean_uncert, "test_feat": test_feat, "test_pred": test_pred, } return results