Source code for dance.modules.multi_modality.joint_embedding.scmogcn

"""Official release of scMoGNN method.

Reference
---------
Wen, Hongzhi, et al. "Graph Neural Networks for Multimodal Single-Cell Data Integration." arXiv:2203.01884 (2022).

"""
import os
from copy import deepcopy

import dgl.nn.pytorch as dglnn
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score
from sklearn.metrics.cluster import normalized_mutual_info_score
from torch.utils.data import DataLoader

from dance import logger
from dance.utils import SimpleIndexDataset
from dance.utils.metrics import *


def propagation_layer_combination(X, idx, wt, from_logits=True):
    if from_logits:
        wt = torch.softmax(wt, -1)

    x = 0
    for i in range(wt.shape[0]):
        x += wt[i] * X[i][idx]

    return x


def cell_feature_propagation(g, alpha: float = 0.5, beta: float = 0.5, cell_init: str = None, feature_init: str = 'id',
                             device: str = 'cuda', layers: int = 3):
    g = g.to(device)
    gconv = dglnn.HeteroGraphConv(
        {
            'cell2feature': dglnn.GraphConv(in_feats=0, out_feats=0, norm='none', weight=False, bias=False),
            'rev_cell2feature': dglnn.GraphConv(in_feats=0, out_feats=0, norm='none', weight=False, bias=False),
        }, aggregate='sum')

    if feature_init is None:
        feature_X = torch.zeros((g.nodes('feature').shape[0], g.srcdata[cell_init]['cell'].shape[1])).float().to(device)
    elif feature_init == 'id':
        feature_X = F.one_hot(g.srcdata['id']['feature']).float().to(device)
    else:
        raise NotImplementedError(f'Not implemented feature init feature {feature_init}.')

    if cell_init is None:
        cell_X = torch.zeros(g.nodes('cell').shape[0], feature_X.shape[1]).float().to(device)
    else:
        cell_X = g.srcdata[cell_init]['cell'].float().to(device)

    h = {'feature': feature_X, 'cell': cell_X}
    hcell = []
    for i in range(layers):
        h1 = gconv(
            g, h, mod_kwargs={
                'cell2feature': {
                    'edge_weight': g.edges['cell2feature'].data['weight'].float()
                },
                'rev_cell2feature': {
                    'edge_weight': g.edges['rev_cell2feature'].data['weight'].float()
                }
            })
        logger.debug(f"{i} cell {h['cell'].abs().mean()} {h1['cell'].abs().mean()}")
        logger.debug(f"{i} feature {h['feature'].abs().mean()} {h1['feature'].abs().mean()}")

        h1['feature'] = (h1['feature'] -
                         h1['feature'].mean()) / (h1['feature'].std() if h1['feature'].mean() != 0 else 1)
        h1['cell'] = (h1['cell'] - h1['cell'].mean()) / (h1['cell'].std() if h1['cell'].mean() != 0 else 1)

        h = {
            'feature': h['feature'] * alpha + h1['feature'] * (1 - alpha),
            'cell': h['cell'] * beta + h1['cell'] * (1 - beta)
        }

        h['feature'] = (h['feature'] - h['feature'].mean()) / h['feature'].std()
        h['cell'] = (h['cell'] - h['cell'].mean()) / h['cell'].std()

        hcell.append(h['cell'])

    logger.debug(f"{hcell[-1].abs().mean()=}")

    return hcell[1:]


[docs]class ScMoGCNWrapper: """ScMoGCN class. Parameters ---------- args : argparse.Namespace A Namespace object that contains arguments of ScMoGCN. For details of parameters in parser args, please refer to link (parser help document). dataset : dance.datasets.multimodality.JointEmbeddingNIPSDataset Joint embedding dataset. """ def __init__(self, args, num_celL_types, num_batches, num_phases, num_features): self.model = ScMoGCN(num_celL_types, num_batches, num_phases, num_features).to(args.device) self.args = args self.wt = torch.tensor([0.] * (args.layers - 1)).to(args.device).requires_grad_(True)
[docs] def fit(self, g_mod1, g_mod2, train_size, cell_type, batch_label, phase_score): """Fit function for training. Parameters ---------- g_mod1 : dgl.DGLGraph Bipartite expression feature graph for modality 1. g_mod2 : dgl.DGLGraph Bipartite expression feature graph for modality 2. train_size : int or array_like Number of training samples. labels : torch.Tensor Labels for training samples. cell_type :torch.Tensor Cell type labels for training samples. batch_label : torch.Tensor Batch labels for training samples. phase_score : torch.Tensor Phase labels for training samples. Returns ------- None. """ wt = self.wt hcell_mod1 = cell_feature_propagation(g_mod1, layers=self.args.layers, device=self.args.device) hcell_mod2 = cell_feature_propagation(g_mod2, layers=self.args.layers, device=self.args.device) self.feat_mod1 = hcell_mod1 self.feat_mod2 = hcell_mod2 X = [] for i in range(len(self.feat_mod1)): X.append(torch.cat([self.feat_mod1[i], self.feat_mod2[i]], dim=1).float().to(self.args.device)) self.X = X Y = [cell_type.to(self.args.device), batch_label.to(self.args.device), phase_score.float().to(self.args.device)] idx = np.random.permutation(train_size) train_idx = idx[:int(idx.shape[0] * 0.9)] val_idx = idx[int(idx.shape[0] * 0.9):] # Make sure the batch size is small enough to cover all splits batch_size = min(self.args.batch_size, len(val_idx)) train_dataset = SimpleIndexDataset(train_idx) train_loader = DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, ) ce = nn.CrossEntropyLoss() mse = nn.MSELoss() optimizer = torch.optim.AdamW([{'params': self.model.parameters()}, {'params': wt, 'weight_decay': 0}], lr=1e-4) vals = [] for epoch in range(60): self.model.train() total_loss = [0] * 5 print('epoch', epoch) for iter, batch_idx in enumerate(train_loader): batch_x = propagation_layer_combination(X, batch_idx, wt) batch_y = [batch_x, Y[0][batch_idx], Y[1][batch_idx], Y[2][batch_idx]] output = self.model(batch_x) # loss1 = mse(output[0], batch_y[0]) # option 1: recover features after propagation loss1 = mse(output[0], batch_x) # option 2: recover Isi features loss2 = ce(output[1], batch_y[1]) loss3 = torch.norm(output[2], p=2, dim=-1).sum() * 1e-2 loss4 = mse(output[3], batch_y[3]) loss = loss1 * 0.7 + loss2 * 0.2 + loss3 * 0.05 + loss4 * 0.05 total_loss[0] += loss1.item() total_loss[1] += loss2.item() total_loss[2] += loss3.item() total_loss[3] += loss4.item() total_loss[4] += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() for i in range(4): print(f'loss{i + 1}', total_loss[i] / len(train_loader), end=', ') print() loss1, loss2, loss3, loss4 = self.score(val_idx, Y[0], Y[2]) weighted_loss = loss1 * 0.7 + loss2 * 0.2 + loss3 * 0.05 + loss4 * 0.05 print('val-loss1', loss1, 'val-loss2', loss2, 'val-loss3', loss3, 'val-loss4', loss4) print('val score', weighted_loss) vals.append(weighted_loss) if min(vals) == vals[-1]: if not os.path.exists('models'): os.mkdir('models') torch.save(self.model.state_dict(), f'models/model_joint_embedding_{self.args.seed}.pth') weight_record = wt.detach() best_dict = deepcopy(self.model.state_dict()) if min(vals) != min(vals[-10:]): print('Early stopped.') break self.wt = weight_record self.fitted = True self.model.load_state_dict(best_dict)
[docs] def to(self, device): """Performs device conversion. Parameters ---------- device : str Target device. Returns ------- self : ScMoGCNWrapper Converted model. """ self.args.device = device self.model = self.model.to(device) self.feat_mod1 = self.feat_mod1.to(device) self.feat_mod2 = self.feat_mod2.to(device) self.X = self.X.to(device) return self
[docs] def load(self, path, map_location=None): """Load model parameters from checkpoint file. Parameters ---------- path : str Path to the checkpoint file. map_location : str optional Mapped device. This parameter will be passed to torch.load function if not none. Returns ------- None. """ self.fitted = True if map_location is not None: self.model.load_state_dict(torch.load(path, map_location=map_location)) else: self.model.load_state_dict(torch.load(path))
[docs] def predict(self, idx): """Predict function to get latent representation of data. Parameters ---------- idx : Iterable[int] Index of testing samples for prediction. Returns ------- prediction : torch.Tensor Joint embedding of input data. """ if not self.fitted: raise RuntimeError('Model is not fitted yet.') self.model.eval() wt = self.wt inputs = self.X with torch.no_grad(): X = propagation_layer_combination(inputs, idx, wt) return self.model.encoder(X)
[docs] def score(self, idx, cell_type, phase_score=None, adata_sol=None, metric='loss'): """Score function to get score of prediction. Parameters ---------- idx : Iterable[int] Index of testing samples for scoring. cell_type : torch.Tensor Cell type labels of testing samples. phase_score : torch.Tensor optional Cell cycle score of testing samples. metric : str optional The type of evaluation metric, by default to be 'loss'. Returns ------- loss1 : float Reconstruction loss. loss2 : float Cell type classfication loss. loss3 : float Batch regularization loss. loss4 : float Cell cycle score loss. """ self.model.eval() ce = nn.CrossEntropyLoss() mse = nn.MSELoss() inputs = self.X with torch.no_grad(): if metric == 'loss': X = propagation_layer_combination(inputs, idx, self.wt) output = self.model(X) loss1 = mse(output[0], X).item() loss2 = ce(output[1], cell_type[idx]).item() loss3 = (torch.norm(output[2], p=2, dim=-1).sum() * 1e-2).item() loss4 = mse(output[3], phase_score[idx]).item() return loss1, loss2, loss3, loss4 elif metric == 'clustering': emb = self.predict(idx).cpu().numpy() kmeans = KMeans(n_clusters=10, n_init=5, random_state=200) # adata.obs['batch'] = adata_sol.obs['batch'][adata.obs_names] # adata.obs['cell_type'] = adata_sol.obs['cell_type'][adata.obs_names] true_labels = cell_type pred_labels = kmeans.fit_predict(emb) NMI_score = round(normalized_mutual_info_score(true_labels, pred_labels, average_method='max'), 3) ARI_score = round(adjusted_rand_score(true_labels, pred_labels), 3) # print('ARI: ' + str(ARI_score) + ' NMI: ' + str(NMI_score)) return {'dance_nmi': NMI_score, 'dance_ari': ARI_score} elif metric == 'openproblems': emb = self.predict(idx).cpu().numpy() assert adata_sol, 'adata_sol is required by `openproblems` evaluation but not provided.' adata_sol.obsm['X_emb'] = emb return integration_openproblems_evaluate(adata_sol) else: raise NotImplementedError
class ScMoGCN(nn.Module): def __init__(self, nb_cell_types, nb_batches, nb_phases, input_dimension): super().__init__() self.nb_cell_types = nb_cell_types self.nb_batches = nb_batches self.nb_phases = nb_phases self.linear1 = nn.Linear(input_dimension, 150) self.linear2 = nn.Linear(150, 120) self.linear3 = nn.Linear(120, 100) self.linear4 = nn.Linear(100, 61) self.bn1 = nn.BatchNorm1d(150) self.bn2 = nn.BatchNorm1d(120) self.bn3 = nn.BatchNorm1d(100) self.act1 = nn.GELU() self.act2 = nn.GELU() self.act3 = nn.GELU() self.decoder = nn.Sequential( nn.Linear(61, 150), nn.ReLU(), nn.Linear(150, input_dimension), nn.ReLU(), ) def encoder(self, x): x = self.linear1(x) x = self.act1(x) x = self.bn1(x) x = F.dropout(x, p=0.3, training=self.training) x = self.linear2(x) x = self.act2(x) x = self.bn2(x) x = F.dropout(x, p=0.3, training=self.training) x = self.linear3(x) x = self.act3(x) x = self.bn3(x) x = F.dropout(x, p=0.3, training=self.training) x = self.linear4(x) return x def forward(self, x): x = self.encoder(x) x0 = x x = self.decoder(x) return ( x, x0[:, :self.nb_cell_types], x0[:, self.nb_cell_types:self.nb_cell_types + self.nb_batches], x0[:, self.nb_cell_types + self.nb_batches:self.nb_cell_types + self.nb_batches + self.nb_phases], )