Source code for dance.datasets.multimodality

import os
import os.path as osp
import pickle
from abc import ABC

import anndata as ad
import mudata as md
import numpy as np
import scanpy as sc
import scipy
import scipy.sparse as sp
import sklearn
from sklearn.utils import issparse

from dance import logger
from dance.data import Data
from dance.datasets.base import BaseDataset
from dance.registry import register_dataset
from dance.transforms.preprocess import lsiTransformer
from dance.typing import List
from dance.utils import is_numeric
from dance.utils.download import download_file, unzip_file


[docs]class MultiModalityDataset(BaseDataset, ABC): TASK = "N/A" URL_DICT = {} SUBTASK_NAME_MAP = {} AVAILABLE_DATA = [] def __init__(self, subtask, root="./data"): assert subtask in self.AVAILABLE_DATA, f"Undefined subtask {subtask!r}." assert self.TASK in ["predict_modality", "match_modality", "joint_embedding"] self.subtask = self.SUBTASK_NAME_MAP.get(subtask, subtask) self.data_url = self.URL_DICT[self.subtask] super().__init__(root=root, full_download=False)
[docs] def download(self): self.download_data()
def download_data(self): download_file(self.data_url, osp.join(self.root, f"{self.subtask}.zip")) unzip_file(osp.join(self.root, f"{self.subtask}.zip"), self.root) def download_pathway(self): download_file("https://www.dropbox.com/s/uqoakpalr3albiq/h.all.v7.4.entrez.gmt?dl=1", osp.join(self.root, "h.all.v7.4.entrez.gmt")) download_file("https://www.dropbox.com/s/yjrcsd2rpmahmfo/h.all.v7.4.symbols.gmt?dl=1", osp.join(self.root, "h.all.v7.4.symbols.gmt")) @property def data_paths(self) -> List[str]: if self.TASK == "joint_embedding": mod = "adt" if "cite" in self.subtask else "atac" meta = "cite" if "cite" in self.subtask else "multiome" paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_mod2.h5ad"), osp.join(self.root, self.subtask, f"{meta}_gex_processed_training.h5ad"), osp.join(self.root, self.subtask, f"{meta}_{mod}_processed_training.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_solution.h5ad"), ] if self.subtask.startswith("GSE140203"): paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_solution.h5ad"), ] if self.subtask.startswith("openproblems_2022"): paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_solution.h5ad"), ] elif self.TASK == "predict_modality": paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_mod2.h5ad") ] if self.subtask == "10k_pbmc": paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.10kanti_dataset_subset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.10kanti_dataset_subset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.10kanti_dataset_subset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.10kanti_dataset_subset.output_test_mod2.h5ad") ] if self.subtask == "pbmc_cite": paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.citeanti_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.citeanti_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.citeanti_dataset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.citeanti_dataset.output_test_mod2.h5ad") ] if self.subtask.startswith("5k_pbmc"): paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.5kanti_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.5kanti_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.5kanti_dataset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.5kanti_dataset.output_test_mod2.h5ad") ] if self.subtask.startswith("openproblems_2022"): paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_test_mod2.h5ad") ] if self.subtask.startswith("GSE127064"): paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.GSE126074_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE126074_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE126074_dataset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE126074_dataset.output_test_mod2.h5ad") ] if self.subtask.startswith("GSE117089"): paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.GSE117089_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE117089_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE117089_dataset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE117089_dataset.output_test_mod2.h5ad") ] if self.subtask.startswith("GSE140203"): paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_test_mod2.h5ad") ] elif self.TASK == "match_modality": paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_sol.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_sol.h5ad"), ] if self.subtask == "pbmc_cite": paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.citeanti_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.citeanti_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.citeanti_dataset.output_train_sol.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.citeanti_dataset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.citeanti_dataset.output_test_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.citeanti_dataset.output_test_sol.h5ad"), ] if self.subtask.startswith("openproblems_2022"): paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_train_sol.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_test_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_test_sol.h5ad"), ] if self.subtask.startswith("GSE127064"): paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.GSE126074_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE126074_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE126074_dataset.output_train_sol.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE126074_dataset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE126074_dataset.output_test_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE126074_dataset.output_test_sol.h5ad") ] if self.subtask.startswith("GSE117089"): paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.GSE117089_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE117089_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE117089_dataset.output_train_sol.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE117089_dataset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE117089_dataset.output_test_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE117089_dataset.output_test_sol.h5ad") ] if self.subtask.startswith("GSE140203"): paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_train_sol.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_test_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_test_sol.h5ad"), ] if self.subtask == "10k_pbmc": paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.10kanti_dataset_subset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.10kanti_dataset_subset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.10kanti_dataset_subset.output_train_sol.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.10kanti_dataset_subset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.10kanti_dataset_subset.output_test_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.10kanti_dataset_subset.output_test_sol.h5ad") ] if self.subtask.startswith("5k_pbmc"): paths = [ osp.join(self.root, self.subtask, f"{self.subtask}.5kanti_dataset.output_train_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.5kanti_dataset.output_train_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.5kanti_dataset.output_train_sol.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.5kanti_dataset.output_test_mod1.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.5kanti_dataset.output_test_mod2.h5ad"), osp.join(self.root, self.subtask, f"{self.subtask}.5kanti_dataset.output_test_sol.h5ad"), ] return paths
[docs] def is_complete(self) -> bool: return all(map(osp.exists, self.data_paths))
def _load_raw_data(self) -> List[ad.AnnData]: modalities = [] for mod_path in self.data_paths: logger.info(f"Loading {mod_path}") modalities.append(ad.read_h5ad(mod_path)) return modalities
[docs]@register_dataset("multimodality") class ModalityPredictionDataset(MultiModalityDataset): TASK = "predict_modality" URL_DICT = { "openproblems_bmmc_cite_phase2_mod2": "https://www.dropbox.com/s/snh8knscnlcq4um/openproblems_bmmc_cite_phase2_mod2.zip?dl=1", "openproblems_bmmc_cite_phase2_rna": "https://www.dropbox.com/s/xbfyhv830u9pupv/openproblems_bmmc_cite_phase2_rna.zip?dl=1", "openproblems_bmmc_multiome_phase2_mod2": "https://www.dropbox.com/s/p9ve2ljyy4yqna4/openproblems_bmmc_multiome_phase2_mod2.zip?dl=1", "openproblems_bmmc_multiome_phase2_rna": "https://www.dropbox.com/s/cz60vp7bwapz0kw/openproblems_bmmc_multiome_phase2_rna.zip?dl=1", "openproblems_bmmc_cite_phase2_rna_subset": "https://www.dropbox.com/s/veytldxkgzyoa8j/openproblems_bmmc_cite_phase2_rna_subset.zip?dl=1", "5k_pbmc": "https://www.dropbox.com/scl/fi/uoyis946glh0oo7g833qj/5k_pbmc.zip?rlkey=mw9cvqq7e12iowfbr9rp7av5u&dl=1", "5k_pbmc_subset": "https://www.dropbox.com/scl/fi/pykqc9zyt1fjypnjf4m1l/5k_pbmc_subset.zip?rlkey=brkmnqhfz5yl9axiuu0f8gmxy&dl=1", "10k_pbmc": "https://www.dropbox.com/scl/fi/npz3n36d3w089creppph2/10k_pbmc.zip?rlkey=6yyv61omv2rw7sqqmfp6u7m1s&dl=1", "pbmc_cite": "https://www.dropbox.com/scl/fi/8yvel9lu2f4pbemjeihzq/pbmc_cite.zip?rlkey=5f5jpjy1fcg14hwzot0hot7xd&dl=1", "openproblems_2022_multi_atac2gex": "https://www.dropbox.com/scl/fi/luzmc1jab7zvvxi2i4od5/openproblems_2022_multi_atac2gex.zip?rlkey=ht1acmhdpq8bbo1guevqgej5y&dl=1", "openproblems_2022_cite_gex2adt": "https://www.dropbox.com/scl/fi/ejioe3qqug0h2f7wvw9hq/openproblems_2022_cite_gex2adt.zip?rlkey=2f9kqz61s9ixdllgzic9tamc7&dl=1", "GSE127064_AdBrain_gex2atac": "https://www.dropbox.com/scl/fi/4ybsx6pgiuy6j9m0y92ly/GSE127064_AdBrain_gex2atac.zip?rlkey=6a5u7p7xr2dqsoduflzxjluja&dl=1", "GSE127064_p0Brain_gex2atac": "https://www.dropbox.com/scl/fi/k4p3nkkqq56ev6ljyo5se/GSE127064_p0Brain_gex2atac.zip?rlkey=y7kayqmk2l72jjogzlvfxtl74&dl=1", "GSE117089_mouse_gex2atac": "https://www.dropbox.com/scl/fi/hbo5eel8vtkctwhgelu5u/GSE117089_mouse_gex2atac.zip?rlkey=84t4kj1ls7ut09dpcbj86mtlc&dl=1", "GSE117089_sciCAR_gex2atac": "https://www.dropbox.com/scl/fi/hc0c48so824uohx0szs3h/GSE117089_sciCAR_gex2atac.zip?rlkey=4xjayirgijodd1fqcf7a42apo&dl=1", "GSE140203_3T3_HG19_atac2gex": "https://www.dropbox.com/scl/fi/v1vbypz87t1rz012vojkh/GSE140203_3T3_HG19_atac2gex.zip?rlkey=xmxrwso5e5ty3w53ctbm5bo9z&dl=1", "GSE140203_3T3_MM10_atac2gex": "https://www.dropbox.com/scl/fi/po9k064twny51subze6df/GSE140203_3T3_MM10_atac2gex.zip?rlkey=q0b4y58bsvacnjrmvsclk4jqu&dl=1", "GSE140203_12878.rep2_atac2gex": "https://www.dropbox.com/scl/fi/jqijimb7h6cv4w4hkax1q/GSE140203_12878.rep2_atac2gex.zip?rlkey=c837xkoacap4wjszffpfrmuak&dl=1", "GSE140203_12878.rep3_atac2gex": "https://www.dropbox.com/scl/fi/wlv9dhvylz78kq8ezncmd/GSE140203_12878.rep3_atac2gex.zip?rlkey=5r607plnqzlxdgxtc4le8d6o1&dl=1", "GSE140203_K562_HG19_atac2gex": "https://www.dropbox.com/scl/fi/n2he1br3u604p3mgniowz/GSE140203_K562_HG19_atac2gex.zip?rlkey=2lhe7s5run8ly5uk4b0vfemyj&dl=1", "GSE140203_K562_MM10_atac2gex": "https://www.dropbox.com/scl/fi/dhdorqy87915uah3xl07a/GSE140203_K562_MM10_atac2gex.zip?rlkey=ecwsy5sp7f2i2gtjo1qyaf4zt&dl=1", "GSE140203_LUNG_atac2gex": "https://www.dropbox.com/scl/fi/gabugiw244ky85j3ckq4d/GSE140203_LUNG_atac2gex.zip?rlkey=uj0we276s6ay2acpioj4tmfj3&dl=1" } SUBTASK_NAME_MAP = { "adt2gex": "openproblems_bmmc_cite_phase2_mod2", "atac2gex": "openproblems_bmmc_multiome_phase2_mod2", "gex2adt": "openproblems_bmmc_cite_phase2_rna", "gex2atac": "openproblems_bmmc_multiome_phase2_rna", "gex2adt_subset": "openproblems_bmmc_cite_phase2_rna_subset", } AVAILABLE_DATA = sorted(list(URL_DICT) + list(SUBTASK_NAME_MAP)) def __init__(self, subtask, root="./data", preprocess=None, span=0.3): # TODO: factor our preprocess self.preprocess = preprocess self.span = span super().__init__(subtask, root) def _raw_to_dance(self, raw_data): train_mod1, train_mod2, test_mod1, test_mod2 = self._maybe_preprocess(raw_data) mod1 = ad.concat((train_mod1, test_mod1)) mod2 = ad.concat((train_mod2, test_mod2)) mod1.var_names_make_unique() mod2.var_names_make_unique() mdata = md.MuData({"mod1": mod1, "mod2": mod2}) mdata.var_names_make_unique() data = Data(mdata, train_size=train_mod1.shape[0]) data.set_config(feature_mod="mod1", label_mod="mod2") return data def _maybe_preprocess(self, raw_data, selection_threshold=10000): changed_count = 0 # keep track to modified entries due to ensuring count data type for i in range(4): m_data = raw_data[i].X int_data = m_data.astype(int) changed_count += np.sum(int_data != m_data) raw_data[i].X = int_data raw_data[i].layers["counts"] = raw_data[i].X if changed_count > 0: logger.warning("Implicit modification: to ensure count (integer type) data, " f"a total number of {changed_count} entries were modified.") if self.preprocess == "feature_selection": if raw_data[0].shape[1] > selection_threshold: sc.pp.highly_variable_genes(raw_data[0], layer="counts", flavor="seurat_v3", n_top_genes=selection_threshold, span=self.span) raw_data[2].var["highly_variable"] = raw_data[0].var["highly_variable"] for i in [0, 2]: raw_data[i] = raw_data[i][:, raw_data[i].var["highly_variable"]] elif self.preprocess not in (None, "none"): logger.info(f"Preprocessing method {self.preprocess!r} not supported.") logger.info("Preprocessing done.") return raw_data
[docs]@register_dataset("multimodality") class ModalityMatchingDataset(MultiModalityDataset): TASK = "match_modality" URL_DICT = { "openproblems_bmmc_cite_phase2_mod2": "https://www.dropbox.com/s/fa6zut89xx73itz/openproblems_bmmc_cite_phase2_mod2.zip?dl=1", "openproblems_bmmc_cite_phase2_rna": "https://www.dropbox.com/s/ep00mqcjmdu0b7v/openproblems_bmmc_cite_phase2_rna.zip?dl=1", "openproblems_bmmc_multiome_phase2_mod2": "https://www.dropbox.com/s/31qi5sckx768acw/openproblems_bmmc_multiome_phase2_mod2.zip?dl=1", "openproblems_bmmc_multiome_phase2_rna": "https://www.dropbox.com/s/h1s067wkefs1jh2/openproblems_bmmc_multiome_phase2_rna.zip?dl=1", "openproblems_bmmc_cite_phase2_rna_subset": "https://www.dropbox.com/s/3q4xwpzjbe81x58/openproblems_bmmc_cite_phase2_rna_subset.zip?dl=1", "pbmc_cite": "https://www.dropbox.com/scl/fi/eq9eg6hzoqj2plgi2003k/pbmc_cite.zip?rlkey=p7bgttr7v99jxu3qem7sh8qrh&dl=1", "openproblems_2022_multi_atac2gex_subset": "https://www.dropbox.com/scl/fi/2p8izdu5xwvgm705hdf16/openproblems_2022_multi_atac2gex_subset.zip?rlkey=v962rncxmc9jqab2vhk3438sp&dl=1", "openproblems_2022_cite_gex2adt_subset": "https://www.dropbox.com/scl/fi/o9ht00cqgkxwgixtaydxm/openproblems_2022_cite_gex2adt_subset.zip?rlkey=sqnodvi25btk1igowww2pen8h&dl=1", "5k_pbmc_subset": "https://www.dropbox.com/scl/fi/rhyzaqtxpkvcu2za8mqaq/5k_pbmc_subset.zip?rlkey=g019vyku5let92z814dor287w&dl=1", "10k_pbmc": "https://www.dropbox.com/scl/fi/1wi9u5zwzx7td9akk1cri/10k_pbmc.zip?rlkey=u9ir7b6d8s3t29sk2hu7v29au&dl=1", "GSE117089_mouse_gex2atac": "https://www.dropbox.com/scl/fi/dbxgretuwq1zekxibb2p0/GSE117089_mouse_gex2atac.zip?rlkey=wzqi309on9v1wllkiatnkpnhv&dl=1", "GSE117089_sciCAR_gex2atac": "https://www.dropbox.com/scl/fi/4sohkymkqyry5xkx34oiw/GSE117089_sciCAR_gex2atac.zip?rlkey=6exg6ybf5ufhagycj5g7hq5vi&dl=1", "GSE127064_AdBrain_gex2atac": "https://www.dropbox.com/scl/fi/mktue5y4bsf9w17t7jyq3/GSE127064_AdBrain_gex2atac.zip?rlkey=3qtazuova6v1rin630keryman&dl=1", "GSE127064_p0Brain_gex2atac": "https://www.dropbox.com/scl/fi/anlukciivt5ah4i9v5q8s/GSE127064_p0Brain_gex2atac.zip?rlkey=9q12rwqgbz2z45dkwz372grgk&dl=1", "GSE140203_3T3_HG19_atac2gex": "https://www.dropbox.com/scl/fi/840hsqkcbis0t35i04kdi/GSE140203_3T3_HG19_atac2gex.zip?rlkey=gurncv741zi4q6dqb9q293zsl&dl=1", "GSE140203_3T3_MM10_atac2gex": "https://www.dropbox.com/scl/fi/chtl13dchlteilm2hky7r/GSE140203_3T3_MM10_atac2gex.zip?rlkey=su1itxejsyzkqcxjngb1xbunj&dl=1", "GSE140203_12878.rep2_atac2gex": "https://www.dropbox.com/scl/fi/9axnm23b554tn7uenf98q/GSE140203_12878.rep2_atac2gex.zip?rlkey=dplthpb82qhvnh9fann5o1gvb&dl=1", "GSE140203_12878.rep3_atac2gex": "https://www.dropbox.com/scl/fi/1zgc35dbl1pyrwrqfmtj8/GSE140203_12878.rep3_atac2gex.zip?rlkey=lwkx6iv2z584m1315gqpcomw9&dl=1", "GSE140203_K562_HG19_atac2gex": "https://www.dropbox.com/scl/fi/kro3384oium84fdr46l77/GSE140203_K562_HG19_atac2gex.zip?rlkey=f9kyx8rz4o7tgf8vts64d5rpi&dl=1", "GSE140203_K562_MM10_atac2gex": "https://www.dropbox.com/scl/fi/2dwn8zzhaq86ojkfgh29q/GSE140203_K562_MM10_atac2gex.zip?rlkey=ek94g5d9w0xrafp72z9jx5wty&dl=1", "GSE140203_LUNG_atac2gex": "https://www.dropbox.com/scl/fi/zb7igtgg835pg73ec7f28/GSE140203_LUNG_atac2gex.zip?rlkey=19ohnkxpj1temqnfxje6jae8w&dl=1", } SUBTASK_NAME_MAP = { "adt2gex": "openproblems_bmmc_cite_phase2_mod2", "atac2gex": "openproblems_bmmc_multiome_phase2_mod2", "gex2adt": "openproblems_bmmc_cite_phase2_rna", "gex2atac": "openproblems_bmmc_multiome_phase2_rna", "gex2adt_subset": "openproblems_bmmc_cite_phase2_rna_subset", } AVAILABLE_DATA = sorted(list(URL_DICT) + list(SUBTASK_NAME_MAP)) def __init__(self, subtask, root="./data", preprocess=None, pkl_path=None, span=0.3): # TODO: factor our preprocess self.preprocess = preprocess self.pkl_path = pkl_path self.span = span super().__init__(subtask, root) def _raw_to_dance(self, raw_data): train_mod1, train_mod2, train_label, test_mod1, test_mod2, test_label = self._maybe_preprocess(raw_data) # Align matched cells train_mod2 = train_mod2[train_label.to_df().values.argmax(1)] mod1 = ad.concat((train_mod1, test_mod1)) mod2 = ad.concat((train_mod2, test_mod2)) mod1.var_names_make_unique() mod2.var_names_make_unique() mod2.obs_names = mod1.obs_names train_size = train_mod1.shape[0] mod1.obsm["labels"] = np.concatenate([np.zeros(train_size), np.argmax(test_label.X.toarray(), 1)]) # Combine modalities into mudata mdata = md.MuData({"mod1": mod1, "mod2": mod2}) mdata.var_names_make_unique() data = Data(mdata, train_size=train_size) return data def _maybe_preprocess(self, raw_data, selection_threshold=10000): if self.preprocess is None: return raw_data train_mod1, train_mod2, train_label, test_mod1, test_mod2, test_label = raw_data modalities = [train_mod1, train_mod2, test_mod1, test_mod2] if is_numeric(train_mod2.obs_names[0]): train_mod2.obs_names = train_mod1.obs_names if is_numeric(test_mod2.obs_names[0]): test_mod2.obs_names = test_mod1.obs_names # TODO: support other two subtasks # assert self.subtask in ("openproblems_bmmc_cite_phase2_rna", "openproblems_bmmc_cite_phase2_rna_subset", # "openproblems_bmmc_multiome_phase2_rna","pbmc_cite","openproblems_2022_multi_atac2gex","openproblems_2022_cite_gex2adt"), "Currently not available." changed_count = 0 # keep track to modified entries due to ensuring count data type for i in range(4): m_data = modalities[i].X int_data = m_data.astype(int) changed_count += np.sum(int_data != m_data) modalities[i].X = int_data modalities[i].layers["counts"] = modalities[i].X if changed_count > 0: logger.warning("Implicit modification: to ensure count (integer type) data, " f"a total number of {changed_count} entries were modified.") if self.preprocess == "pca": if self.pkl_path and osp.exists(self.pkl_path): with open(self.pkl_path, "rb") as f: preprocessed_features = pickle.load(f) else: for i in range(2): sc.pp.filter_genes(modalities[i], min_cells=1, inplace=True) sc.pp.filter_genes(modalities[i + 2], min_cells=1, inplace=True) common_genes = list(set(modalities[i].var.index) & set(modalities[i + 2].var.index)) modalities[i] = modalities[i][:, common_genes] modalities[i + 2] = modalities[i + 2][:, common_genes] if self.subtask in ("openproblems_2022_cite_gex2adt_subset", "pbmc_cite", "openproblems_bmmc_cite_phase2_rna", "openproblems_bmmc_cite_phase2_rna_subset", "5k_pbmc_subset"): lsi_transformer_gex = lsiTransformer(n_components=256, drop_first=True) m1_train = lsi_transformer_gex.fit_transform(modalities[0]).values m1_test = lsi_transformer_gex.transform(modalities[2]).values m2_train = modalities[1].X.toarray() m2_test = modalities[3].X.toarray() elif self.subtask in ("GSE117089_mouse_gex2atac", "GSE117089_sciCAR_gex2atac", "GSE127064_AdBrain_gex2atac", "GSE127064_p0Brain_gex2atac", "openproblems_bmmc_multiome_phase2_rna", "10k_pbmc"): lsi_transformer_gex = lsiTransformer(n_components=256, drop_first=True) m1_train = lsi_transformer_gex.fit_transform(modalities[0]).values m1_test = lsi_transformer_gex.transform(modalities[2]).values lsi_transformer_atac = lsiTransformer(n_components=512, drop_first=True) m2_train = lsi_transformer_atac.fit_transform(modalities[1]).values m2_test = lsi_transformer_atac.transform(modalities[3]).values elif self.subtask in ("openproblems_2022_multi_atac2gex_subset", "GSE140203_3T3_HG19_atac2gex", "GSE140203_3T3_MM10_atac2gex", "GSE140203_12878.rep2_atac2gex", "GSE140203_12878.rep3_atac2gex", "GSE140203_K562_HG19_atac2gex", "GSE140203_K562_MM10_atac2gex", "GSE140203_LUNG_atac2gex"): lsi_transformer_atac = lsiTransformer(n_components=512, drop_first=True) m1_train = lsi_transformer_atac.fit_transform(modalities[0]).values m1_test = lsi_transformer_atac.transform(modalities[2]).values lsi_transformer_gex = lsiTransformer(n_components=256, drop_first=True) m2_train = lsi_transformer_gex.fit_transform(modalities[1]).values m2_test = lsi_transformer_gex.transform(modalities[3]).values else: raise ValueError(f"Unrecognized subtask name: {self.subtask}") preprocessed_features = { "mod1_train": m1_train, "mod2_train": m2_train, "mod1_test": m1_test, "mod2_test": m2_test } if self.pkl_path: with open(self.pkl_path, "wb") as f: pickle.dump(preprocessed_features, f) modalities[0].obsm["X_pca"] = preprocessed_features["mod1_train"] modalities[1].obsm["X_pca"] = preprocessed_features["mod2_train"] modalities[2].obsm["X_pca"] = preprocessed_features["mod1_test"] modalities[3].obsm["X_pca"] = preprocessed_features["mod2_test"] elif self.preprocess == "feature_selection": for i in [0, 2]: sc.pp.filter_cells(modalities[i], min_counts=1, inplace=True) sc.pp.filter_cells(modalities[i + 1], min_counts=1, inplace=True) common_cells = list(set(modalities[i].obs.index) & set(modalities[i + 1].obs.index)) modalities[i] = modalities[i][common_cells, :] modalities[i + 1] = modalities[i + 1][common_cells, :] if i == 0: train_label = train_label[common_cells, :] train_label = ad.AnnData(obs=train_label.obs, X=sp.csr_matrix(np.eye(len(train_label.obs)))) else: test_label = test_label[common_cells, :] test_label = ad.AnnData(obs=test_label.obs, X=sp.csr_matrix(np.eye(len(test_label.obs)))) for i in range(2): if modalities[i].shape[1] > selection_threshold: sc.pp.highly_variable_genes(modalities[i], layer="counts", flavor="seurat_v3", n_top_genes=selection_threshold, span=self.span) modalities[i + 2].var["highly_variable"] = modalities[i].var["highly_variable"] modalities[i] = modalities[i][:, modalities[i].var["highly_variable"]] modalities[i + 2] = modalities[i + 2][:, modalities[i + 2].var["highly_variable"]] for i in [0, 2]: sc.pp.filter_cells(modalities[i], min_counts=1, inplace=True) sc.pp.filter_cells(modalities[i + 1], min_counts=1, inplace=True) common_cells = list(set(modalities[i].obs.index) & set(modalities[i + 1].obs.index)) modalities[i] = modalities[i][common_cells, :] modalities[i + 1] = modalities[i + 1][common_cells, :] if i == 0: train_label = train_label[common_cells, :] train_label = ad.AnnData(obs=train_label.obs, X=sp.csr_matrix(np.eye(len(train_label.obs)))) else: test_label = test_label[common_cells, :] test_label = ad.AnnData(obs=test_label.obs, X=sp.csr_matrix(np.eye(len(test_label.obs)))) else: logger.info("Preprocessing method not supported.") logger.info("Preprocessing done.") train_mod1, train_mod2, test_mod1, test_mod2 = modalities return train_mod1, train_mod2, train_label, test_mod1, test_mod2, test_label
[docs]@register_dataset("multimodality") class JointEmbeddingNIPSDataset(MultiModalityDataset): TASK = "joint_embedding" URL_DICT = { "openproblems_bmmc_cite_phase2": "https://www.dropbox.com/s/hjr4dxuw55vin5z/openproblems_bmmc_cite_phase2.zip?dl=1", "openproblems_bmmc_multiome_phase2": "https://www.dropbox.com/s/40kjslupxhkg92s/openproblems_bmmc_multiome_phase2.zip?dl=1", "GSE140203_BRAIN_atac2gex": "https://www.dropbox.com/scl/fi/pa4zpj1fp00cqiavjadtx/GSE140203_BRAIN_atac2gex.zip?rlkey=oy73h59w4p9jsyhoxtaerxfw5&dl=1", "GSE140203_SKIN_atac2gex": "https://www.dropbox.com/scl/fi/2yuatq0icu6g5pc37jxq7/GSE140203_SKIN_atac2gex.zip?rlkey=o9fzlogrk3thcycv6u20jbyc6&dl=1", "openproblems_2022_cite_gex2adt": "https://www.dropbox.com/scl/fi/j3att18aems8ve8qhykeu/openproblems_2022_cite_gex2adt.zip?rlkey=i85wjp8iqkpxhbknywmwz8mz6&dl=1", "openproblems_2022_multi_atac2gex": "https://www.dropbox.com/scl/fi/fcw493eef1kmegwh9dpq9/openproblems_2022_multi_atac2gex.zip?rlkey=sd0dxbb9iadj84f84ai5cm5q5&dl=1" } SUBTASK_NAME_MAP = { "adt": "openproblems_bmmc_cite_phase2", "atac": "openproblems_bmmc_multiome_phase2", } AVAILABLE_DATA = sorted(list(URL_DICT) + list(SUBTASK_NAME_MAP)) def __init__(self, subtask, root="./data", preprocess=None, normalize=False, pretrained_folder=".", selection_threshold=10000, span=0.3): # TODO: factor our preprocess self.preprocess = preprocess self.normalize = normalize self.pretrained_folder = pretrained_folder super().__init__(subtask, root) self.selection_threshold = selection_threshold self.span = span def _raw_to_dance(self, raw_data): mod1, mod2, meta1, meta2, test_sol = self._maybe_preprocess(raw_data) self._to_csr([mod1, mod2, meta1, meta2, test_sol]) assert all(mod2.obs_names == mod1.obs_names), "Modalities not aligned" mdata = md.MuData({"mod1": mod1, "mod2": mod2, "meta1": meta1, "meta2": meta2, "test_sol": test_sol}) train_size = meta1.shape[0] data = Data(mdata, train_size=train_size) return data def _to_csr(self, datas): for data in datas: if scipy.sparse.issparse(data.X): if not isinstance(data.X, scipy.sparse.csr_matrix): data.X = data.X.tocsr() # data.X = np.array(data.X.todense()).astype(float) if "counts" in data.layers and scipy.sparse.issparse(data.layers["counts"]): if not isinstance(data.layers["counts"], scipy.sparse.csr_matrix): data.layers["counts"] = data.layers["counts"].tocsr() # data.layers["counts"] = np.array(data.layers["counts"].todense()).astype(float) def _maybe_preprocess(self, raw_data): if self.preprocess is None: return raw_data mod1, mod2, meta1, meta2, test_sol = raw_data train_size = meta1.shape[0] # aux -> cell cycle analysis if self.preprocess == "aux": os.makedirs(self.pretrained_folder, exist_ok=True) if osp.exists(osp.join(self.pretrained_folder, f"preprocessed_data_{self.subtask}.pkl")): with open(osp.join(self.pretrained_folder, f"preprocessed_data_{self.subtask}.pkl"), "rb") as f: preprocessed_data = pickle.load(f) y_train = preprocessed_data["y_train"] mod1.obsm["X_pca"] = preprocessed_data["X_pca_0"] mod2.obsm["X_pca"] = preprocessed_data["X_pca_1"] mod1.obsm["cell_type"] = y_train[0] mod1.obsm["batch_label"] = np.concatenate( [y_train[1], np.zeros(y_train[0].shape[0] - train_size)], 0) mod1.obsm["phase_labels"] = np.concatenate( [y_train[2], np.zeros(y_train[0].shape[0] - train_size)], 0) mod1.obsm["S_scores"] = np.concatenate([y_train[3], np.zeros(y_train[0].shape[0] - train_size)], 0) mod1.obsm["G2M_scores"] = np.concatenate( [y_train[4], np.zeros(y_train[0].shape[0] - train_size)], 0) with open(osp.join(self.pretrained_folder, f"{self.subtask}_config.pk"), "rb") as f: # cell types, batch labels, cell cycle self.nb_cell_types, self.nb_batches, self.nb_phases = pickle.load(f) logger.info("Preprocessing done.") return mod1, mod2, meta1, meta2, test_sol # PCA mod1_name = mod1.var["feature_types"][0] mod2_name = mod2.var["feature_types"][0] if mod2_name == "ADT": if osp.exists(osp.join(self.pretrained_folder, f"lsi_cite_{mod1_name}.pkl")): with open(osp.join(self.pretrained_folder, f"lsi_cite_{mod1_name}.pkl"), "rb") as f: lsi_transformer_gex = pickle.load(f) else: lsi_transformer_gex = lsiTransformer(n_components=256, drop_first=True) lsi_transformer_gex.fit(mod1) with open(osp.join(self.pretrained_folder, f"lsi_cite_{mod1_name}.pkl"), "wb") as f: pickle.dump(lsi_transformer_gex, f) if mod2_name == "ATAC": if osp.exists(osp.join(self.pretrained_folder, f"lsi_multiome_{mod1_name}.pkl")): with open(osp.join(self.pretrained_folder, f"lsi_multiome_{mod1_name}.pkl"), "rb") as f: lsi_transformer_gex = pickle.load(f) else: lsi_transformer_gex = lsiTransformer(n_components=64, drop_first=True) lsi_transformer_gex.fit(mod1) with open(osp.join(self.pretrained_folder, f"lsi_multiome_{mod1_name}.pkl"), "wb") as f: pickle.dump(lsi_transformer_gex, f) if osp.exists(osp.join(self.pretrained_folder, f"lsi_multiome_{mod2_name}.pkl")): with open(osp.join(self.pretrained_folder, f"lsi_multiome_{mod2_name}.pkl"), "rb") as f: lsi_transformer_atac = pickle.load(f) else: # lsi_transformer_atac = TruncatedSVD(n_components=100, random_state=random_seed) lsi_transformer_atac = lsiTransformer(n_components=512, drop_first=True) lsi_transformer_atac.fit(mod2) with open(osp.join(self.pretrained_folder, f"lsi_multiome_{mod2_name}.pkl"), "wb") as f: pickle.dump(lsi_transformer_atac, f) # Data preprocessing # Only exploration dataset provides cell type information. # The exploration dataset is a subset of the full dataset. ad_mod1 = meta1 mod1_obs = ad_mod1.obs # Make sure exploration data match the full data assert ((mod1.obs["batch"].index[:mod1_obs.shape[0]] == mod1_obs["batch"].index).mean() == 1) if mod2_name == "ADT": mod1_pca = lsi_transformer_gex.transform(mod1).values mod2_pca = mod2.X.toarray() elif mod2_name == "ATAC": mod1_pca = lsi_transformer_gex.transform(mod1).values mod2_pca = lsi_transformer_atac.transform(mod2).values else: raise ValueError(f"Unknown modality 2: {mod2_name}") cell_cycle_genes = [ "MCM5", "PCNA", "TYMS", "FEN1", "MCM2", "MCM4", "RRM1", "UNG", "GINS2", "MCM6", "CDCA7", "DTL", "PRIM1", "UHRF1", "MLF1IP", "HELLS", "RFC2", "RPA2", "NASP", "RAD51AP1", "GMNN", "WDR76", "SLBP", "CCNE2", "UBR7", "POLD3", "MSH2", "ATAD2", "RAD51", "RRM2", "CDC45", "CDC6", "EXO1", "TIPIN", "DSCC1", "BLM", "CASP8AP2", "USP1", "CLSPN", "POLA1", "CHAF1B", "BRIP1", "E2F8", "HMGB2", "CDK1", "NUSAP1", "UBE2C", "BIRC5", "TPX2", "TOP2A", "NDC80", "CKS2", "NUF2", "CKS1B", "MKI67", "TMPO", "CENPF", "TACC3", "FAM64A", "SMC4", "CCNB2", "CKAP2L", "CKAP2", "AURKB", "BUB1", "KIF11", "ANP32E", "TUBB4B", "GTSE1", "KIF20B", "HJURP", "CDCA3", "HN1", "CDC20", "TTK", "CDC25C", "KIF2C", "RANGAP1", "NCAPD2", "DLGAP5", "CDCA2", "CDCA8", "ECT2", "KIF23", "HMMR", "AURKA", "PSRC1", "ANLN", "LBR", "CKAP5", "CENPE", "CTCF", "NEK2", "G2E3", "GAS2L3", "CBX5", "CENPA" ] logger.info(f"Data loading and pca done: {mod1_pca.shape=}, {mod2_pca.shape=}") logger.info("Start to calculate cell_cycle score. It may roughly take an hour.") cell_type_labels = test_sol.obs["cell_type"].to_numpy() batch_ids = mod1_obs["batch"] phase_labels = mod1_obs["phase"] nb_cell_types = len(np.unique(cell_type_labels)) nb_batches = len(np.unique(batch_ids)) nb_phases = len(np.unique(phase_labels)) - 1 # 2 cell_type_labels_unique = list(np.unique(cell_type_labels)) batch_ids_unique = list(np.unique(batch_ids)) phase_labels_unique = list(np.unique(phase_labels)) c_labels = np.array([cell_type_labels_unique.index(item) for item in cell_type_labels]) b_labels = np.array([batch_ids_unique.index(item) for item in batch_ids]) p_labels = np.array([phase_labels_unique.index(item) for item in phase_labels]) # 0:G1, 1:G2M, 2: S, only consider the last two s_genes = cell_cycle_genes[:43] g2m_genes = cell_cycle_genes[43:] sc.pp.log1p(ad_mod1) sc.pp.scale(ad_mod1) sc.tl.score_genes_cell_cycle(ad_mod1, s_genes=s_genes, g2m_genes=g2m_genes) S_scores = ad_mod1.obs["S_score"].values G2M_scores = ad_mod1.obs["G2M_score"].values # phase_scores = np.stack([S_scores, G2M_scores]).T # (nb_cells, 2) y_train = [c_labels, b_labels, p_labels, S_scores, G2M_scores] mod1.obsm["X_pca"] = mod1_pca mod2.obsm["X_pca"] = mod2_pca train_size = mod1_obs.shape[0] mod1.obsm["cell_type"] = c_labels mod1.obsm["batch_label"] = np.concatenate([y_train[1], np.zeros(mod1.shape[0] - train_size)], 0) mod1.obsm["phase_labels"] = np.concatenate([y_train[2], np.zeros(mod1.shape[0] - train_size)], 0) mod1.obsm["S_scores"] = np.concatenate([y_train[3], np.zeros(mod1.shape[0] - train_size)], 0) mod1.obsm["G2M_scores"] = np.concatenate([y_train[4], np.zeros(mod1.shape[0] - train_size)], 0) preprocessed_data = {"X_pca_0": mod1.obsm["X_pca"], "X_pca_1": mod2.obsm["X_pca"], "y_train": y_train} with open(osp.join(self.pretrained_folder, f"preprocessed_data_{self.subtask}.pkl"), "wb") as f: pickle.dump(preprocessed_data, f) with open(osp.join(self.pretrained_folder, f"{self.subtask}_config.pk"), "wb") as f: pickle.dump([nb_cell_types, nb_batches, nb_phases], f) self.nb_cell_types, self.nb_batches, self.nb_phases = nb_cell_types, nb_batches, nb_phases elif self.preprocess == "pca": sc.pp.filter_genes(mod1, min_counts=3) sc.pp.filter_genes(mod2, min_counts=3) meta1 = meta1[:, mod1.var.index] meta2 = meta2[:, mod2.var.index] test_sol = test_sol[:, mod1.var.index] lsi_transformer_gex = lsiTransformer(n_components=64, drop_first=True) mod1.obsm['X_pca'] = lsi_transformer_gex.fit_transform(mod1).values mod2.obsm['X_pca'] = lsi_transformer_gex.fit_transform(mod2).values elif self.preprocess == "feature_selection": sc.pp.filter_genes(mod1, min_counts=3) sc.pp.filter_genes(mod2, min_counts=3) meta1 = meta1[:, mod1.var.index] meta2 = meta2[:, mod2.var.index] test_sol = test_sol[:, mod1.var.index] if mod1.shape[1] > self.selection_threshold: sc.pp.highly_variable_genes(mod1, layer="counts", flavor="seurat_v3", n_top_genes=self.selection_threshold, span=self.span) mod1 = mod1[:, mod1.var["highly_variable"]] # Equivalent to subset=True and _inplace_subset_var if mod2.shape[1] > self.selection_threshold: sc.pp.highly_variable_genes(mod2, layer="counts", flavor="seurat_v3", n_top_genes=self.selection_threshold, span=self.span) mod2 = mod2[:, mod2.var["highly_variable"]] sc.pp.filter_cells(mod1, min_genes=1, inplace=True) sc.pp.filter_cells(mod2, min_genes=1, inplace=True) common_cells = list(set(mod1.obs.index) & set(mod2.obs.index)) mod1 = mod1[common_cells, :] mod2 = mod2[common_cells, :] test_sol = test_sol[common_cells, :] sc.pp.filter_cells(meta1, min_genes=1, inplace=True) sc.pp.filter_cells(meta2, min_genes=1, inplace=True) meta_common_cells = list(set(meta1.obs.index) & set(meta2.obs.index)) meta1 = meta1[meta_common_cells, :] meta2 = meta2[meta_common_cells, :] else: logger.info(f"Preprocessing method {self.preprocess!r} not supported.") # Normalization if self.normalize: sc.pp.scale(mod1) sc.pp.scale(mod2) logger.info("Preprocessing done.") return mod1, mod2, meta1, meta2, test_sol