"""Reimplementation of Cross-Model AutoEncoder method.
Extended from https://github.com/uhlerlab/cross-modal-autoencoders
Reference
---------
Yang, Karren Dai, et al. "Multi-domain translation between single-cell imaging and sequencing data using autoencoders." Nature communications 12.1 (2021): 1-10.
"""
import math
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from dance.utils import SimpleIndexDataset
class Discriminator(nn.Module):
def __init__(self, input_dim, params):
super().__init__()
self.gan_type = params['gan_type']
self.dim = params['dim']
self.norm = params['norm']
self.input_dim = input_dim
self.net = self._make_net()
def _make_net(self):
return nn.Sequential(
nn.Linear(self.input_dim, self.input_dim),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(self.input_dim, self.dim),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(self.dim, 1),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, x):
return self.net(x)
def calc_dis_loss(self, input_fake, input_real):
# calculate the loss to train D
outs0 = [self.forward(input_fake)]
outs1 = [self.forward(input_real)]
loss = 0
for it, (out0, out1) in enumerate(zip(outs0, outs1)):
loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2)
return loss
def calc_gen_loss(self, input_fake):
# calculate the loss to train G
outs0 = [self.forward(input_fake)]
loss = 0
for it, (out0) in enumerate(outs0):
# 1 = real data
loss += torch.mean((out0 - 1)**2)
return loss
def calc_gen_loss_reverse(self, input_real):
# calculate the loss to train G
outs0 = [self.forward(input_real)]
loss = 0
for it, (out0) in enumerate(outs0):
# 0 = fake data
loss += torch.mean((out0 - 0)**2)
return loss
def calc_gen_loss_half(self, input_fake):
# calculate the loss to train G
outs0 = [self.forward(input_fake)]
loss = 0
for it, (out0) in enumerate(outs0):
if self.gan_type == 'lsgan':
loss += torch.mean((out0 - 0.5)**2)
elif self.gan_type == 'nsgan':
all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False)
loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1))
else:
assert 0, "Unsupported GAN type: {}".format(self.gan_type)
return loss
##################################################################################
# Generator
##################################################################################
class VAEGen(nn.Module):
# VAE architecture
def __init__(self, input_dim, params, shared_layer=False):
super().__init__()
self.dim = params['dim']
self.latent = params['latent']
self.input_dim = input_dim
# encoder_layers = [nn.Linear(self.input_dim, self.input_dim),
# nn.LeakyReLU(0.2, inplace=True),
# nn.Linear(self.input_dim, self.input_dim),
# nn.LeakyReLU(0.2, inplace=True),
# nn.Linear(self.input_dim, self.input_dim),
# nn.LeakyReLU(0.2, inplace=True),
# nn.Linear(self.input_dim, self.dim),
# nn.LeakyReLU(0.2, inplace=True)]
#
# decoder_layers = [nn.LeakyReLU(0.2, inplace=True),
# nn.Linear(self.dim, self.input_dim),
# nn.LeakyReLU(0.2, inplace=True),
# nn.Linear(self.input_dim, self.input_dim),
# nn.LeakyReLU(0.2, inplace=True),
# nn.Linear(self.input_dim, self.input_dim),
# nn.LeakyReLU(0.2, inplace=True),
# nn.Linear(self.input_dim, self.input_dim),
# nn.LeakyReLU(0.2, inplace=True)]
if self.input_dim > 1000:
hid_size = 1000
else:
hid_size = self.input_dim
encoder_layers = [
nn.Linear(self.input_dim, hid_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hid_size, hid_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hid_size, hid_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hid_size, self.dim),
nn.LeakyReLU(0.2, inplace=True)
]
decoder_layers = [
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(self.dim, hid_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hid_size, hid_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hid_size, hid_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hid_size, self.input_dim),
nn.LeakyReLU(0.2, inplace=True)
]
if shared_layer:
encoder_layers += [shared_layer["enc"], nn.LeakyReLU(0.2, inplace=True)]
decoder_layers = [shared_layer["dec"]] + decoder_layers
else:
encoder_layers += [nn.Linear(self.dim, self.latent), nn.LeakyReLU(0.2, inplace=True)]
decoder_layers = [nn.Linear(self.latent, self.dim)] + decoder_layers
self.enc = nn.Sequential(*encoder_layers)
self.dec = nn.Sequential(*decoder_layers)
def forward(self, images):
# This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones.
hiddens = self.encode(images)
if self.training == True:
noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))
images_recon = self.decode(hiddens + noise)
else:
images_recon = self.decode(hiddens)
return images_recon, hiddens
def encode(self, images):
hiddens = self.enc(images)
noise = torch.randn_like(hiddens)
return hiddens, noise
def decode(self, hiddens):
images = self.dec(hiddens)
return images
##################################################################################
# Classifier
##################################################################################
class Classifier(nn.Module):
def __init__(self, input_dim, cls=3):
super().__init__()
self.input_dim = input_dim
self.classes = cls
self.net = self._make_net()
self.cel = nn.CrossEntropyLoss()
def _make_net(self):
return nn.Sequential(nn.Linear(self.input_dim, self.classes))
def forward(self, x):
return self.net(x)
def class_loss(self, input, target):
return self.cel(input, target)
def weights_init(init_type='gaussian'):
def init_fun(m):
classname = m.__class__.__name__
if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
# print m.__class__.__name__
if init_type == 'gaussian':
init.normal_(m.weight.data, 0.0, 0.02)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=math.sqrt(2))
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=math.sqrt(2))
elif init_type == 'default':
pass
else:
assert 0, "Unsupported initialization: {}".format(init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
return init_fun
# Get model list for resume
def get_model_list(dirname, key):
if os.path.exists(dirname) is False:
return None
gen_models = [
os.path.join(dirname, f) for f in os.listdir(dirname)
if os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f
]
if gen_models is None:
return None
gen_models.sort()
last_model_name = gen_models[-1]
return last_model_name
def get_scheduler(optimizer, hyperparameters, iterations=-1):
if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant':
scheduler = None # constant scheduler
elif hyperparameters['lr_policy'] == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'],
gamma=hyperparameters['gamma'], last_epoch=iterations)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy'])
return scheduler
[docs]class CMAE(nn.Module):
"""CMAE class.
Parameters
----------
hyperparameters : dictionary
A dictionary that contains arguments of CMAE. For details of parameters in parser args, please refer to link (parser help document).
"""
def __init__(self, hyperparameters):
super().__init__()
lr = hyperparameters['lr']
# Initiate the networks
shared_layer = False
if "shared_layer" in hyperparameters and hyperparameters["shared_layer"]:
shared_layer = {}
shared_layer["dec"] = nn.Linear(hyperparameters['gen']['latent'], hyperparameters['gen']['dim'])
shared_layer["enc"] = nn.Linear(hyperparameters['gen']['dim'], hyperparameters['gen']['latent'])
self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen'],
shared_layer) # auto-encoder for domain a
self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen'],
shared_layer) # auto-encoder for domain b
self.dis_latent = Discriminator(hyperparameters['gen']['latent'],
hyperparameters['dis']) # discriminator for latent space
self.classifier = Classifier(hyperparameters['gen']['latent'],
cls=hyperparameters['num_of_classes']) # classifier on the latent space
# Setup the optimizers
beta1 = hyperparameters['beta1']
beta2 = hyperparameters['beta2']
dis_params = list(self.dis_latent.parameters())
gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) + list(self.classifier.parameters())
self.dis_opt = torch.optim.AdamW([p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2),
weight_decay=hyperparameters['weight_decay'])
self.gen_opt = torch.optim.AdamW([p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2),
weight_decay=hyperparameters['weight_decay'])
self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
# Network weight initialization
self.apply(weights_init(hyperparameters['init']))
self.dis_latent.apply(weights_init('gaussian'))
self.hyperparameters = hyperparameters
def _mae_loss(self, input, target):
"""A simple criterion function for MAE loss.
Parameters
----------
inputs : torch.Tensor
A input tensor.
target : torch.Tensor
A target tensor.
Returns
-------
loss : float
MAE loss between input and target tensors.
"""
return torch.mean(torch.abs(input - target))
[docs] def predict(self, mod1):
"""Predict function to get prediction of target modality features.
Parameters
----------
mod1 : torch.Tensor
Input modality features.
Returns
-------
pred : torch.Tensor
Predicted features of target modality.
"""
with torch.no_grad():
emb, _ = self.gen_a.encode(mod1)
pred = self.gen_b.decode(emb)
return pred
[docs] def score(self, mod1, mod2):
"""Score function to get score of prediction.
Parameters
----------
mod1 : torch.Tensor
Input modality features.
mod2 : torch.Tensor
Output modality features.
Returns
-------
score : float
RMSE loss of predicted output modality features.
"""
with torch.no_grad():
pred = self.predict(mod1)
mse = nn.MSELoss()
score = math.sqrt(mse(pred, mod2))
return score
[docs] def forward(self, mod1, mod2):
"""Forward function for torch.nn.Module.
Parameters
----------
mod1 : torch.Tensor
Input modality features.
mod2 : torch.Tensor
Target modality features.
Returns
-------
x_ab : torch.Tensor
Prediction of target modality from input modality.
x_ba : torch.Tensor
Prediction of input modality from target modality.
"""
self.eval()
h_a, _ = self.gen_a.encode(mod1)
h_b, _ = self.gen_b.encode(mod2)
x_ba = self.gen_a.decode(h_b)
x_ab = self.gen_b.decode(h_a)
self.train()
return x_ab, x_ba
def _gen_update(self, x_a, x_b, super_a, super_b, hyperparameters, a_labels=None, b_labels=None, variational=True):
true_samples = Variable(torch.randn(200, hyperparameters['gen']['latent']), requires_grad=False).cuda()
self.gen_opt.zero_grad()
# encode
h_a, n_a = self.gen_a.encode(x_a)
h_b, n_b = self.gen_b.encode(x_b)
# decode (within domain)
if variational:
h_a = h_a + n_a
h_b = h_b + n_b
x_a_recon = self.gen_a.decode(h_a)
x_b_recon = self.gen_b.decode(h_b)
classes_a = self.classifier.forward(h_a)
classes_b = self.classifier.forward(h_b)
# reconstruction loss
self.loss_gen_recon_x_a = self._mae_loss(x_a_recon, x_a)
self.loss_gen_recon_x_b = self._mae_loss(x_b_recon, x_b)
# GAN loss
self.loss_latent_a = self.dis_latent.calc_gen_loss(h_a)
self.loss_latent_b = self.dis_latent.calc_gen_loss_reverse(h_b)
# Classification Loss
if a_labels is not None and b_labels is not None:
self.loss_class_a = self.classifier.class_loss(classes_a, a_labels)
self.loss_class_b = self.classifier.class_loss(classes_b, b_labels)
else:
self.loss_class_a = self.loss_class_b = 0
# supervision
s_a, n_a = self.gen_a.encode(super_a)
s_b, n_b = self.gen_b.encode(super_b)
self.loss_supervision = self._mae_loss(s_a, s_b)
class_weight = hyperparameters['gan_w'] if "class_w" not in hyperparameters else hyperparameters["class_w"]
# total loss
self.loss_gen_total = hyperparameters['gan_w'] * self.loss_latent_a + \
hyperparameters['gan_w'] * self.loss_latent_b + \
class_weight * self.loss_class_a + \
class_weight * self.loss_class_b + \
hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
hyperparameters['super_w'] * self.loss_supervision
if variational:
self.loss_gen_total += hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b
self.loss_gen_total.backward()
self.gen_opt.step()
def _sample(self, x_a, x_b):
self.eval()
x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
for i in range(x_a.size(0)):
h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0))
x_a_recon.append(self.gen_a.decode(h_a))
x_b_recon.append(self.gen_b.decode(h_b))
x_ba.append(self.gen_a.decode(h_b))
x_ab.append(self.gen_b.decode(h_a))
x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
x_ba = torch.cat(x_ba)
x_ab = torch.cat(x_ab)
self.train()
return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba
def _dis_update(self, x_a, x_b, hyperparameters):
self.dis_opt.zero_grad()
# encode
h_a, n_a = self.gen_a.encode(x_a)
h_b, n_b = self.gen_b.encode(x_b)
# D loss
self.loss_dis_latent = self.dis_latent.calc_dis_loss(h_a, h_b)
self.loss_dis_total = hyperparameters['gan_w'] * (self.loss_dis_latent)
self.loss_dis_total.backward()
self.dis_opt.step()
def _update_learning_rate(self):
if self.dis_scheduler is not None:
self.dis_scheduler.step()
if self.gen_scheduler is not None:
self.gen_scheduler.step()
[docs] def resume(self, checkpoint_dir):
"""Resume function to resume from checkpoint file.
Parameters
----------
checkpoint_dir : str
Path to the checkpoint file.
Returns
-------
iterations : int
Current iteration number of resumed model.
"""
hyperparameters = self.hyperparameters
# Load generators
last_model_name = get_model_list(checkpoint_dir, "gen")
state_dict = torch.load(last_model_name)
self.gen_a.load_state_dict(state_dict['a'])
self.gen_b.load_state_dict(state_dict['b'])
iterations = int(last_model_name[-11:-3])
# Load discriminators
last_model_name = get_model_list(checkpoint_dir, "dis")
state_dict = torch.load(last_model_name)
self.dis_latent.load_state_dict(state_dict['latent'])
# Load optimizers
state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
self.dis_opt.load_state_dict(state_dict['dis'])
self.gen_opt.load_state_dict(state_dict['gen'])
# Reinitilize schedulers
self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
print('Resume from iteration %d' % iterations)
return iterations
[docs] def save(self, checkpoint_dir, iterations):
"""Save function to save parameters to checkpoint file.
Parameters
----------
checkpoint_dir : str
Path to the checkpoint file.
iterations : int
Current number of training iterations.
Returns
-------
None.
"""
# Save generators, discriminators, and optimizers
gen_name = os.path.join(checkpoint_dir, 'gen_%08d.pt' % (iterations + 1))
dis_name = os.path.join(checkpoint_dir, 'dis_%08d.pt' % (iterations + 1))
opt_name = os.path.join(checkpoint_dir, 'optimizer.pt')
torch.save(
{
'a': self.gen_a.state_dict(),
'b': self.gen_b.state_dict(),
"classifier": self.classifier.state_dict()
}, gen_name)
torch.save({'latent': self.dis_latent.state_dict()}, dis_name)
torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
[docs] def fit(self, train_mod1, train_mod2, aux_labels=None, checkpoint_directory='./checkpoint', val_ratio=0.15):
"""Train CMAE.
Parameters
----------
train_mod1 : torch.Tensor
Features of input modality.
train_mod2 : torch.Tensor
Features of target modality.
aux_labels : torch.Tensor optional
Auxiliary labels for extra supervision during training.
checkpoint_directory : str optional
Path to the checkpoint file, by default to be './checkpoint'.
val_ratio : float
Ratio for automatic train-validation split.
Returns
-------
None.
"""
hyperparameters = self.hyperparameters
idx = torch.randperm(train_mod1.shape[0])
train_idx = idx[:int(idx.shape[0] * (1 - val_ratio))]
val_idx = idx[int(idx.shape[0] * (1 - val_ratio)):]
train_dataset = SimpleIndexDataset(train_idx)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=hyperparameters['batch_size'],
shuffle=True,
num_workers=0,
drop_last=True,
)
# Start training
iterations = self.resume(checkpoint_directory,
hyperparameters=hyperparameters) if hyperparameters['resume'] else 0
num_disc = 1 if "num_disc" not in hyperparameters else hyperparameters["num_disc"]
num_gen = 1 if "num_gen" not in hyperparameters else hyperparameters["num_gen"]
while True:
print('Iteration: ', iterations)
for it, batch_idx in enumerate(train_loader):
mod1, mod2 = train_mod1[batch_idx], train_mod2[batch_idx]
for _ in range(num_disc):
self._dis_update(mod1, mod2, hyperparameters)
for _ in range(num_gen):
if aux_labels is not None:
self._gen_update(mod1, mod2, mod1, mod2, hyperparameters, aux_labels[batch_idx],
aux_labels[batch_idx], variational=False)
else:
self._gen_update(mod1, mod2, mod1, mod2, hyperparameters, variational=False)
self._update_learning_rate()
print('RMSE Loss:', self.score(train_mod1[val_idx], train_mod2[val_idx]))
iterations += 1
if iterations >= hyperparameters['max_epochs']:
print('Finish training')
break