Source code for dance.modules.multi_modality.predict_modality.babel

"""Reimplementation of BABEL method.

Extended from https://github.com/wukevin/babel

Reference
---------
Wu, Kevin E., Kathryn E. Yost, Howard Y. Chang, and James Zou. "BABEL enables cross-modality translation between
multiomic profiles at single-cell resolution." Proceedings of the National Academy of Sciences 118, no. 15 (2021).

"""
import math
from copy import deepcopy
from typing import Callable, List, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import dance.utils.loss as loss_functions
from dance import logger

REDUCE_LR_ON_PLATEAU_PARAMS = {
    "mode": "min",
    "factor": 0.1,
    "patience": 10,
    "min_lr": 1e-6,
}


def recursive_to_device(t, device="cpu"):
    """Recursively transfer t to the given device."""
    if isinstance(t, tuple) or isinstance(t, list):
        return tuple(recursive_to_device(x, device=device) for x in t)
    return t.to(device)


class Exp(nn.Module):
    """Applies torch.exp, clamped to improve stability during training."""

    def __init__(self, minimum=1e-5, maximum=1e6):
        """Values taken from DCA."""
        super().__init__()
        self.min_value = minimum
        self.max_value = maximum

    def forward(self, input):
        return torch.clamp(
            torch.exp(input),
            min=self.min_value,
            max=self.max_value,
        )


class ClippedSoftplus(nn.Module):

    def __init__(self, beta=1, threshold=20, minimum=1e-4, maximum=1e3):
        super().__init__()
        self.beta = beta
        self.threshold = threshold
        self.min_value = minimum
        self.max_value = maximum

    def forward(self, input):
        return torch.clamp(
            F.softplus(input, self.beta, self.threshold),
            min=self.min_value,
            max=self.max_value,
        )

    def extra_repr(self):
        return "beta={}, threshold={}, min={}, max={}".format(
            self.beta,
            self.threshold,
            self.min_value,
            self.max_value,
        )


class DeepCountAutoencoder(nn.Module):
    """Replicate of DCA."""

    def __init__(
        self,
        input_dim: int,
        inter_dim: int = 64,
        bottle_dim: int = 32,
        output_dim: int = None,
        mode: str = "zinb",
    ):
        super().__init__()
        torch.manual_seed(1234)  # Fixed initialization
        assert mode in ["zinb", "nb", "poisson"], f"Unrecognized mode: {mode}"
        self.mode = mode

        self.encoder = nn.Linear(input_dim, inter_dim)
        nn.init.xavier_uniform_(self.encoder.weight)  # Aka Glorot initialization
        self.bn1 = nn.BatchNorm1d(num_features=inter_dim)

        self.bottleneck = nn.Linear(inter_dim, bottle_dim)
        nn.init.xavier_uniform_(self.bottleneck.weight)
        self.bn2 = nn.BatchNorm1d(num_features=bottle_dim)

        self.decoder = nn.Linear(bottle_dim, inter_dim)
        nn.init.xavier_uniform_(self.decoder.weight)
        self.bn3 = nn.BatchNorm1d(num_features=inter_dim)

        output_dim = output_dim if output_dim is not None else input_dim
        self.mean = nn.Linear(inter_dim, output_dim)  # These parameters are all per-gene
        nn.init.xavier_uniform_(self.mean.weight)
        if "nb" in self.mode:
            self.disp = nn.Linear(inter_dim, output_dim)
            nn.init.xavier_uniform_(self.disp.weight)
        if self.mode == "zinb":
            self.dropout = nn.Linear(inter_dim, output_dim)
            nn.init.xavier_uniform_(self.dropout.weight)

        self.forward = self.forward_with_decode  # Default to returning denoised

    def encode(self, x):
        """Given input, return bottlenecked latent representation."""
        x = F.relu(self.bn1(self.encoder(x)))
        x = F.relu(self.bn2(self.bottleneck(x)))
        return x

    def decode(self, x, size_factors):
        """Given latent representation, output 3 output layers (mean, dispersion,
        dropout)"""
        x = F.relu(self.bn3(self.decoder(x)))
        mu = torch.clamp(torch.exp(self.mean(x)), min=1e-5, max=1e6)  # Mean
        sf_scaled = size_factors.view(-1, 1).repeat(1, mu.shape[1])
        mu_scaled = mu * sf_scaled  # Elementwise multiplication

        if "nb" in self.mode:
            theta = torch.clamp(F.softplus(self.disp(x)), min=1e-4, max=1e3)  # Dispersion
        else:  # Is poisson
            return mu_scaled

        if self.mode == "zinb":
            pi = torch.sigmoid(self.dropout(x))  # Dropout
            return mu_scaled, theta, pi
        else:
            return mu_scaled, theta

    def forward_with_decode(self, x, size_factors):
        """Return the parameters mu, theta, and pi The denoised matrix is generated by
        replacing the original count values with the mean of the negative binomial
        component as predicted in the output layer.

        This matrix represents the denoised and library size normalized expression
        matrix, the final output of the method.

        """
        return self.decode(self.encode(x), size_factors)

    def forward_no_decode(self, x, _size_factors):
        """Return the hidden representation instead of decoding it as well Useful for
        probing the latent dimension using pytorch_eval Use by setting model.forward =
        model.forward_no_decode."""
        return self.encode(x)


class Encoder(nn.Module):

    def __init__(self, num_inputs: int, num_units=32, activation=nn.PReLU):
        super().__init__()
        self.num_inputs = num_inputs
        self.num_units = num_units

        self.encode1 = nn.Linear(self.num_inputs, 64)
        nn.init.xavier_uniform_(self.encode1.weight)
        self.bn1 = nn.BatchNorm1d(64)
        self.act1 = activation()

        self.encode2 = nn.Linear(64, self.num_units)
        nn.init.xavier_uniform_(self.encode2.weight)
        self.bn2 = nn.BatchNorm1d(num_units)
        self.act2 = activation()

    def forward(self, x):
        x = self.act1(self.bn1(self.encode1(x)))
        x = self.act2(self.bn2(self.encode2(x)))
        return x


class Decoder(nn.Module):

    def __init__(
        self,
        num_outputs: int,
        num_units: int = 32,
        intermediate_dim: int = 64,
        activation=nn.PReLU,
        final_activation=None,
    ):
        super().__init__()
        self.num_outputs = num_outputs
        self.num_units = num_units

        self.decode1 = nn.Linear(self.num_units, intermediate_dim)
        nn.init.xavier_uniform_(self.decode1.weight)
        self.bn1 = nn.BatchNorm1d(intermediate_dim)
        self.act1 = activation()

        self.decode21 = nn.Linear(intermediate_dim, self.num_outputs)
        nn.init.xavier_uniform_(self.decode21.weight)
        self.decode22 = nn.Linear(intermediate_dim, self.num_outputs)
        nn.init.xavier_uniform_(self.decode22.weight)
        self.decode23 = nn.Linear(intermediate_dim, self.num_outputs)
        nn.init.xavier_uniform_(self.decode23.weight)

        self.final_activations = nn.ModuleDict()
        if final_activation is not None:
            if isinstance(final_activation, list) or isinstance(final_activation, tuple):
                assert len(final_activation) <= 3
                for i, act in enumerate(final_activation):
                    if act is None:
                        continue
                    self.final_activations[f"act{i+1}"] = act
            elif isinstance(final_activation, nn.Module):
                self.final_activations["act1"] = final_activation
            else:
                raise ValueError(f"Unrecognized type for final_activation: {type(final_activation)}")

    def forward(self, x, size_factors=None):
        """Include size factor here because we may want to scale the output by that."""
        x = self.act1(self.bn1(self.decode1(x)))

        retval1 = self.decode21(x)  # This is invariably the counts
        if "act1" in self.final_activations.keys():
            retval1 = self.final_activations["act1"](retval1)
        if size_factors is not None:
            sf_scaled = size_factors.view(-1, 1).repeat(1, retval1.shape[1])
            retval1 = retval1 * sf_scaled  # Elementwise multiplication

        retval2 = self.decode22(x)
        if "act2" in self.final_activations.keys():
            retval2 = self.final_activations["act2"](retval2)

        retval3 = self.decode23(x)
        if "act3" in self.final_activations.keys():
            retval3 = self.final_activations["act3"](retval3)

        return retval1, retval2, retval3


class ChromEncoder(nn.Module):
    """Consumes multiple inputs (i.e. one feature vector for each chromosome) After
    processing everything to be the same dimensionality, concatenate to form a single
    latent dimension."""

    def __init__(self, num_inputs: List[int], latent_dim: int = 32, activation=nn.PReLU):
        super().__init__()
        self.num_inputs = num_inputs
        self.act = activation

        self.initial_modules = nn.ModuleList()
        for n in self.num_inputs:
            assert isinstance(n, int)
            layer1 = nn.Linear(n, 32)
            nn.init.xavier_uniform_(layer1.weight)
            bn1 = nn.BatchNorm1d(32)
            act1 = self.act()
            layer2 = nn.Linear(32, 16)
            nn.init.xavier_uniform_(layer2.weight)
            bn2 = nn.BatchNorm1d(16)
            act2 = self.act()
            self.initial_modules.append(nn.ModuleList([layer1, bn1, act1, layer2, bn2, act2]))

        self.encode2 = nn.Linear(16 * len(self.num_inputs), latent_dim)
        nn.init.xavier_uniform_(self.encode2.weight)
        self.bn2 = nn.BatchNorm1d(latent_dim)
        self.act2 = self.act()

    def forward(self, x):
        assert len(x) == len(self.num_inputs), f"Expected {len(self.num_inputs)} inputs but got {len(x)}"
        enc_chroms = []
        for init_mod, chrom_input in zip(self.initial_modules, x):
            for f in init_mod:
                chrom_input = f(chrom_input)
            enc_chroms.append(chrom_input)
        enc1 = torch.cat(enc_chroms, dim=1)  # Concatenate along the feature dim not batch dim
        enc2 = self.act2(self.bn2(self.encode2(enc1)))
        return enc2


class ChromDecoder(nn.Module):
    """Network that is per-chromosome aware, but does not does not output per- chromsome
    values, instead concatenating them into a single vector."""

    def __init__(
        self,
        num_outputs: List[int],  # Per-chromosome list of output sizes
        latent_dim: int = 32,
        activation=nn.PReLU,
        final_activations=[Exp(), ClippedSoftplus()],
    ):
        super().__init__()
        self.num_outputs = num_outputs
        self.latent_dim = latent_dim

        self.decode1 = nn.Linear(self.latent_dim, len(self.num_outputs) * 16)
        nn.init.xavier_uniform_(self.decode1.weight)
        self.bn1 = nn.BatchNorm1d(len(self.num_outputs) * 16)
        self.act1 = activation()

        self.final_activations = nn.ModuleDict()
        if final_activations is not None:
            if isinstance(final_activations, list) or isinstance(final_activations, tuple):
                assert len(final_activations) <= 3
                for i, act in enumerate(final_activations):
                    if act is None:
                        continue
                    self.final_activations[f"act{i+1}"] = act
            elif isinstance(final_activations, nn.Module):
                self.final_activations["act1"] = final_activations
            else:
                raise ValueError(f"Unrecognized type for final_activation: {type(final_activations)}")
        logger.info(f"ChromDecoder with {len(self.final_activations)} output activations")

        self.final_decoders = nn.ModuleList()  # List[List[Module]]
        for n in self.num_outputs:
            layer0 = nn.Linear(16, 32)
            nn.init.xavier_uniform_(layer0.weight)
            bn0 = nn.BatchNorm1d(32)
            act0 = activation()
            # l = [layer0, bn0, act0]
            # for _i in range(len(self.final_activations)):
            #     fc_layer = nn.Linear(32, n)
            #     nn.init.xavier_uniform_(fc_layer.weight)
            #     l.append(fc_layer)
            # self.final_decoders.append(nn.ModuleList(l))
            layer1 = nn.Linear(32, n)
            nn.init.xavier_uniform_(layer1.weight)
            layer2 = nn.Linear(32, n)
            nn.init.xavier_uniform_(layer2.weight)
            layer3 = nn.Linear(32, n)
            nn.init.xavier_uniform_(layer3.weight)
            self.final_decoders.append(nn.ModuleList([layer0, bn0, act0, layer1, layer2, layer3]))

    def forward(self, x):
        x = self.act1(self.bn1(self.decode1(x)))
        # This is the reverse operation of cat
        x_chunked = torch.chunk(x, chunks=len(self.num_outputs), dim=1)

        retval1, retval2, retval3 = [], [], []
        for chunk, processors in zip(x_chunked, self.final_decoders):
            # Each processor is a list of 3 different decoders
            # decode1, bn1, act1, *output_decoders = processors
            decode1, bn1, act1, decode21, decode22, decode23 = processors
            chunk = act1(bn1(decode1(chunk)))
            temp1 = decode21(chunk)
            temp2 = decode22(chunk)
            temp3 = decode23(chunk)

            if "act1" in self.final_activations.keys():
                # temp1 = output_decoders[0](chunk)
                temp1 = self.final_activations["act1"](temp1)
                # retval1.append(temp1)
            if "act2" in self.final_activations.keys():
                # temp2 = output_decoders[1](chunk)
                temp2 = self.final_activations["act2"](temp2)
                # retval2.append(temp2)
            if "act3" in self.final_activations.keys():
                # temp3 = output_decoders[2](chunk)
                temp3 = self.final_activations["act3"](temp3)
                # retval3.append(temp3)
            retval1.append(temp1)
            retval2.append(temp2)
            retval3.append(temp3)
        retval1 = torch.cat(retval1, dim=1)
        retval2 = torch.cat(retval2, dim=1)
        retval3 = torch.cat(retval3, dim=1)
        return retval1, retval2, retval3

        # retval = []
        # if retval1:
        #     retval.append(torch.cat(retval1, dim=1))
        # if retval2:
        #     retval.append(torch.cat(retval2, dim=1))
        # if retval3:
        #     retval.append(torch.cat(retval3, dim=1))
        # return tuple(retval)


class AutoEncoder(nn.Module):
    """Vanilla autoencoder."""

    def __init__(
        self,
        num_inputs: int,
        num_units: int = 32,
        num_outputs: int = None,
        activation=nn.PReLU,
        final_activation=None,
        seed=8947,
        output_encoded: bool = True,
    ):
        super().__init__()
        torch.manual_seed(seed)
        self.output_encoded = output_encoded
        self.num_inputs = num_inputs
        self.num_outputs = self.num_inputs if num_outputs is None else num_outputs
        self.num_units = num_units

        self.encoder = Encoder(self.num_inputs, num_units=self.num_units, activation=activation)
        self.decoder = Decoder(
            self.num_outputs,
            num_units=self.num_units,
            activation=activation,
            final_activation=final_activation,
        )

    def forward(self, X):
        encoded = self.encoder(X)
        decoded = self.decoder(encoded)[0]  # Drop the second output that corresponds to addtl parameters
        if self.output_encoded:
            return decoded, encoded
        return decoded


class PairedAutoEncoder(nn.Module):
    """Paired autoencoder Supports cross-domain prediction by naively combining/swapping
    encoder and decoder."""

    def __init__(self, model1, model2):
        super().__init__()
        self.model1 = model1
        self.model2 = model2

    def forward(self, x):
        """X is expected to be a tuple of two inputs."""
        x1, x2 = x
        y1 = self.model1(x1)
        y2 = self.model2(x2)
        return (y1, y2)

    def translate_1_to_2(self, encoded1):
        """Given data from domain 1 output domain 2."""
        output2 = self.model2.from_encoded(encoded1)
        return output2

    def translate_2_to_1(self, encoded2):
        """Given data fromd omain 2 output domain 1."""
        output1 = self.model1.from_encoded(encoded2)
        return output1


class SplicedAutoEncoder(nn.Module):
    """
    Spliced Autoencoder - where we have 4 parts (2 encoders, 2 decoders) that are all combined
    This does not work when you have chromsome split features
    """

    def __init__(
        self,
        input_dim1: int,
        input_dim2: int,
        hidden_dim: int = 16,
        final_activations1: list = [Exp(), nn.Softplus()],
        final_activations2=[Exp(), nn.Softplus(), nn.Sigmoid()],
        flat_mode: bool = False,  # Controls if we have to re-split inputs
        seed=182822,
    ):
        super().__init__()
        torch.manual_seed(seed)

        self.flat_mode = flat_mode
        self.input_dim1 = input_dim1
        self.input_dim2 = input_dim2
        self.num_outputs1 = (len(final_activations1) if isinstance(final_activations1, (list, set, tuple)) else 1)
        self.num_outputs2 = (len(final_activations2) if isinstance(final_activations2, (list, set, tuple)) else 1)

        self.encoder1 = Encoder(num_inputs=input_dim1, num_units=hidden_dim)
        self.encoder2 = Encoder(num_inputs=input_dim2, num_units=hidden_dim)

        self.decoder1 = Decoder(
            num_outputs=input_dim1,
            num_units=hidden_dim,
            final_activation=final_activations1,
        )
        self.decoder2 = Decoder(
            num_outputs=input_dim2,
            num_units=hidden_dim,
            final_activation=final_activations2,
        )

    def split_catted_input(self, x):
        """Split catted input data to expected sizes."""
        return torch.split(x, [self.input_dim1, self.input_dim2], dim=-1)

    def _combine_output_and_encoded(self, decoded, encoded, num_outputs: int):
        """Combines the output and encoded in a single output."""
        if num_outputs > 1:
            retval = *decoded, encoded
        else:
            if isinstance(decoded, tuple):
                decoded = decoded[0]
            retval = decoded, encoded
        assert isinstance(retval, (list, tuple))
        assert isinstance(retval[0], (torch.TensorType, torch.Tensor)), f"Expected tensor but got {type(retval[0])}"
        return retval

    def forward_single(self, x, size_factors=None, in_domain: int = 1, out_domain: int = 1):
        """Return output of a single domain combination."""
        assert in_domain in [1, 2] and out_domain in [1, 2]

        encoder = self.encoder1 if in_domain == 1 else self.encoder2
        decoder = self.decoder1 if out_domain == 1 else self.decoder2
        num_non_latent_out = self.num_outputs1 if out_domain == 1 else self.num_outputs2

        encoded = encoder(x)
        decoded = decoder(encoded)
        return self._combine_output_and_encoded(decoded, encoded, num_non_latent_out)

    def forward(self, x, size_factors=None, mode: Union[None, Tuple[int, int]] = None):
        if self.flat_mode:
            x = self.split_catted_input(x)
        assert isinstance(x, (tuple, list))
        assert len(x) == 2, "There should be two inputs to spliced autoencoder"
        encoded1 = self.encoder1(x[0])
        encoded2 = self.encoder2(x[1])

        decoded11 = self.decoder1(encoded1)
        retval11 = self._combine_output_and_encoded(decoded11, encoded1, self.num_outputs1)
        decoded12 = self.decoder2(encoded1)
        retval12 = self._combine_output_and_encoded(decoded12, encoded1, self.num_outputs2)
        decoded22 = self.decoder2(encoded2)
        retval22 = self._combine_output_and_encoded(decoded22, encoded2, self.num_outputs2)
        decoded21 = self.decoder1(encoded2)
        retval21 = self._combine_output_and_encoded(decoded21, encoded2, self.num_outputs1)

        if mode is None:
            return retval11, retval12, retval21, retval22
        retval_dict = {
            (1, 1): retval11,
            (1, 2): retval12,
            (2, 1): retval21,
            (2, 2): retval22,
        }
        if mode not in retval_dict:
            raise ValueError(f"Invalid mode code: {mode}")
        return retval_dict[mode]


class NaiveSplicedAutoEncoder(SplicedAutoEncoder):
    """Naive "spliced" autoencoder that does not use shared branches and instead simply
    trains four separate models."""

    def __init__(
        self,
        input_dim1: int,
        input_dim2: int,
        hidden_dim: int = 16,
        final_activations1: Union[Callable, List[Callable]] = [
            Exp(),
            ClippedSoftplus(),
        ],
        final_activations2: Union[Callable, List[Callable]] = nn.Sigmoid(),
        flat_mode: bool = True,  # Controls if we have to re-split inputs
        seed: int = 182822,
    ):
        nn.Module.__init__(self)
        torch.manual_seed(seed)

        self.flat_mode = flat_mode
        self.input_dim1 = input_dim1
        self.input_dim2 = input_dim2
        self.num_outputs1 = (len(final_activations1) if isinstance(final_activations1, (list, set, tuple)) else 1)
        self.num_outputs2 = (len(final_activations2) if isinstance(final_activations2, (list, set, tuple)) else 1)

        self.model11 = nn.ModuleList([
            Encoder(num_inputs=input_dim1, num_units=hidden_dim),
            Decoder(
                num_outputs=input_dim1,
                num_units=hidden_dim,
                final_activation=final_activations1,
            ),
        ])
        self.model12 = nn.ModuleList([
            Encoder(num_inputs=input_dim1, num_units=hidden_dim),
            ChromDecoder(
                num_outputs=input_dim2,
                latent_dim=hidden_dim,
                final_activations=final_activations2,
            ),
        ])
        self.model22 = nn.ModuleList([
            ChromEncoder(num_inputs=input_dim2, latent_dim=hidden_dim),
            ChromDecoder(
                num_outputs=input_dim2,
                latent_dim=hidden_dim,
                final_activations=final_activations2,
            ),
        ])
        self.model21 = nn.ModuleList([
            ChromEncoder(num_inputs=input_dim2, latent_dim=hidden_dim),
            Decoder(
                num_outputs=input_dim1,
                num_units=hidden_dim,
                final_activation=final_activations1,
            ),
        ])
        # Each of these has exactly two items
        self.modlist_dict = nn.ModuleDict({
            "11": self.model11,
            "12": self.model12,
            "22": self.model22,
            "21": self.model21,
        })

    def split_catted_input(self, x):
        """Split the input into chunks that goes to each input to model."""
        a, b = torch.split(x, [self.input_dim1, sum(self.input_dim2)], dim=-1)
        return (a, torch.split(b, self.input_dim2, dim=-1))

    def forward_single(self, x, size_factors=None, in_domain: int = 1, out_domain: int = 1):
        """Return output of a single domain combination."""
        num_non_latent_out = self.num_outputs1 if out_domain == 1 else self.num_outputs2
        modlist = self.modlist_dict[str(in_domain) + str(out_domain)]

        encoded = modlist[0](x)
        decoded = modlist[1](encoded)

        if num_non_latent_out > 1:
            retval = *decoded, encoded
        else:
            if isinstance(decoded, tuple):
                decoded = decoded[0]
            retval = decoded, encoded

        assert isinstance(retval, tuple)
        assert isinstance(retval[0], (torch.TensorType, torch.Tensor))
        return retval

    def forward(self, x, size_factors=None, mode: Union[None, Tuple[int, int]] = None):
        if self.flat_mode:
            x = self.split_catted_input(x)
        assert isinstance(x, (tuple, list))
        assert len(x) == 2, "There should be two inputs to spliced autoencoder"
        retval11 = self.forward_single(x[0], size_factors=size_factors, in_domain=1, out_domain=1)
        retval12 = self.forward_single(x[0], size_factors=size_factors, in_domain=1, out_domain=2)
        retval21 = self.forward_single(x[1], size_factors=size_factors, in_domain=2, out_domain=1)
        retval22 = self.forward_single(x[1], size_factors=size_factors, in_domain=2, out_domain=2)

        if mode is None:
            return retval11, retval12, retval21, retval22
        retval_dict = {
            (1, 1): retval11,
            (1, 2): retval12,
            (2, 1): retval21,
            (2, 2): retval22,
        }
        if mode not in retval_dict:
            raise ValueError(f"Invalid mode code: {mode}")
        return retval_dict[mode]


class AssymSplicedAutoEncoder(SplicedAutoEncoder):
    """Assymmetric spliced autoencoder where branch 2 is a chrom AE."""

    def __init__(
        self,
        input_dim1: int,
        input_dim2: List[int],
        hidden_dim: int = 16,
        final_activations1: list = [Exp(), ClippedSoftplus()],
        final_activations2=nn.Sigmoid(),
        flat_mode: bool = True,  # Controls if we have to re-split inputs
        seed: int = 182822,
    ):
        # https://stackoverflow.com/questions/9575409/calling-parent-class-init-with-multiple-inheritance-whats-the-right-way
        nn.Module.__init__(self)
        torch.manual_seed(seed)

        self.flat_mode = flat_mode
        self.input_dim1 = input_dim1
        self.input_dim2 = input_dim2
        self.num_outputs1 = (len(final_activations1) if isinstance(final_activations1, (list, set, tuple)) else 1)
        self.num_outputs2 = (len(final_activations2) if isinstance(final_activations2, (list, set, tuple)) else 1)

        self.encoder1 = Encoder(num_inputs=input_dim1, num_units=hidden_dim)
        self.encoder2 = ChromEncoder(num_inputs=input_dim2, latent_dim=hidden_dim)

        self.decoder1 = Decoder(
            num_outputs=input_dim1,
            num_units=hidden_dim,
            final_activation=final_activations1,
        )
        self.decoder2 = ChromDecoder(
            num_outputs=input_dim2,
            latent_dim=hidden_dim,
            final_activations=final_activations2,
        )

    def split_catted_input(self, x):
        """Split the input into chunks that goes to each input to model."""
        a, b = torch.split(x, [self.input_dim1, sum(self.input_dim2)], dim=-1)
        return (a, torch.split(b, self.input_dim2, dim=-1))


[docs]class BabelWrapper: """Babel class. Parameters ---------- args : argparse.Namespace A Namespace object that contains arguments of Babel. For details of parameters in parser args, please refer to link (parser help document). dim_in : int Input dimension. dim_out: int Output dimension. """ def __init__(self, args, dim_in, dim_out): self.args = args model_class = (NaiveSplicedAutoEncoder if args.naive else AssymSplicedAutoEncoder) self.model = model_class( hidden_dim=args.hidden, input_dim1=dim_in, input_dim2=[dim_out], final_activations1=nn.ReLU(), final_activations2=nn.ReLU(), flat_mode=True, seed=args.seed, ).to(args.device)
[docs] def to(self, device): """Performs device conversion. Parameters ---------- device : str Target device. Returns ------- self : BabelWrapper Converted model. """ self.model.to(device)
[docs] def score(self, test_mod1, test_mod2): """Score function to get score of prediction. Parameters ---------- test_mod1 : torch.Tensor Input modality features. test_mod2 : torch.Tensor Target modality features. Returns ------- score : float RMSE loss of prediction. """ mse = nn.MSELoss() self.model.eval() with torch.no_grad(): pred = self.predict(test_mod1) score = math.sqrt(mse(pred, test_mod2.to(self.args.device))) return score
[docs] def predict(self, test_mod1): """Predict function to get prediction of target modality features. Parameters ---------- test_mod1 : torch.Tensor Input modality features. Returns ------- pred : torch.Tensor Prediction of target modality features. """ self.model.eval() with torch.no_grad(): if self.args.naive: emb = self.model.model12[0](test_mod1.to(self.args.device)) pred = self.model.model12[1](emb)[0] else: emb = self.model.encoder1(test_mod1.to(self.args.device)) pred = self.model.decoder2(emb)[0] return pred
[docs] def fit(self, x_train, y_train, max_epochs=500, val_ratio=0.15): """Fit function for training. Parameters ---------- x_train : torch.Tensor Training input modality. y_train : torch.Tensor Training output modality. max_epochs : int optional Maximum number of training epochs, by default to be 500. val_ratio : int Validation ratio. """ criterion = loss_functions.QuadLoss(loss1=loss_functions.RMSELoss, loss2=loss_functions.RMSELoss, loss2_weight=self.args.lossweight) device = self.args.device total_size = x_train.shape[0] val_size = int(total_size * val_ratio) rand_idx = torch.randperm(total_size) train_idx = rand_idx[:-val_size] val_idx = rand_idx[-val_size:] train_loader = DataLoader(torch.hstack((x_train[train_idx], y_train[train_idx])), shuffle=True, batch_size=self.args.batchsize) val_loader = DataLoader(torch.hstack((x_train[val_idx], y_train[val_idx])), batch_size=self.args.batchsize) optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr) val = [] for i in range(max_epochs): self.model.train() total_loss = 0 for train_batch in train_loader: logits = self.model(train_batch.to(device)) loss = criterion(logits, train_batch.to(device)) total_loss += loss.item() optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5) optimizer.step() mse = nn.MSELoss() self.model.eval() with torch.no_grad(): loss = 0 for val_batch in val_loader: logits = self.model(val_batch.to(device)) loss += mse(logits[1][0], val_batch[:, -logits[1][0].shape[1]:].to(device)).item() val.append(math.sqrt(loss / len(val_loader))) print('epoch: ', i + 1) print('training (sum of 4 losses):', total_loss / len(train_loader)) print('validation (prediction loss):', val[-1]) if min(val) == val[-1]: torch.save(self.model.state_dict(), f'{self.args.outdir}/BABEL_best_{self.args.seed}.pth') best_dict = deepcopy(self.model.state_dict()) if i > self.args.earlystop and min(val) != min(val[-self.args.earlystop:]): print('Early stopped.') break self.model.load_state_dict(best_dict)