Source code for dance.modules.multi_modality.joint_embedding.jae

"""Reimplementation of the JAE model, which is adapted from scDEC.

Extended from https://github.com/kimmo1019/JAE

Reference
---------
Liu Q, Chen S, Jiang R, et al. Simultaneous deep generative modelling and clustering of single-cell genomic data[J].
Nature machine intelligence, 2021, 3(6): 536-544.

"""

import os
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from dance.utils import SimpleIndexDataset
from dance.utils.metrics import *


def random_classification_loss(y_pred, nb_batches):
    device = nb_batches.device
    y_true = torch.ones(y_pred.shape).float().to(device) / nb_batches.shape[0]
    return (-(torch.softmax(y_pred, -1) + 1e-7).log() * y_true).sum(-1).mean()


class JAEWrapper:
    """JAE class.

    Parameters
    ----------
    args : argparse.Namespace
        A Namespace object that contains arguments of JAE. For details of parameters in parser args, please refer to
        link (parser help document).
    dataset : dance.datasets.multimodality.JointEmbeddingNIPSDataset
        Joint embedding dataset.

    """

    def __init__(self, args, num_celL_types, num_batches, num_phases, num_features):
        self.model = JAE(num_celL_types, num_batches, num_phases, num_features).to(args.device)
        print(num_celL_types, num_batches, num_phases, num_features)
        self.args = args

    def fit(self, inputs, cell_type, batch_label, phase_score, max_epochs=60):
        """Fit function for training.

        Parameters
        ----------
        inputs : torch.Tensor
            Modality features.
        cell_type : torch.Tensor
            Cell type labels.
        batch_label : torch.Tensor
            Batch labels.
        phase_score : torch.Tensor
            Phase scores.

        Returns
        -------
        None.

        """
        X = inputs.float().to(self.args.device)
        Y = [cell_type.to(self.args.device), batch_label.to(self.args.device), phase_score.float().to(self.args.device)]
        idx = np.random.permutation(X.shape[0])
        train_idx = idx[:int(idx.shape[0] * 0.9)]
        val_idx = idx[int(idx.shape[0] * 0.9):]

        train_dataset = SimpleIndexDataset(train_idx)
        train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=self.args.batch_size,
            shuffle=True,
            num_workers=1,
        )

        ce = nn.CrossEntropyLoss()
        mse = nn.MSELoss()

        optimizer = torch.optim.Adam([{'params': self.model.parameters()}], lr=1e-4)
        vals = []

        for epoch in range(max_epochs):
            self.model.train()
            total_loss = [0] * 5
            print('epoch', epoch)

            for iter, batch_idx in enumerate(train_loader):

                batch_x = X[batch_idx]
                batch_y = [batch_x, Y[0][batch_idx], Y[1][batch_idx], Y[2][batch_idx]]

                output = self.model(batch_x)

                # loss1 = mse(output[0], batch_y[0]) # option 1: recover features after propagation
                loss1 = mse(output[0], batch_x)  # option 2: recover Isi features
                loss2 = ce(output[1], batch_y[1])
                loss3 = random_classification_loss(output[2], batch_y[2])
                loss4 = mse(output[3], batch_y[3])

                loss = loss1 * 0.7 + loss2 * 0.2 + loss3 * 0.05 + loss4 * 0.05

                total_loss[0] += loss1.item()
                total_loss[1] += loss2.item()
                total_loss[2] += loss3.item()
                total_loss[3] += loss4.item()
                total_loss[4] += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            for i in range(4):
                print(f'loss{i + 1}', total_loss[i] / len(train_loader), end=', ')
            print()

            loss1, loss2, loss3, loss4 = self.score(X, val_idx, Y[0], Y[1], Y[2])
            weighted_loss = loss1 * 0.7 + loss2 * 0.2 + loss3 * 0.05 + loss4 * 0.05
            print('val-loss1', loss1, 'val-loss2', loss2, 'val-loss3', loss3, 'val-loss4', loss4)
            print('val score', weighted_loss)
            vals.append(weighted_loss)
            if min(vals) == vals[-1]:
                if not os.path.exists('models'):
                    os.mkdir('models')
                best_dict = deepcopy(self.model.state_dict())


#                 torch.save(self.model.state_dict(), f'models/model_joint_embedding_{self.args.seed}.pth')

            if min(vals) != min(vals[-10:]):
                print('Early stopped.')
                break
        self.model.load_state_dict(best_dict)

    def to(self, device):
        """Performs device conversion.

        Parameters
        ----------
        device : str
            Target device.

        Returns
        -------
        self : JAEWrapper
            Converted model.

        """
        self.args.device = device
        self.model = self.model.to(device)
        return self

    def load(self, path, map_location=None):
        """Load model parameters from checkpoint file.

        Parameters
        ----------
        path : str
            Path to the checkpoint file.
        map_location : str optional
            Mapped device. This parameter will be passed to torch.load function if not none.

        Returns
        -------
        None.

        """
        if map_location is not None:
            self.model.load_state_dict(torch.load(path, map_location=map_location))
        else:
            self.model.load_state_dict(torch.load(path))

    def predict(self, inputs, idx):
        """Predict function to get latent representation of data.

        Parameters
        ----------
        inputs : torch.Tensor
            Multimodality features.
        idx : Iterable(int)
            Index of cells to predict.

        Returns
        -------
        prediction : torch.Tensor
            Joint embedding of input data.

        """
        self.model.eval()
        with torch.no_grad():
            prediction = self.model.encoder(inputs[idx])
        return prediction

    def score(self, inputs, idx, cell_type, batch_label=None, phase_score=None, adata_sol=None, metric='loss'):
        """Score function to get score of prediction.

        Parameters
        ----------
        inputs : torch.Tensor
            Multimodality features.
        idx : Iterable[int]
            Index of testing cells for scoring.
        cell_type : torch.Tensor
            Cell type labels.
        phase_score : torch.Tensor
            Cell cycle phase labels.
        metric : str optional
            The type of evaluation metric, by default to be 'loss'.
        adata_sol : anndata.AnnData optional
            The solution anndata containing cell stypes, phase scores and batches. Required by 'openproblems' evaluation.

        Returns
        -------
        loss1 : float
            Reconstruction loss.
        loss2 : float
            Cell type classfication loss.
        loss3 : float
            Batch regularization loss.
        loss4 : float
            Cell cycle score loss.

        """
        self.model.eval()

        with torch.no_grad():

            if metric == 'loss':
                ce = nn.CrossEntropyLoss()
                mse = nn.MSELoss()
                X = inputs[idx]
                output = self.model(X)
                loss1 = mse(output[0], X).item()
                loss2 = ce(output[1], cell_type[idx]).item()
                loss3 = random_classification_loss(output[2], batch_label[idx]).item()
                loss4 = mse(output[3], phase_score[idx]).item()

                return loss1, loss2, loss3, loss4
            elif metric == 'clustering':
                emb = self.predict(inputs, idx).cpu().numpy()

                kmeans = KMeans(n_clusters=10, n_init=5, random_state=200)

                # adata.obs['batch'] = adata_sol.obs['batch'][adata.obs_names]
                # adata.obs['cell_type'] = adata_sol.obs['cell_type'][adata.obs_names]
                true_labels = cell_type
                pred_labels = kmeans.fit_predict(emb)
                print(true_labels, pred_labels)
                NMI_score = round(normalized_mutual_info_score(true_labels, pred_labels, average_method='max'), 3)
                ARI_score = round(adjusted_rand_score(true_labels, pred_labels), 3)

                # print('ARI: ' + str(ARI_score) + ' NMI: ' + str(NMI_score))
                return {'dance_nmi': NMI_score, 'dance_ari': ARI_score}
            elif metric == 'openproblems':
                emb = self.predict(inputs, idx).cpu().numpy()
                assert adata_sol, 'adata_sol is required by `openproblems` evaluation but not provided.'
                adata_sol.obsm['X_emb'] = emb
                return integration_openproblems_evaluate(adata_sol)
            else:
                raise NotImplementedError


[docs]class JAE(nn.Module): def __init__(self, nb_cell_types, nb_batches, nb_phases, input_dimension): super().__init__() self.nb_cell_types = nb_cell_types self.nb_batches = nb_batches self.nb_phases = nb_phases self.linear1 = nn.Linear(input_dimension, 150) self.linear2 = nn.Linear(150, 120) self.linear3 = nn.Linear(120, 100) self.linear4 = nn.Linear(100, 61) self.bn1 = nn.BatchNorm1d(150) self.bn2 = nn.BatchNorm1d(120) self.bn3 = nn.BatchNorm1d(100) self.act1 = nn.GELU() self.act2 = nn.GELU() self.act3 = nn.GELU() self.decoder = nn.Sequential( nn.Linear(61, 150), nn.ReLU(), nn.Linear(150, input_dimension), nn.ReLU(), ) def encoder(self, x): x = self.linear1(x) x = self.act1(x) x = self.bn1(x) x = F.dropout(x, p=0.2, training=self.training) x = self.linear2(x) x = self.act2(x) x = self.bn2(x) x = F.dropout(x, p=0.2, training=self.training) x = self.linear3(x) x = self.act3(x) x = self.bn3(x) x = F.dropout(x, p=0.2, training=self.training) x = self.linear4(x) return x
[docs] def forward(self, x): x = self.encoder(x) x0 = x x = self.decoder(x) return ( x, x0[:, :self.nb_cell_types], x0[:, self.nb_cell_types:self.nb_cell_types + self.nb_batches], x0[:, self.nb_cell_types + self.nb_batches:self.nb_cell_types + self.nb_batches + self.nb_phases], )