Source code for dance.modules.multi_modality.match_modality.scmm

"""Reimplementation of scMM method.

Extended from https://github.com/kodaim1115/scMM

Reference
---------
Minoura, Kodai, et al. A mixture-of-experts deep generative model for integrated analysis of single-cell multiomics
data. Cell reports methods 1.5 (2021): 100071.

"""
import math
import os
from copy import deepcopy

import numpy as np
import torch
import torch.distributions as dist
import torch.nn as nn
import torch.nn.functional as F
from numpy import prod
from pyro.distributions.zero_inflated import ZeroInflatedNegativeBinomial
from sklearn.cluster import DBSCAN, KMeans
from sklearn.neighbors import NearestNeighbors
from torch import optim
from torch.utils.data import DataLoader

from dance.utils import SimpleIndexDataset


def get_mean(d, K=100):
    """Extract the `mean` parameter for given distribution.

    If attribute not available, estimate from samples.

    """
    try:
        mean = d.mean
    except NotImplementedError:
        samples = d.rsample(torch.Size([K]))
        mean = samples.mean(0)
    return mean


def kl_divergence(d1, d2, K=100):
    """Computes closed-form KL if available, else computes a MC estimate."""
    if (type(d1), type(d2)) in torch.distributions.kl._KL_REGISTRY:
        return torch.distributions.kl_divergence(d1, d2)
    else:
        samples = d1.rsample(torch.Size([K]))
        return (d1.log_prob(samples) - d2.log_prob(samples)).mean(0)


def m_elbo_naive(model, x):
    """Computes E_{p(x)}[ELBO] for multi-modal vae --- NOT EXPOSED."""
    qz_xs, px_zs, zss = model(x)
    lpx_zs, klds = [], []
    for r, qz_x in enumerate(qz_xs):
        kld = kl_divergence(qz_x, model.pz(*model._get_pz_params))
        klds.append(kld.sum(-1))
        for d, px_z in enumerate(px_zs[r]):
            lpx_z = px_z.log_prob(x[d]) * model.vaes[d].llik_scaling
            lpx_zs.append(lpx_z.sum(-1))
    obj = (1 / len(model.vaes)) * (torch.stack(lpx_zs).sum(0) - torch.stack(klds).sum(0))
    return obj.sum()


def m_elbo_naive_warmup(model, x, beta):
    """Computes E_{p(x)}[ELBO] for multi-modal vae --- NOT EXPOSED."""
    qz_xs, px_zs, zss = model(x)
    lpx_zs, klds = [], []
    for r, qz_x in enumerate(qz_xs):
        kld = kl_divergence(qz_x, model.pz(*model._get_pz_params))
        klds.append(kld.sum(-1))
        for d, px_z in enumerate(px_zs[r]):
            lpx_z = px_z.log_prob(x[d]) * model.vaes[d].llik_scaling
            lpx_zs.append(lpx_z.sum(-1))
    obj = (1 / len(model.vaes)) * (torch.stack(lpx_zs).sum(0) - beta * torch.stack(klds).sum(0))
    return obj.sum()


def protein_preprocessing(t1):
    t0 = t1.clone()
    t0[t0 == 0] = 1
    return torch.log1p(t1 / torch.exp(torch.sum(torch.log(t0), axis=1) * (1 / torch.sum(t1 > 0, axis=1))).unsqueeze(-1))


def atac_preprocessing(t1):
    t1[t1 > 0] = 1
    return t1


# TODO: Not implemented
def rna_preprocessing(t1):
    return t1


class Constants:
    eta = 1e-6
    eps = 1e-7
    log2 = math.log(2)
    log2pi = math.log(2 * math.pi)
    logceilc = 88  # largest cuda v s.t. exp(v) < inf
    logfloorc = -104  # smallest cuda v s.t. exp(v) > 0


class ZINB(ZeroInflatedNegativeBinomial):

    def __init__(self, total_count, probs, gate):
        super().__init__(total_count=total_count, probs=probs, gate=gate)


class VAE(nn.Module):

    def __init__(self, prior_dist, likelihood_dist, post_dist, enc, dec, params):
        super().__init__()
        self.pz = prior_dist
        self.px_z = likelihood_dist
        self.qz_x = post_dist
        self.enc = enc
        self.dec = dec
        self.modelName = None
        self.params = params
        self._pz_params = None  # defined in subclass
        self._qz_x_params = None  # populated in `forward`
        self.llik_scaling = 1.0

    @property
    def pz_params(self):
        return self._pz_params

    @property
    def qz_x_params(self):
        if self._qz_x_params is None:
            raise NameError("qz_x params not initialized yet!")
        return self._qz_x_params

    @staticmethod
    def getDataLoaders(batch_size, shuffle=True, device="cuda"):
        # handle merging individual datasets appropriately in sub-class
        raise NotImplementedError

    def forward(self, x):
        self._qz_x_params = self.enc(x)
        qz_x = self.qz_x(*self._qz_x_params)
        zs = qz_x.rsample()
        px_z = self.px_z(*self.dec(zs))
        return qz_x, px_z, zs

    def reconstruct(self, data):
        self.eval()
        with torch.no_grad():
            qz_x = self.qz_x(*self.enc(data))
            latents = qz_x.rsample()  # no dim expansion
            px_z = self.px_z(*self.dec(latents))
            recon = get_mean(px_z)
        return recon

    def reconstruct_sample(self, data):
        self.eval()
        with torch.no_grad():
            qz_x = self.qz_x(*self.enc(data))
            latents = qz_x.rsample()  # no dim expansion
            px_z = self.px_z(*self.dec(latents))
            recon = px_z._sample()
        return recon

    def latents(self, data, sampling=False):
        self.eval()
        with torch.no_grad():
            qz_x = self.qz_x(*self.enc(data))
            if not sampling:
                lats = get_mean(qz_x)
            else:
                lats = qz_x._sample()
        return lats


class Enc(nn.Module):

    def __init__(self, data_dim, latent_dim, num_hidden_layers, hidden_dim):  # added hidden_dim
        super().__init__()
        self.data_dim = data_dim
        modules = []
        modules.append(nn.Sequential(nn.Linear(data_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(True)))
        for _ in range(num_hidden_layers - 1):
            modules.append(nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(True)))
        self.enc = nn.Sequential(*modules)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)
        self.fc22 = nn.Linear(hidden_dim, latent_dim)
        self.scale_factor = 10000

    def read_count(self, x):
        read = torch.sum(x, axis=1)
        read = read.repeat(self.data_dim, 1).t()
        return (read)

    def forward(self, x):
        read = self.read_count(x)
        x = x / read * self.scale_factor
        e = self.enc(x)
        lv = self.fc22(e).clamp(-12, 12)  # restrict to avoid torch.exp() over/underflow
        return self.fc21(e), F.softmax(lv, dim=-1) * lv.size(-1) + Constants.eta


class Dec(nn.Module):
    """Generate an MNIST image given a sample from the latent space."""

    def __init__(self, data_dim, latent_dim, num_hidden_layers, hidden_dim, modality):  # added hidden_dim
        super().__init__()
        self.modality = modality
        self.data_dim = data_dim

        modules = []
        modules.append(nn.Sequential(nn.Linear(latent_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(True)))
        for _ in range(num_hidden_layers - 1):
            modules.append(nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(True)))
        self.dec = nn.Sequential(*modules)
        self.fc31 = nn.Linear(hidden_dim, data_dim)
        self.fc32 = nn.Linear(hidden_dim, data_dim)

        if self.modality == 'atac':
            # zero-inflated
            self.fc33 = nn.Linear(hidden_dim, data_dim)

    def forward(self, z):
        d = self.dec(z)
        log_r = self.fc31(d).clamp(-12, 12)  # restrict to avoid torch.exp() over/underflow
        r = torch.exp(log_r)
        p = self.fc32(d)
        p = torch.sigmoid(p).clamp(Constants.eps, 1 - Constants.eps)  # restrict to avoid probs = 0,1

        if self.modality == 'atac':
            g = self.fc33(d)
            g = torch.sigmoid(g)
            return r, p, g
        else:
            return r, p


class ATAC(VAE):
    """Derive a specific sub-class of a VAE for ATAC."""

    def __init__(self, params):
        super().__init__(dist.Laplace, ZINB, dist.Laplace,
                         Enc(params.p_dim, params.latent_dim, params.num_hidden_layers, params.p_hidden_dim),
                         Dec(params.p_dim, params.latent_dim, params.num_hidden_layers, params.p_hidden_dim, 'atac'),
                         params)
        grad = {'requires_grad': params.learn_prior}
        self._pz_params = nn.ParameterList([
            nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False),  # mu
            nn.Parameter(torch.zeros(1, params.latent_dim), **grad)  # logvar
        ])
        self.modelName = 'atac'
        self.data_dim = self.params.p_dim
        self.llik_scaling = 1.
        self.scale_factor = 10000

    @property
    def pz_params(self):
        return self._pz_params[0], F.softmax(self._pz_params[1], dim=1) * self._pz_params[1].size(-1)

    @staticmethod
    def getDataLoaders(dataset, batch_size, shuffle=True, device="cuda"):
        kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=False, **kwargs)
        return dataloader

    def forward(self, x):
        read_count = self.enc.read_count(x)
        self._qz_x_params = self.enc(x)
        qz_x = self.qz_x(*self._qz_x_params)
        zs = qz_x.rsample()
        r, p, g = self.dec(zs)
        r = r / self.scale_factor * read_count
        px_z = self.px_z(r, p, g)
        return qz_x, px_z, zs


class Protein(VAE):
    """Derive a specific sub-class of a VAE for Protein."""

    def __init__(self, params):
        super().__init__(dist.Laplace, dist.NegativeBinomial, dist.Laplace,
                         Enc(params.p_dim, params.latent_dim, params.num_hidden_layers, params.p_hidden_dim),
                         Dec(params.p_dim, params.latent_dim, params.num_hidden_layers, params.p_hidden_dim, 'protein'),
                         params)
        grad = {'requires_grad': params.learn_prior}
        self._pz_params = nn.ParameterList([
            nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False),  # mu
            nn.Parameter(torch.zeros(1, params.latent_dim), **grad)  # logvar
        ])
        self.modelName = 'protein'
        self.data_dim = self.params.p_dim
        self.llik_scaling = 1.
        self.scale_factor = 10000

    @property
    def pz_params(self):
        return self._pz_params[0], F.softmax(self._pz_params[1], dim=1) * self._pz_params[1].size(-1)

    @staticmethod
    def getDataLoaders(dataset, batch_size, shuffle=True, device="cuda"):
        kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=False, **kwargs)

        return dataloader

    def forward(self, x):
        read_count = self.enc.read_count(x)
        self._qz_x_params = self.enc(x)
        qz_x = self.qz_x(*self._qz_x_params)
        zs = qz_x.rsample()
        r, _ = self.dec(zs)
        r = r / self.scale_factor * read_count
        px_z = self.px_z(r, _)
        return qz_x, px_z, zs


class RNA(VAE):
    """Derive a specific sub-class of a VAE for RNA."""

    def __init__(self, params):
        super().__init__(
            dist.Laplace,
            dist.NegativeBinomial,  # likelihood
            dist.Laplace,
            Enc(params.r_dim, params.latent_dim, params.num_hidden_layers, params.r_hidden_dim),
            Dec(params.r_dim, params.latent_dim, params.num_hidden_layers, params.r_hidden_dim, 'rna'),
            params)
        grad = {'requires_grad': params.learn_prior}
        self._pz_params = nn.ParameterList([
            nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False),  # mu
            nn.Parameter(torch.zeros(1, params.latent_dim), **grad)  # logvar
        ])
        self.modelName = 'rna'
        self.data_dim = self.params.r_dim
        self.llik_scaling = 1.
        self.scale_factor = 10000

    @property
    def pz_params(self):
        return self._pz_params[0], F.softmax(self._pz_params[1], dim=1) * self._pz_params[1].size(-1)

    @staticmethod
    def getDataLoaders(dataset, batch_size, shuffle=True, device="cuda"):
        kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=False, **kwargs)

        return dataloader

    def forward(self, x):
        read_count = self.enc.read_count(x)
        self._qz_x_params = self.enc(x)
        qz_x = self.qz_x(*self._qz_x_params)
        zs = qz_x.rsample()
        r, _ = self.dec(zs)
        r = r / self.scale_factor * read_count
        px_z = self.px_z(r, _)
        return qz_x, px_z, zs


[docs]class MMVAE(nn.Module): """MMVAE class. Parameters ---------- subtask : str Name of the subtask which is composed of the name of two modality. This parameter will indicate some modality-specific features in the model. params : argparse.Namespace A Namespace object that contains arguments of MMVAE. For details of parameters in parser args, please refer to link (parser help document). """ def __init__(self, subtask, params): super().__init__() self.pz = dist.Laplace assert subtask in ('rna-dna', 'rna-protein') self.modelName = subtask if subtask == 'rna-dna': self.preprocessing = atac_preprocessing self.vaes = nn.ModuleList([RNA(params), ATAC(params)]) else: self.preprocessing = protein_preprocessing self.vaes = nn.ModuleList([RNA(params), Protein(params)]) self.params = params grad = {'requires_grad': params.learn_prior} self._pz_params = nn.ParameterList([ nn.Parameter(torch.zeros(1, params.latent_dim), requires_grad=False), # mu nn.Parameter(torch.zeros(1, params.latent_dim), **grad) # logvar ]) self.vaes[0].llik_scaling = prod(self.vaes[1].dataSize) / prod(self.vaes[0].dataSize) \ if params.llik_scaling == 0 else params.llik_scaling self.scale_factor = 10000 @property def _get_pz_params(self): return self._pz_params[0], F.softmax(self._pz_params[1], dim=1) * self._pz_params[1].size(-1) def _get_cluster(self, data, modality='both', n_clusters=10, method='kmeans', device='cuda'): self.eval() lats = self._get_latents(data, sampling=False) if modality == 'both': lat = sum(lats) / len(lats) elif modality == 'rna': lat = lats[0] elif modality == 'atac': lat = lats[1] if method == 'kmeans': fit = KMeans(n_clusters=n_clusters, random_state=0, init='k-means++').fit(lat.cpu().numpy()) cluster = fit.labels_ elif method == 'dbscan': fit = DBSCAN(eps=0.5, min_samples=50).fit(lat.cpu().numpy()) cluster = fit.labels_ else: gamma, _, _, _, _ = self.get_gamma(lat) cluster = torch.argmax(gamma, axis=1) cluster = cluster.detach().numpy() fit = None return cluster, fit
[docs] def forward(self, x): """Forward function for torch.nn.Module. Parameters ---------- x : list[torch.Tensor] Features of two modalities. Returns ------- qz_xs : list[torch.Tensor] Post prior of two modalities. px_zs : list[torch.Tensor] likelihood of two modalities. zss : list[torch.Tensor] Reconstruction results of two modalities. """ qz_xs, zss = [], [] read_counts = [] # initialise cross-modal matrix px_zs = [[None for _ in range(len(self.vaes))] for _ in range(len(self.vaes))] for m, vae in enumerate(self.vaes): read_counts.append(vae.enc.read_count(x[m])) qz_x, px_z, zs = vae(x[m]) qz_xs.append(qz_x) zss.append(zs) px_zs[m][m] = px_z # fill-in diagonal for e, zs in enumerate(zss): for d, vae in enumerate(self.vaes): if e != d: # fill-in off-diagonal if self.modelName == 'rna-protein': r, _ = vae.dec(zs) r = r / self.scale_factor * read_counts[d] px_zs[e][d] = vae.px_z(r, _) else: if d == 0: r, p = vae.dec(zs) r = r / self.scale_factor * read_counts[d] px_zs[e][d] = vae.px_z(r, p) else: r, p, g = vae.dec(zs) r = r / self.scale_factor * read_counts[d] px_zs[e][d] = vae.px_z(r, p, g) return qz_xs, px_zs, zss
def _reconstruct(self, data): self.eval() with torch.no_grad(): _, px_zs, _ = self.forward(data) # cross-modal matrix of reconstructions of MEANS recons = [[get_mean(px_z) for px_z in r] for r in px_zs] return recons def _reconstruct_sample(self, data): self.eval() with torch.no_grad(): _, px_zs, _ = self.forward(data) # cross-modal matrix of reconstructions of SAMPLES recons = [[px_z.sample() for px_z in r] for r in px_zs] return recons def _get_latents(self, data, sampling=False): self.eval() with torch.no_grad(): qz_xs, _, _ = self.forward(data) if not sampling: lats = [get_mean(qz_x) for qz_x in qz_xs] else: lats = [qz_x._sample() for qz_x in qz_xs] return lats
[docs] def fit(self, x_train, y_train, val_ratio=0.15): """Fit function for training. Parameters ---------- x_train : torch.Tensor Input modality for training. y_train : torch.Tensor Target modality for training. val_ratio : float Ratio for automatic train-validation split. """ start_early_stop = self.params.deterministic_warmup idx = np.random.permutation(x_train.shape[0]) train_idx = idx[:int(idx.shape[0] * val_ratio)] val_idx = idx[int(idx.shape[0] * val_ratio):] optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.params.lr, amsgrad=True) assert (self.params.obj in ['m_elbo_naive', 'm_elbo_naive_warmup']) objective = m_elbo_naive_warmup if self.params.obj == 'm_elbo_naive_warmup' else m_elbo_naive train_dataset = SimpleIndexDataset(train_idx) train_loader = DataLoader( dataset=train_dataset, batch_size=self.params.batch_size, shuffle=True, num_workers=0, drop_last=True, ) train_mod1 = x_train.float().to(self.params.device) train_mod2 = y_train.float().to(self.params.device) tr, vals = [], [] for epoch in range(1, self.params.epochs + 1): self.train() b_loss = 0 for i, batch_idx in enumerate(train_loader): dataT = (train_mod1[batch_idx], train_mod2[batch_idx]) beta = (epoch - 1) / start_early_stop if epoch <= start_early_stop else 1 if dataT[0].size()[0] == 1: continue # data = [d.to(self.paradevice) for d in dataT] # multimodal data = dataT optimizer.zero_grad() loss = -objective(self, data, beta) loss.backward() optimizer.step() b_loss += loss.item() if self.params.print_freq > 0 and i % self.params.print_freq == 0: print("iteration {:04d}: loss: {:6.3f}".format(i, loss.item() / self.params.batch_size)) tr.append(b_loss / len(train_loader.dataset)) print('====> Epoch: {:03d} Train loss: {:.4f}'.format(epoch, tr[-1])) if torch.isnan(torch.tensor([b_loss])): break vals.append(self.score(train_mod1[val_idx], train_mod2[val_idx], metric='loss')) print('====> Valid loss: {:.4f}'.format(vals[-1])) if vals[-1] == min(vals): if not os.path.exists('models'): os.mkdir('models') torch.save(self.state_dict(), f'models/model_{self.params.seed}.pth') best_dict = deepcopy(self.state_dict()) if epoch % 10 == 0: print('Valid Matching score:', self.score(train_mod1[val_idx], train_mod2[val_idx], torch.eye(val_idx.shape[0]))) if epoch > start_early_stop and min(vals) != min(vals[-10:]): print('Early stopped.') break self.load_state_dict(best_dict)
[docs] def score(self, mod1, mod2, labels=None, metric='minkowski'): """Score function to get score of prediction. Parameters ---------- mod1 : torch.Tensor Features of modality 1. mod2 : torch.Tensor Features of modality 2. labels : torch.Tensor optional Labels of matching modality, i.e. cell correspondence between two modalities. Required when metric is not 'loss'. metric : str optional Metric of the score function, by default to be 'minkowski'. Returns ------- score : float Score of predicted matching, according to specified metric. """ self.eval() mod1 = mod1.float().to(self.params.device) mod2 = mod2.float().to(self.params.device) if labels is None: assert metric == 'loss', 'Unable to evaluate without labels.' if metric == 'loss': b_loss = 0 idx = np.arange(mod1.shape[0]) dataset = SimpleIndexDataset(idx) data_loader = DataLoader( dataset=dataset, batch_size=self.params.batch_size, shuffle=False, num_workers=0, drop_last=False, ) with torch.no_grad(): for i, batch_idx in enumerate(data_loader): objective = m_elbo_naive_warmup if self.params.obj == 'm_elbo_naive_warmup' else 'm_elbo_naive' loss = -objective(self, [mod1[batch_idx], mod2[batch_idx]], 1).item() b_loss += loss return b_loss / mod1.shape[0] else: pred = self.predict(mod1, mod2, metric=metric) return (pred[torch.arange(pred.shape[0]).long(), labels.long()].mean()).item()
[docs] def predict(self, mod1, mod2, metric='minkowski'): """Predict function to get score of prediction. Parameters ---------- mod1 : torch.Tensor Features of the first modality. mod2 : torch.Tensor Features of the second modality. metric : str optional Metric of the matching function, by default to be 'minkowski'. Returns ------- pred : float Predicted matching between two modalities. """ self.eval() idx = np.arange(mod1.shape[0]) mod1 = mod1.float().to(self.params.device) mod2 = mod2.float().to(self.params.device) dataset = SimpleIndexDataset(idx) data_loader = DataLoader( dataset=dataset, batch_size=self.params.batch_size * 10, shuffle=False, num_workers=0, drop_last=False, ) pred = [] with torch.no_grad(): for i, batch_idx in enumerate(data_loader): dataT = [mod1[batch_idx], mod2[batch_idx]] lats = self._get_latents(dataT, sampling=False) if i == 0: pred = lats else: for m, lat in enumerate(lats): pred[m] = torch.cat([pred[m], lat], dim=0) lats = [] for m, lat in enumerate(pred): lat = lat.cpu() lats.append(lat) nn = NearestNeighbors(metric=metric) nn.fit(lats[0]) transp_nearest_neighbor = torch.tensor(nn.kneighbors(lats[1], 1, return_distance=False)) pred = torch.zeros(lats[0].shape[0], lats[0].shape[0]) pred[torch.arange(lats[0].shape[0]), transp_nearest_neighbor.squeeze(-1)] = 1 return pred