"""Reimplementation of BABEL method.
Extended from https://github.com/wukevin/babel
Reference
---------
Wu, Kevin E., Kathryn E. Yost, Howard Y. Chang, and James Zou. "BABEL enables cross-modality translation between
multiomic profiles at single-cell resolution." Proceedings of the National Academy of Sciences 118, no. 15 (2021).
"""
import math
from copy import deepcopy
from typing import Callable, List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import dance.utils.loss as loss_functions
from dance import logger
REDUCE_LR_ON_PLATEAU_PARAMS = {
"mode": "min",
"factor": 0.1,
"patience": 10,
"min_lr": 1e-6,
}
def recursive_to_device(t, device="cpu"):
"""Recursively transfer t to the given device."""
if isinstance(t, tuple) or isinstance(t, list):
return tuple(recursive_to_device(x, device=device) for x in t)
return t.to(device)
class Exp(nn.Module):
"""Applies torch.exp, clamped to improve stability during training."""
def __init__(self, minimum=1e-5, maximum=1e6):
"""Values taken from DCA."""
super().__init__()
self.min_value = minimum
self.max_value = maximum
def forward(self, input):
return torch.clamp(
torch.exp(input),
min=self.min_value,
max=self.max_value,
)
class ClippedSoftplus(nn.Module):
def __init__(self, beta=1, threshold=20, minimum=1e-4, maximum=1e3):
super().__init__()
self.beta = beta
self.threshold = threshold
self.min_value = minimum
self.max_value = maximum
def forward(self, input):
return torch.clamp(
F.softplus(input, self.beta, self.threshold),
min=self.min_value,
max=self.max_value,
)
def extra_repr(self):
return "beta={}, threshold={}, min={}, max={}".format(
self.beta,
self.threshold,
self.min_value,
self.max_value,
)
class DeepCountAutoencoder(nn.Module):
"""Replicate of DCA."""
def __init__(
self,
input_dim: int,
inter_dim: int = 64,
bottle_dim: int = 32,
output_dim: int = None,
mode: str = "zinb",
):
super().__init__()
torch.manual_seed(1234) # Fixed initialization
assert mode in ["zinb", "nb", "poisson"], f"Unrecognized mode: {mode}"
self.mode = mode
self.encoder = nn.Linear(input_dim, inter_dim)
nn.init.xavier_uniform_(self.encoder.weight) # Aka Glorot initialization
self.bn1 = nn.BatchNorm1d(num_features=inter_dim)
self.bottleneck = nn.Linear(inter_dim, bottle_dim)
nn.init.xavier_uniform_(self.bottleneck.weight)
self.bn2 = nn.BatchNorm1d(num_features=bottle_dim)
self.decoder = nn.Linear(bottle_dim, inter_dim)
nn.init.xavier_uniform_(self.decoder.weight)
self.bn3 = nn.BatchNorm1d(num_features=inter_dim)
output_dim = output_dim if output_dim is not None else input_dim
self.mean = nn.Linear(inter_dim, output_dim) # These parameters are all per-gene
nn.init.xavier_uniform_(self.mean.weight)
if "nb" in self.mode:
self.disp = nn.Linear(inter_dim, output_dim)
nn.init.xavier_uniform_(self.disp.weight)
if self.mode == "zinb":
self.dropout = nn.Linear(inter_dim, output_dim)
nn.init.xavier_uniform_(self.dropout.weight)
self.forward = self.forward_with_decode # Default to returning denoised
def encode(self, x):
"""Given input, return bottlenecked latent representation."""
x = F.relu(self.bn1(self.encoder(x)))
x = F.relu(self.bn2(self.bottleneck(x)))
return x
def decode(self, x, size_factors):
"""Given latent representation, output 3 output layers (mean, dispersion,
dropout)"""
x = F.relu(self.bn3(self.decoder(x)))
mu = torch.clamp(torch.exp(self.mean(x)), min=1e-5, max=1e6) # Mean
sf_scaled = size_factors.view(-1, 1).repeat(1, mu.shape[1])
mu_scaled = mu * sf_scaled # Elementwise multiplication
if "nb" in self.mode:
theta = torch.clamp(F.softplus(self.disp(x)), min=1e-4, max=1e3) # Dispersion
else: # Is poisson
return mu_scaled
if self.mode == "zinb":
pi = torch.sigmoid(self.dropout(x)) # Dropout
return mu_scaled, theta, pi
else:
return mu_scaled, theta
def forward_with_decode(self, x, size_factors):
"""Return the parameters mu, theta, and pi The denoised matrix is generated by
replacing the original count values with the mean of the negative binomial
component as predicted in the output layer.
This matrix represents the denoised and library size normalized expression
matrix, the final output of the method.
"""
return self.decode(self.encode(x), size_factors)
def forward_no_decode(self, x, _size_factors):
"""Return the hidden representation instead of decoding it as well Useful for
probing the latent dimension using pytorch_eval Use by setting model.forward =
model.forward_no_decode."""
return self.encode(x)
class Encoder(nn.Module):
def __init__(self, num_inputs: int, num_units=32, activation=nn.PReLU):
super().__init__()
self.num_inputs = num_inputs
self.num_units = num_units
self.encode1 = nn.Linear(self.num_inputs, 64)
nn.init.xavier_uniform_(self.encode1.weight)
self.bn1 = nn.BatchNorm1d(64)
self.act1 = activation()
self.encode2 = nn.Linear(64, self.num_units)
nn.init.xavier_uniform_(self.encode2.weight)
self.bn2 = nn.BatchNorm1d(num_units)
self.act2 = activation()
def forward(self, x):
x = self.act1(self.bn1(self.encode1(x)))
x = self.act2(self.bn2(self.encode2(x)))
return x
class Decoder(nn.Module):
def __init__(
self,
num_outputs: int,
num_units: int = 32,
intermediate_dim: int = 64,
activation=nn.PReLU,
final_activation=None,
):
super().__init__()
self.num_outputs = num_outputs
self.num_units = num_units
self.decode1 = nn.Linear(self.num_units, intermediate_dim)
nn.init.xavier_uniform_(self.decode1.weight)
self.bn1 = nn.BatchNorm1d(intermediate_dim)
self.act1 = activation()
self.decode21 = nn.Linear(intermediate_dim, self.num_outputs)
nn.init.xavier_uniform_(self.decode21.weight)
self.decode22 = nn.Linear(intermediate_dim, self.num_outputs)
nn.init.xavier_uniform_(self.decode22.weight)
self.decode23 = nn.Linear(intermediate_dim, self.num_outputs)
nn.init.xavier_uniform_(self.decode23.weight)
self.final_activations = nn.ModuleDict()
if final_activation is not None:
if isinstance(final_activation, list) or isinstance(final_activation, tuple):
assert len(final_activation) <= 3
for i, act in enumerate(final_activation):
if act is None:
continue
self.final_activations[f"act{i+1}"] = act
elif isinstance(final_activation, nn.Module):
self.final_activations["act1"] = final_activation
else:
raise ValueError(f"Unrecognized type for final_activation: {type(final_activation)}")
def forward(self, x, size_factors=None):
"""Include size factor here because we may want to scale the output by that."""
x = self.act1(self.bn1(self.decode1(x)))
retval1 = self.decode21(x) # This is invariably the counts
if "act1" in self.final_activations.keys():
retval1 = self.final_activations["act1"](retval1)
if size_factors is not None:
sf_scaled = size_factors.view(-1, 1).repeat(1, retval1.shape[1])
retval1 = retval1 * sf_scaled # Elementwise multiplication
retval2 = self.decode22(x)
if "act2" in self.final_activations.keys():
retval2 = self.final_activations["act2"](retval2)
retval3 = self.decode23(x)
if "act3" in self.final_activations.keys():
retval3 = self.final_activations["act3"](retval3)
return retval1, retval2, retval3
class ChromEncoder(nn.Module):
"""Consumes multiple inputs (i.e. one feature vector for each chromosome) After
processing everything to be the same dimensionality, concatenate to form a single
latent dimension."""
def __init__(self, num_inputs: List[int], latent_dim: int = 32, activation=nn.PReLU):
super().__init__()
self.num_inputs = num_inputs
self.act = activation
self.initial_modules = nn.ModuleList()
for n in self.num_inputs:
assert isinstance(n, int)
layer1 = nn.Linear(n, 32)
nn.init.xavier_uniform_(layer1.weight)
bn1 = nn.BatchNorm1d(32)
act1 = self.act()
layer2 = nn.Linear(32, 16)
nn.init.xavier_uniform_(layer2.weight)
bn2 = nn.BatchNorm1d(16)
act2 = self.act()
self.initial_modules.append(nn.ModuleList([layer1, bn1, act1, layer2, bn2, act2]))
self.encode2 = nn.Linear(16 * len(self.num_inputs), latent_dim)
nn.init.xavier_uniform_(self.encode2.weight)
self.bn2 = nn.BatchNorm1d(latent_dim)
self.act2 = self.act()
def forward(self, x):
assert len(x) == len(self.num_inputs), f"Expected {len(self.num_inputs)} inputs but got {len(x)}"
enc_chroms = []
for init_mod, chrom_input in zip(self.initial_modules, x):
for f in init_mod:
chrom_input = f(chrom_input)
enc_chroms.append(chrom_input)
enc1 = torch.cat(enc_chroms, dim=1) # Concatenate along the feature dim not batch dim
enc2 = self.act2(self.bn2(self.encode2(enc1)))
return enc2
class ChromDecoder(nn.Module):
"""Network that is per-chromosome aware, but does not does not output per- chromsome
values, instead concatenating them into a single vector."""
def __init__(
self,
num_outputs: List[int], # Per-chromosome list of output sizes
latent_dim: int = 32,
activation=nn.PReLU,
final_activations=[Exp(), ClippedSoftplus()],
):
super().__init__()
self.num_outputs = num_outputs
self.latent_dim = latent_dim
self.decode1 = nn.Linear(self.latent_dim, len(self.num_outputs) * 16)
nn.init.xavier_uniform_(self.decode1.weight)
self.bn1 = nn.BatchNorm1d(len(self.num_outputs) * 16)
self.act1 = activation()
self.final_activations = nn.ModuleDict()
if final_activations is not None:
if isinstance(final_activations, list) or isinstance(final_activations, tuple):
assert len(final_activations) <= 3
for i, act in enumerate(final_activations):
if act is None:
continue
self.final_activations[f"act{i+1}"] = act
elif isinstance(final_activations, nn.Module):
self.final_activations["act1"] = final_activations
else:
raise ValueError(f"Unrecognized type for final_activation: {type(final_activations)}")
logger.info(f"ChromDecoder with {len(self.final_activations)} output activations")
self.final_decoders = nn.ModuleList() # List[List[Module]]
for n in self.num_outputs:
layer0 = nn.Linear(16, 32)
nn.init.xavier_uniform_(layer0.weight)
bn0 = nn.BatchNorm1d(32)
act0 = activation()
# l = [layer0, bn0, act0]
# for _i in range(len(self.final_activations)):
# fc_layer = nn.Linear(32, n)
# nn.init.xavier_uniform_(fc_layer.weight)
# l.append(fc_layer)
# self.final_decoders.append(nn.ModuleList(l))
layer1 = nn.Linear(32, n)
nn.init.xavier_uniform_(layer1.weight)
layer2 = nn.Linear(32, n)
nn.init.xavier_uniform_(layer2.weight)
layer3 = nn.Linear(32, n)
nn.init.xavier_uniform_(layer3.weight)
self.final_decoders.append(nn.ModuleList([layer0, bn0, act0, layer1, layer2, layer3]))
def forward(self, x):
x = self.act1(self.bn1(self.decode1(x)))
# This is the reverse operation of cat
x_chunked = torch.chunk(x, chunks=len(self.num_outputs), dim=1)
retval1, retval2, retval3 = [], [], []
for chunk, processors in zip(x_chunked, self.final_decoders):
# Each processor is a list of 3 different decoders
# decode1, bn1, act1, *output_decoders = processors
decode1, bn1, act1, decode21, decode22, decode23 = processors
chunk = act1(bn1(decode1(chunk)))
temp1 = decode21(chunk)
temp2 = decode22(chunk)
temp3 = decode23(chunk)
if "act1" in self.final_activations.keys():
# temp1 = output_decoders[0](chunk)
temp1 = self.final_activations["act1"](temp1)
# retval1.append(temp1)
if "act2" in self.final_activations.keys():
# temp2 = output_decoders[1](chunk)
temp2 = self.final_activations["act2"](temp2)
# retval2.append(temp2)
if "act3" in self.final_activations.keys():
# temp3 = output_decoders[2](chunk)
temp3 = self.final_activations["act3"](temp3)
# retval3.append(temp3)
retval1.append(temp1)
retval2.append(temp2)
retval3.append(temp3)
retval1 = torch.cat(retval1, dim=1)
retval2 = torch.cat(retval2, dim=1)
retval3 = torch.cat(retval3, dim=1)
return retval1, retval2, retval3
# retval = []
# if retval1:
# retval.append(torch.cat(retval1, dim=1))
# if retval2:
# retval.append(torch.cat(retval2, dim=1))
# if retval3:
# retval.append(torch.cat(retval3, dim=1))
# return tuple(retval)
class AutoEncoder(nn.Module):
"""Vanilla autoencoder."""
def __init__(
self,
num_inputs: int,
num_units: int = 32,
num_outputs: int = None,
activation=nn.PReLU,
final_activation=None,
seed=8947,
output_encoded: bool = True,
):
super().__init__()
torch.manual_seed(seed)
self.output_encoded = output_encoded
self.num_inputs = num_inputs
self.num_outputs = self.num_inputs if num_outputs is None else num_outputs
self.num_units = num_units
self.encoder = Encoder(self.num_inputs, num_units=self.num_units, activation=activation)
self.decoder = Decoder(
self.num_outputs,
num_units=self.num_units,
activation=activation,
final_activation=final_activation,
)
def forward(self, X):
encoded = self.encoder(X)
decoded = self.decoder(encoded)[0] # Drop the second output that corresponds to addtl parameters
if self.output_encoded:
return decoded, encoded
return decoded
class PairedAutoEncoder(nn.Module):
"""Paired autoencoder Supports cross-domain prediction by naively combining/swapping
encoder and decoder."""
def __init__(self, model1, model2):
super().__init__()
self.model1 = model1
self.model2 = model2
def forward(self, x):
"""X is expected to be a tuple of two inputs."""
x1, x2 = x
y1 = self.model1(x1)
y2 = self.model2(x2)
return (y1, y2)
def translate_1_to_2(self, encoded1):
"""Given data from domain 1 output domain 2."""
output2 = self.model2.from_encoded(encoded1)
return output2
def translate_2_to_1(self, encoded2):
"""Given data fromd omain 2 output domain 1."""
output1 = self.model1.from_encoded(encoded2)
return output1
class SplicedAutoEncoder(nn.Module):
"""
Spliced Autoencoder - where we have 4 parts (2 encoders, 2 decoders) that are all combined
This does not work when you have chromsome split features
"""
def __init__(
self,
input_dim1: int,
input_dim2: int,
hidden_dim: int = 16,
final_activations1: list = [Exp(), nn.Softplus()],
final_activations2=[Exp(), nn.Softplus(), nn.Sigmoid()],
flat_mode: bool = False, # Controls if we have to re-split inputs
seed=182822,
):
super().__init__()
torch.manual_seed(seed)
self.flat_mode = flat_mode
self.input_dim1 = input_dim1
self.input_dim2 = input_dim2
self.num_outputs1 = (len(final_activations1) if isinstance(final_activations1, (list, set, tuple)) else 1)
self.num_outputs2 = (len(final_activations2) if isinstance(final_activations2, (list, set, tuple)) else 1)
self.encoder1 = Encoder(num_inputs=input_dim1, num_units=hidden_dim)
self.encoder2 = Encoder(num_inputs=input_dim2, num_units=hidden_dim)
self.decoder1 = Decoder(
num_outputs=input_dim1,
num_units=hidden_dim,
final_activation=final_activations1,
)
self.decoder2 = Decoder(
num_outputs=input_dim2,
num_units=hidden_dim,
final_activation=final_activations2,
)
def split_catted_input(self, x):
"""Split catted input data to expected sizes."""
return torch.split(x, [self.input_dim1, self.input_dim2], dim=-1)
def _combine_output_and_encoded(self, decoded, encoded, num_outputs: int):
"""Combines the output and encoded in a single output."""
if num_outputs > 1:
retval = *decoded, encoded
else:
if isinstance(decoded, tuple):
decoded = decoded[0]
retval = decoded, encoded
assert isinstance(retval, (list, tuple))
assert isinstance(retval[0], (torch.TensorType, torch.Tensor)), f"Expected tensor but got {type(retval[0])}"
return retval
def forward_single(self, x, size_factors=None, in_domain: int = 1, out_domain: int = 1):
"""Return output of a single domain combination."""
assert in_domain in [1, 2] and out_domain in [1, 2]
encoder = self.encoder1 if in_domain == 1 else self.encoder2
decoder = self.decoder1 if out_domain == 1 else self.decoder2
num_non_latent_out = self.num_outputs1 if out_domain == 1 else self.num_outputs2
encoded = encoder(x)
decoded = decoder(encoded)
return self._combine_output_and_encoded(decoded, encoded, num_non_latent_out)
def forward(self, x, size_factors=None, mode: Union[None, Tuple[int, int]] = None):
if self.flat_mode:
x = self.split_catted_input(x)
assert isinstance(x, (tuple, list))
assert len(x) == 2, "There should be two inputs to spliced autoencoder"
encoded1 = self.encoder1(x[0])
encoded2 = self.encoder2(x[1])
decoded11 = self.decoder1(encoded1)
retval11 = self._combine_output_and_encoded(decoded11, encoded1, self.num_outputs1)
decoded12 = self.decoder2(encoded1)
retval12 = self._combine_output_and_encoded(decoded12, encoded1, self.num_outputs2)
decoded22 = self.decoder2(encoded2)
retval22 = self._combine_output_and_encoded(decoded22, encoded2, self.num_outputs2)
decoded21 = self.decoder1(encoded2)
retval21 = self._combine_output_and_encoded(decoded21, encoded2, self.num_outputs1)
if mode is None:
return retval11, retval12, retval21, retval22
retval_dict = {
(1, 1): retval11,
(1, 2): retval12,
(2, 1): retval21,
(2, 2): retval22,
}
if mode not in retval_dict:
raise ValueError(f"Invalid mode code: {mode}")
return retval_dict[mode]
class NaiveSplicedAutoEncoder(SplicedAutoEncoder):
"""Naive "spliced" autoencoder that does not use shared branches and instead simply
trains four separate models."""
def __init__(
self,
input_dim1: int,
input_dim2: int,
hidden_dim: int = 16,
final_activations1: Union[Callable, List[Callable]] = [
Exp(),
ClippedSoftplus(),
],
final_activations2: Union[Callable, List[Callable]] = nn.Sigmoid(),
flat_mode: bool = True, # Controls if we have to re-split inputs
seed: int = 182822,
):
nn.Module.__init__(self)
torch.manual_seed(seed)
self.flat_mode = flat_mode
self.input_dim1 = input_dim1
self.input_dim2 = input_dim2
self.num_outputs1 = (len(final_activations1) if isinstance(final_activations1, (list, set, tuple)) else 1)
self.num_outputs2 = (len(final_activations2) if isinstance(final_activations2, (list, set, tuple)) else 1)
self.model11 = nn.ModuleList([
Encoder(num_inputs=input_dim1, num_units=hidden_dim),
Decoder(
num_outputs=input_dim1,
num_units=hidden_dim,
final_activation=final_activations1,
),
])
self.model12 = nn.ModuleList([
Encoder(num_inputs=input_dim1, num_units=hidden_dim),
ChromDecoder(
num_outputs=input_dim2,
latent_dim=hidden_dim,
final_activations=final_activations2,
),
])
self.model22 = nn.ModuleList([
ChromEncoder(num_inputs=input_dim2, latent_dim=hidden_dim),
ChromDecoder(
num_outputs=input_dim2,
latent_dim=hidden_dim,
final_activations=final_activations2,
),
])
self.model21 = nn.ModuleList([
ChromEncoder(num_inputs=input_dim2, latent_dim=hidden_dim),
Decoder(
num_outputs=input_dim1,
num_units=hidden_dim,
final_activation=final_activations1,
),
])
# Each of these has exactly two items
self.modlist_dict = nn.ModuleDict({
"11": self.model11,
"12": self.model12,
"22": self.model22,
"21": self.model21,
})
def split_catted_input(self, x):
"""Split the input into chunks that goes to each input to model."""
a, b = torch.split(x, [self.input_dim1, sum(self.input_dim2)], dim=-1)
return (a, torch.split(b, self.input_dim2, dim=-1))
def forward_single(self, x, size_factors=None, in_domain: int = 1, out_domain: int = 1):
"""Return output of a single domain combination."""
num_non_latent_out = self.num_outputs1 if out_domain == 1 else self.num_outputs2
modlist = self.modlist_dict[str(in_domain) + str(out_domain)]
encoded = modlist[0](x)
decoded = modlist[1](encoded)
if num_non_latent_out > 1:
retval = *decoded, encoded
else:
if isinstance(decoded, tuple):
decoded = decoded[0]
retval = decoded, encoded
assert isinstance(retval, tuple)
assert isinstance(retval[0], (torch.TensorType, torch.Tensor))
return retval
def forward(self, x, size_factors=None, mode: Union[None, Tuple[int, int]] = None):
if self.flat_mode:
x = self.split_catted_input(x)
assert isinstance(x, (tuple, list))
assert len(x) == 2, "There should be two inputs to spliced autoencoder"
retval11 = self.forward_single(x[0], size_factors=size_factors, in_domain=1, out_domain=1)
retval12 = self.forward_single(x[0], size_factors=size_factors, in_domain=1, out_domain=2)
retval21 = self.forward_single(x[1], size_factors=size_factors, in_domain=2, out_domain=1)
retval22 = self.forward_single(x[1], size_factors=size_factors, in_domain=2, out_domain=2)
if mode is None:
return retval11, retval12, retval21, retval22
retval_dict = {
(1, 1): retval11,
(1, 2): retval12,
(2, 1): retval21,
(2, 2): retval22,
}
if mode not in retval_dict:
raise ValueError(f"Invalid mode code: {mode}")
return retval_dict[mode]
class AssymSplicedAutoEncoder(SplicedAutoEncoder):
"""Assymmetric spliced autoencoder where branch 2 is a chrom AE."""
def __init__(
self,
input_dim1: int,
input_dim2: List[int],
hidden_dim: int = 16,
final_activations1: list = [Exp(), ClippedSoftplus()],
final_activations2=nn.Sigmoid(),
flat_mode: bool = True, # Controls if we have to re-split inputs
seed: int = 182822,
):
# https://stackoverflow.com/questions/9575409/calling-parent-class-init-with-multiple-inheritance-whats-the-right-way
nn.Module.__init__(self)
torch.manual_seed(seed)
self.flat_mode = flat_mode
self.input_dim1 = input_dim1
self.input_dim2 = input_dim2
self.num_outputs1 = (len(final_activations1) if isinstance(final_activations1, (list, set, tuple)) else 1)
self.num_outputs2 = (len(final_activations2) if isinstance(final_activations2, (list, set, tuple)) else 1)
self.encoder1 = Encoder(num_inputs=input_dim1, num_units=hidden_dim)
self.encoder2 = ChromEncoder(num_inputs=input_dim2, latent_dim=hidden_dim)
self.decoder1 = Decoder(
num_outputs=input_dim1,
num_units=hidden_dim,
final_activation=final_activations1,
)
self.decoder2 = ChromDecoder(
num_outputs=input_dim2,
latent_dim=hidden_dim,
final_activations=final_activations2,
)
def split_catted_input(self, x):
"""Split the input into chunks that goes to each input to model."""
a, b = torch.split(x, [self.input_dim1, sum(self.input_dim2)], dim=-1)
return (a, torch.split(b, self.input_dim2, dim=-1))
[docs]class BabelWrapper:
"""Babel class.
Parameters
----------
args : argparse.Namespace
A Namespace object that contains arguments of Babel. For details of parameters in parser args, please refer to
link (parser help document).
dim_in : int
Input dimension.
dim_out: int
Output dimension.
"""
def __init__(self, args, dim_in, dim_out):
self.args = args
model_class = (NaiveSplicedAutoEncoder if args.naive else AssymSplicedAutoEncoder)
self.model = model_class(
hidden_dim=args.hidden,
input_dim1=dim_in,
input_dim2=[dim_out],
final_activations1=nn.ReLU(),
final_activations2=nn.ReLU(),
flat_mode=True,
seed=args.seed,
).to(args.device)
[docs] def to(self, device):
"""Performs device conversion.
Parameters
----------
device : str
Target device.
Returns
-------
self : BabelWrapper
Converted model.
"""
self.model.to(device)
[docs] def score(self, test_mod1, test_mod2):
"""Score function to get score of prediction.
Parameters
----------
test_mod1 : torch.Tensor
Input modality features.
test_mod2 : torch.Tensor
Target modality features.
Returns
-------
score : float
RMSE loss of prediction.
"""
mse = nn.MSELoss()
self.model.eval()
with torch.no_grad():
pred = self.predict(test_mod1)
score = math.sqrt(mse(pred, test_mod2.to(self.args.device)))
return score
[docs] def predict(self, test_mod1):
"""Predict function to get prediction of target modality features.
Parameters
----------
test_mod1 : torch.Tensor
Input modality features.
Returns
-------
pred : torch.Tensor
Prediction of target modality features.
"""
self.model.eval()
with torch.no_grad():
if self.args.naive:
emb = self.model.model12[0](test_mod1.to(self.args.device))
pred = self.model.model12[1](emb)[0]
else:
emb = self.model.encoder1(test_mod1.to(self.args.device))
pred = self.model.decoder2(emb)[0]
return pred
[docs] def fit(self, x_train, y_train, max_epochs=500, val_ratio=0.15):
"""Fit function for training.
Parameters
----------
x_train : torch.Tensor
Training input modality.
y_train : torch.Tensor
Training output modality.
max_epochs : int optional
Maximum number of training epochs, by default to be 500.
val_ratio : int
Validation ratio.
"""
criterion = loss_functions.QuadLoss(loss1=loss_functions.RMSELoss, loss2=loss_functions.RMSELoss,
loss2_weight=self.args.lossweight)
device = self.args.device
total_size = x_train.shape[0]
val_size = int(total_size * val_ratio)
rand_idx = torch.randperm(total_size)
train_idx = rand_idx[:-val_size]
val_idx = rand_idx[-val_size:]
train_loader = DataLoader(torch.hstack((x_train[train_idx], y_train[train_idx])), shuffle=True,
batch_size=self.args.batchsize)
val_loader = DataLoader(torch.hstack((x_train[val_idx], y_train[val_idx])), batch_size=self.args.batchsize)
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)
val = []
for i in range(max_epochs):
self.model.train()
total_loss = 0
for train_batch in train_loader:
logits = self.model(train_batch.to(device))
loss = criterion(logits, train_batch.to(device))
total_loss += loss.item()
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5)
optimizer.step()
mse = nn.MSELoss()
self.model.eval()
with torch.no_grad():
loss = 0
for val_batch in val_loader:
logits = self.model(val_batch.to(device))
loss += mse(logits[1][0], val_batch[:, -logits[1][0].shape[1]:].to(device)).item()
val.append(math.sqrt(loss / len(val_loader)))
print('epoch: ', i + 1)
print('training (sum of 4 losses):', total_loss / len(train_loader))
print('validation (prediction loss):', val[-1])
if min(val) == val[-1]:
torch.save(self.model.state_dict(), f'{self.args.outdir}/BABEL_best_{self.args.seed}.pth')
best_dict = deepcopy(self.model.state_dict())
if i > self.args.earlystop and min(val) != min(val[-self.args.earlystop:]):
print('Early stopped.')
break
self.model.load_state_dict(best_dict)