import itertools
import warnings
from abc import ABC, abstractmethod
from copy import deepcopy
from functools import partial
from operator import is_not
from pprint import pformat
import anndata
import mudata
import numpy as np
import omegaconf
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import torch
from deprecated import deprecated
from dance import logger
from dance.typing import Any, Dict, FeatType, Iterator, List, ListConfig, Literal, Optional, Sequence, Tuple, Union
def _ensure_iter(val: Optional[Union[List[str], str]]) -> Iterator[Optional[str]]:
if val is None:
val = itertools.repeat(None)
elif isinstance(val, str):
val = [val]
elif not isinstance(val, list):
raise TypeError(f"Input to _ensure_iter must be list, str, or None. Got {type(val)}.")
return val
def _check_types_and_sizes(types, sizes):
if len(types) == 0:
return
elif len(types) > 1:
raise TypeError(f"Found mixed types: {types}. Input configs must be either all str or all lists.")
elif ((type_ := types.pop()) == list) and (len(sizes) > 1):
raise ValueError(f"Found mixed sizes lists: {sizes}. Input configs must be of same length.")
elif type_ not in (list, str, ListConfig):
raise TypeError(f"Unknownn type {type_} found in config.")
[docs]class BaseData(ABC):
"""Base data object.
The ``dance`` data object is a wrapper of the :class:`~anndata.AnnData` object, with several utility methods to
help retrieving data in specific splits in specific format (see :meth:`~BaseData.get_split_idx` and
:meth:`~BaseData.get_feature`). The :class:`~anndata.AnnData` objcet is saved in the attribute ``data`` and can be
accessed directly.
Warning
-------
Since the underlying data object is a reference to the input :class:`~anndata.AnnData` object, please be extra
cautious ***NOT*** initializing two different dance ``data`` object using the same :class:`~anndata.AnnData`
object! If you are unsure, we recommend always initialize the dance ``data`` object using a ``copy`` of the input
:class:`~anndata.AnnData` object, e.g.,
>>> adata = anndata.AnnData(...)
>>> ddata = dance.data.Data(adata.copy())
Note
----
You can directly access some main properties of :class:`~anndata.AnnData` (or :class:`~mudata.MuData` depending on
which type of data you passed in), such as ``X``, ``obs``, ``var``, and etc.
Parameters
----------
data
Cell data.
train_size
Number of cells to be used for training. If not specified, not splits will be generated.
val_size
Number of cells to be used for validation. If set to -1, use what's left from training and testing.
test_size
Number of cells to be used for testing. If set to -1, used what's left from training and validation.
"""
_FEATURE_CONFIGS: List[str] = ["feature_mod", "feature_channel", "feature_channel_type"]
_LABEL_CONFIGS: List[str] = ["label_mod", "label_channel", "label_channel_type"]
_DATA_CHANNELS: List[str] = ["obs", "var", "obsm", "varm", "obsp", "varp", "layers", "uns"]
def __init__(self, data: Union[anndata.AnnData, mudata.MuData], train_size: Optional[int] = None, val_size: int = 0,
test_size: int = -1, split_index_range_dict: Optional[Dict[str, Tuple[int, int]]] = None,
full_split_name: Optional[str] = None):
super().__init__()
# Check data type
if isinstance(data, anndata.AnnData):
additional_channels = ["X"]
elif isinstance(data, mudata.MuData):
additional_channels = ["X", "mod"]
else:
raise TypeError(f"Unknown data type {type(data)}, must be either AnnData or MuData.")
# Store data and pass through some main properties over
self._data = data
for prop in self._DATA_CHANNELS + additional_channels:
assert not hasattr(self, prop)
setattr(self, prop, getattr(data, prop))
# TODO: move _split_idx_dict into data.uns
self._split_idx_dict: Dict[str, Sequence[int]] = {}
self._setup_splits(train_size, val_size, test_size, split_index_range_dict, full_split_name)
if "dance_config" not in self._data.uns:
self._data.uns["dance_config"] = dict()
def __repr__(self) -> str:
return f"{self.__class__.__name__} object that wraps (.data):\n{self.data}"
# WARNING: need to be careful about subsampling cells as the index are not automatically updated!!
def _setup_splits(
self,
train_size: Optional[Union[int, str]],
val_size: int,
test_size: int,
split_index_range_dict: Optional[Dict[str, Tuple[int, int]]],
full_split_name: Optional[str],
):
if (split_index_range_dict is not None) and (full_split_name is not None):
raise ValueError("Only one of split_index_range_dict, full_split_name can be specified, but not both")
elif split_index_range_dict is not None:
self._setup_splits_range(split_index_range_dict)
elif full_split_name is not None:
self._setup_splits_full(full_split_name)
else:
self._setup_splits_default(train_size, val_size, test_size)
def _setup_splits_default(self, train_size: Optional[Union[int, str]], val_size: int, test_size: int):
if train_size is None:
return
elif isinstance(train_size, str) and train_size.lower() == "all":
train_size = -1
val_size = test_size = 0
elif any(not isinstance(i, int) for i in (train_size, val_size, test_size)):
raise TypeError("Split sizes must be of type int")
split_names = ["train", "val", "test"]
split_sizes = np.array((train_size, val_size, test_size))
# Only one -1 (complementary size) is allowed
if (split_sizes == -1).sum() > 1:
raise ValueError("Only one split can be specified as -1")
# Each size must be bounded between -1 and the data size
data_size = self.num_cells
for name, size in zip(split_names, split_sizes):
if size < -1:
raise ValueError(f"{name} must be integer no less than -1, got {size!r}")
elif size > data_size:
raise ValueError(f"{name}={size:,} exceeds total number of samples {data_size:,}")
# Sum of sizes must be bounded by the data size
if (tot_size := split_sizes.clip(0).sum()) > data_size:
raise ValueError(f"Total size {tot_size:,} exceeds total number of samples {data_size:,}")
logger.debug(f"Split sizes before conversion: {split_sizes.tolist()}")
split_sizes[split_sizes == -1] = data_size - split_sizes.clip(0).sum()
logger.debug(f"Split sizes after conversion: {split_sizes.tolist()}")
split_thresholds = split_sizes.cumsum()
for i, split_name in enumerate(split_names):
start = split_thresholds[i - 1] if i > 0 else 0
end = split_thresholds[i]
if end - start > 0: # skip empty split
self._split_idx_dict[split_name] = list(range(start, end))
def _setup_splits_range(self, split_index_range_dict: Dict[str, Tuple[int, int]]):
for split_name, index_range in split_index_range_dict.items():
if (not isinstance(index_range, tuple)) or (len(index_range) != 2):
raise TypeError("The split index range must of a two-tuple containing the start and end index. "
f"Got {index_range!r} for key {split_name!r}")
elif any(not isinstance(i, int) for i in index_range):
raise TypeError("The split index range must of a two-tuple of int type. "
f"Got {index_range!r} for key {split_name!r}")
start, end = index_range
if end - start > 0: # skip empty split
self._split_idx_dict[split_name] = list(range(start, end))
def _setup_splits_full(self, full_split_name: str):
self._split_idx_dict[full_split_name] = list(range(self.shape[0]))
def __getitem__(self, idx: Sequence[int]) -> Any:
return self.data[idx]
@property
def data(self):
return self._data
@property
@abstractmethod
def x(self):
raise NotImplementedError
@property
@abstractmethod
def y(self):
raise NotImplementedError
@property
def config(self) -> Dict[str, Any]:
"""Return the dance data object configuration dict.
Notes
-----
The configuration dictionary is saved in the ``data`` attribute, which is an :class:`~anndata.AnnData`
object. Inparticular, the config will be saved in the ``.uns`` attribute with the key ``"dance_config"``.
"""
return self._data.uns["dance_config"]
[docs] def set_config(self, *, overwrite: bool = False, **kwargs):
"""Set dance data object configuration.
See
:meth: `~BaseData.set_config_from_dict`.
"""
self.set_config_from_dict(kwargs, overwrite=overwrite)
[docs] def set_config_from_dict(self, config_dict: Dict[str, Any], *, overwrite: bool = False):
"""Set dance data object configuration from a config dict.
Parameters
----------
config_dict
Configuration dictionary.
overwrite
Used to determine the behaviour of resolving config conflicts. In the case of a conflict, where the config
dict passed contains a key with value that differs from an existing setting, if ``overwrite`` is set to
``False``, then raise a ``KeyError``. Otherwise, overwrite the configuration with the new values.
"""
# Check config key validity
all_configs = set(self._FEATURE_CONFIGS + self._LABEL_CONFIGS)
if (unknown_options := set(config_dict).difference(all_configs)):
raise KeyError(f"Unknown config option(s): {unknown_options}, available options are: {all_configs}")
feature_configs = [j for i, j in config_dict.items() if i in self._FEATURE_CONFIGS and j is not None]
label_configs = [j for i, j in config_dict.items() if i in self._LABEL_CONFIGS and j is not None]
# Check type and length consistencies for feature and label configs
for i in [feature_configs, label_configs]:
types = set(map(type, i))
sizes = set(map(len, i))
_check_types_and_sizes(types, sizes)
# Finally, update the configs
for config_key, config_val in config_dict.items():
# New config
if config_key not in self.config:
if isinstance(config_val, ListConfig):
config_val = omegaconf.OmegaConf.to_object(config_val)
logger.warning(f"transform ListConfig {config_val} to List")
self.config[config_key] = config_val
logger.info(f"Setting config {config_key!r} to {config_val!r}")
continue
# Existing config
if (old_config_val := self.config[config_key]) == config_val: # new value is the same as before
continue
elif overwrite: # new value differs from before and overwrite setting is turned on
self.config[config_key] = config_val
logger.warning(f"Overwriting config {config_key!r} to {config_val!r} (previously {old_config_val!r})")
else: # new value differs from before but overwrite setting is not on
raise KeyError(f"Config {config_key!r} exit with value {old_config_val!r} but trying to set to a "
f"different value {config_val!r}. If you want to overwrite the config, please specify "
"`overwrite=True` when calling the set config function.")
@property
def num_cells(self) -> int:
return self.data.shape[0]
@property
def num_features(self) -> int:
return self.data.shape[1]
@property
def cells(self) -> List[str]:
return self.data.obs.index.tolist()
@property
def train_idx(self) -> Sequence[int]:
return self.get_split_idx("train", error_on_miss=False)
@property
def val_idx(self) -> Sequence[int]:
return self.get_split_idx("val", error_on_miss=False)
@property
def test_idx(self) -> Sequence[int]:
return self.get_split_idx("test", error_on_miss=False)
@property
def shape(self) -> Tuple[int, int]:
return self.data.shape
def copy(self):
return deepcopy(self)
[docs] def set_split_idx(self, split_name: str, split_idx: Sequence[int]):
"""Set cell indices for a particular split.
Parameters
----------
split_name
Name of the split to set.
split_idx
Indices to be used in this split.
"""
self._split_idx_dict[split_name] = split_idx
[docs] def get_split_idx(self, split_name: str, error_on_miss: bool = False):
"""Obtain cell indices for a particular split.
Parameters
----------
split_name
Name of the split to retrieve.
error_on_miss
If set to True, raise KeyError if the queried split does not exit, otherwise return None.
See Also
--------
:meth:`~get_split_mask`
"""
if split_name is None:
return list(range(self.shape[0]))
elif split_name in self._split_idx_dict:
return self._split_idx_dict[split_name]
elif error_on_miss:
raise KeyError(f"Unknown split {split_name!r}. Please set the split inddices via set_split_idx first.")
else:
return None
[docs] def get_split_mask(self, split_name: str, return_type: FeatType = "numpy") -> Union[np.ndarray, torch.Tensor]:
"""Obtain mask representation of a particular split.
Parameters
----------
split_name
Name of the split to retrieve.
return_type
Return numpy array if set to 'numpy', or torch Tensor if set to 'torch'.
"""
split_idx = self.get_split_idx(split_name, error_on_miss=True)
if return_type == "numpy":
mask = np.zeros(self.shape[0], dtype=bool)
elif return_type == "torch":
mask = torch.zeros(self.shape[0], dtype=torch.bool)
else:
raise ValueError(f"Unsupported return_type {return_type!r}. Available options are 'numpy' and 'torch'.")
mask[split_idx] = True
return mask
[docs] def get_split_data(self, split_name: str) -> Union[anndata.AnnData, mudata.MuData]:
"""Obtain the underlying data of a particular split.
Parameters
----------
split_name
Name of the split to retrieve.
"""
split_idx = self.get_split_idx(split_name, error_on_miss=True)
return self.data[split_idx]
@staticmethod
def _get_feature(
in_data: Union[anndata.AnnData, mudata.MuData],
channel: Optional[str],
channel_type: Optional[str],
mod: Optional[str],
) -> Union[np.ndarray, sp.spmatrix, pd.DataFrame]:
# Pick modality
if mod is None:
data = in_data
elif not isinstance(in_data, mudata.MuData):
raise AttributeError("`mod` option is only available when using multimodality data.")
elif mod not in in_data.mod:
raise KeyError(f"Unknown modality {mod!r}, available options are {sorted(data.mod)}")
else:
data = in_data.mod[mod]
if channel_type == "X":
feature = data.X
elif channel_type == "raw_X":
feature = data.raw.X
else:
# Pick channels (obsm, varm, ...)
channel_type = channel_type or "obsm" # default to obsm
if channel_type not in (options := BaseData._DATA_CHANNELS):
raise ValueError(f"Unknown channel type {channel_type!r}. Available options are {options}")
channel_obj = getattr(data, channel_type)
# Pick feature from a specific channel
if channel is None:
# FIX: channel default change to "X".
warnings.warn(
"The `None` option for channel when channel_type is no longer supported "
"and will raise an exception in the near future version. Please change "
"channel_type to 'X' to preserve the current behavior", DeprecationWarning, stacklevel=2)
feature = data.X
else:
feature = channel_obj[channel]
return feature
[docs] def get_feature(self, *, split_name: Optional[str] = None, return_type: FeatType = "numpy",
channel: Optional[str] = None, channel_type: Optional[str] = "obsm",
mod: Optional[str] = None): # yapf: disable
"""Retrieve features from data.
Parameters
----------
split_name
Name of the split to retrieve. If not set, return all.
return_type
How should the features be returned. **sparse**: return as a sparse matrix; **numpy**: return as a numpy
array; **torch**: return as a torch tensor; **anndata**: return as an anndata object.
channel
Return a particular channel as features. If ``channel_type`` is ``X`` or ``raw_X``, then return ``.X`` or
the ``.raw.X`` attribute from the :class:`~anndata.AnnData` directly. If ``channel_type`` is ``obs``, return
the column named by ``channel``, similarly for ``var``. Finally, if ``channel_type`` is ``obsm``, ``obsp``,
``varm``, ``varp``, ``layers``, or ``uns``, then return the value correspond to the ``channel`` in the
dictionary.
channel_type
Channel type to use, default to ``obsm`` (will be changed to ``X`` in the near future).
mod
Modality to use, default to ``None``. Options other than ``None`` are only available when the underlying
data object is :class:`~mudata.Mudata`.
"""
feature = self._get_feature(self.data, channel, channel_type, mod)
# FIX: no longer allow channel_type=None, use channel_type='X' or 'raw_X' instead
channel_type = channel_type or "obsm"
if return_type == "default":
if split_name is not None:
raise ValueError(f"split_name is not supported when return_type is 'default', got {split_name=!r}")
return feature
if return_type == "sparse":
if isinstance(feature, np.ndarray):
feature = sp.csr_matrix(feature)
elif not isinstance(feature, sp.spmatrix):
raise ValueError(f"Feature is not sparse, got {type(feature)}")
# Transform features to numpy array
elif hasattr(feature, "toarray"): # convert sparse array to dense numpy array
feature = feature.toarray()
elif hasattr(feature, "to_numpy"): # convert dataframe to numpy array
feature = feature.to_numpy()
# Extract specific split
if split_name is not None:
if channel_type in ["X", "raw_X", "obs", "obsm", "obsp", "layers"]:
idx = self.get_split_idx(split_name, error_on_miss=True)
idx = list(filter(lambda a: a < feature.shape[0], idx))
feature = feature[idx][:, idx] if channel_type == "obsp" else feature[idx]
else:
logger.warning(f"Indexing option for {channel_type!r} not implemented yet.")
# Convert to other data types if needed
if return_type == "torch":
feature = torch.from_numpy(feature)
elif return_type not in ["numpy", "sparse"]:
raise ValueError(f"Unknown return_type {return_type!r}")
return feature
[docs] def append(
self,
data,
*,
mode: Optional[Literal["merge", "rename", "new_split"]] = "merge",
rename_dict: Optional[Dict[str, str]] = None,
new_split_name: Optional[str] = None,
label_batch: bool = False,
**concat_kwargs,
):
"""Append another dance data object to the current data object.
Parameters
----------
data
New dance data object to be added.
mode
How to combine the splits from the new data and the current data. (1) ``"merge"``: merge the splits from
the data, e.g., the training indexes from both data are used as the training indexes in the new combined
data. (2) ``"rename"``: rename the splits of the new data and add to the current split index dictionary,
e.g., renaming 'train' to 'ref'. Requires passing the ``rename_dict``. Raise an error if the newly renamed
key is already used in the current split index dictionary. (3) ``"new_split"``: assign the whole new data
to a new split. Requires pssing the ``new_split_name`` that is not already used as a split name in the
current data. (4) ``None``: do not specify split index to the newly added data.
rename_dict
Optional argument that is only used when ``mode="rename"``. A dictionary to map the split names in the new
data to other names.
new_split_name
Optional argument that is only used when ``mode="new_split"``. Name of the split to assign to the new data.
label_batch
Add "batch" column to ``.obs`` when set to True.
**concat_kwargs
See :meth:`anndata.concat`.
"""
offset = self.shape[0]
new_split_idx_dict = {i: sorted(np.array(j) + offset) for i, j in data._split_idx_dict.items()}
if mode == "merge":
for split_name, split_idxs in self._split_idx_dict.items():
if split_name in new_split_idx_dict:
split_idxs = split_idxs + new_split_idx_dict[split_name]
new_split_idx_dict[split_name] = split_idxs
elif mode == "rename":
if rename_dict is None:
raise ValueError("Mode 'rename' is selected but 'rename_dict' is not specified.")
elif len(common_keys := set(self._split_idx_dict) & set(rename_dict.values())) > 0:
raise ValueError(f"'rename_dict' cannot caontain split keys present in current data: {common_keys}")
elif len(missed_keys := [i for i in data._split_idx_dict if i not in rename_dict]) > 0:
raise KeyError(f"Missing rename mapping for keys: {missed_keys}")
new_split_idx_dict = {rename_dict[i]: j for i, j in new_split_idx_dict.items()}
new_split_idx_dict.update(self._split_idx_dict)
elif mode == "new_split":
if new_split_name is None:
raise ValueError("Mode 'new_split' is selected but 'new_split_name' is not specified.")
elif not isinstance(new_split_name, str):
raise TypeError(f"'new_split_name' must be a string, got {type(new_split_name)}: {new_split_name}.")
elif new_split_name in self._split_idx_dict:
raise ValueError(f"{new_split_name!r} is being used in the current splits. Please pick another name.")
new_split_idx_dict = {new_split_name: list(range(offset, offset + data.shape[0]))}
new_split_idx_dict.update(self._split_idx_dict)
elif mode is None:
new_split_idx_dict = self._split_idx_dict
else:
raise ValueError(f"Unknown mode {mode!r}. Available options are: 'merge', 'rename', 'new_split'")
# NOTE: Manually merging uns cause AnnData is incapable of doing so, even with uns_merge set
new_uns = dict(data.data.uns)
new_uns.update(dict(self.data.uns))
if label_batch:
if "batch" in self.data.obs.columns:
old_batch = self.data.obs["batch"].tolist()
else:
old_batch = np.zeros(self.shape[0]).tolist()
new_batch = (np.ones(data.shape[0]) * (max(old_batch) + 1)).tolist()
batch = list(map(int, old_batch + new_batch))
self._data = anndata.concat((self.data, data.data), **concat_kwargs)
self._data.uns.update(new_uns)
self._split_idx_dict = new_split_idx_dict
if label_batch:
self._data.obs["batch"] = pd.Series(batch, dtype="category", index=self._data.obs.index)
return self
def pop(self, *, split_name: str):
# TODO: ass more option, e.g., index
index_to_pop = self.get_split_idx(split_name, error_on_miss=True)
index_to_preserve = sorted(set(range(self.shape[0])) - set(index_to_pop))
oldidx_to_newidx = {j: i for i, j in enumerate(index_to_preserve)}
new_split_idx_dict = {}
for split_name, split_idx in self._split_idx_dict.items():
new_split_idx = sorted(filter(partial(is_not, None), map(oldidx_to_newidx.get, split_idx)))
if len(new_split_idx) > 0:
new_split_idx_dict[split_name] = new_split_idx
logger.info(f"Updating split index for {split_name!r}. {len(split_idx):,} -> {len(new_split_idx):,}")
self._data = self._data[index_to_preserve]
self._split_idx_dict = new_split_idx_dict
[docs] @deprecated("out of date")
def filter_cells(self, **kwargs):
"""Apply cell filtering using scanpy.pp.filter_cells and update splits.
Filters the cells in `self.data` based on the provided criteria,
similar to `scanpy.pp.filter_cells`. Crucially, this method also
updates the internal split indices (`train_idx`, `val_idx`, etc.)
to reflect the cells remaining after filtering.
Parameters
----------
**kwargs
Arguments passed directly to `scanpy.pp.filter_cells`.
Common arguments include `min_counts`, `max_counts`,
`min_genes`, `max_genes`. Note: `inplace` is forced to `False`
internally to get the filter mask, then applied effectively inplace.
Returns
-------
self
Returns the instance to allow method chaining.
Raises
------
NotImplementedError
If the underlying `self.data` is not an `anndata.AnnData` object.
Filtering `MuData` requires more careful consideration of modalities.
"""
if not isinstance(self.data, anndata.AnnData):
# Filtering MuData needs careful handling: filter which modality?
# How to sync obs across modalities after filtering one?
raise NotImplementedError("filter_cells is currently only implemented for AnnData objects. "
"Filtering MuData requires specific modality handling.")
logger.info(f"Applying filter_cells with parameters: {kwargs}")
original_shape = self.data.shape
original_obs_names = self.data.obs_names.copy()
# 1. Store original obs_names for each split
# We need the *names* of the cells in each split before filtering
original_split_obs_names: Dict[str, pd.Index] = {}
for split_name, split_idx in self._split_idx_dict.items():
if split_idx is not None and len(split_idx) > 0:
original_split_obs_names[split_name] = original_obs_names[split_idx]
else:
original_split_obs_names[split_name] = pd.Index([]) # Handle empty splits
# 2. Determine which cells to keep using scanpy's logic
# We run it with inplace=False first to get the boolean mask
try:
kwargs_copy = kwargs.copy()
kwargs_copy['inplace'] = False
cells_mask = sc.pp.filter_cells(self.data, **kwargs_copy)
except Exception as e:
logger.error(f"Error during sc.pp.filter_cells execution: {e}")
raise
num_filtered = original_shape[0] - cells_mask.sum()
if num_filtered == 0:
logger.info("No cells were filtered.")
return self # Nothing changed
logger.info(f"Filtering out {num_filtered} cells ({original_shape[0]} -> {cells_mask.sum()}).")
# 3. Apply the filtering to self.data
# Slicing creates a view or copy; we make it an explicit copy
# to ensure the underlying data is modified cleanly.
self._data = self.data[cells_mask, :].copy()
logger.debug(f"Data shape after filtering: {self.data.shape}")
# 4. Update split indices
new_obs_names = self.data.obs_names # Names of cells *after* filtering
# Create a fast lookup for new index positions
new_obs_name_to_new_idx = {name: i for i, name in enumerate(new_obs_names)}
new_split_idx_dict = {}
total_kept_in_splits = 0
for split_name, original_names_in_split in original_split_obs_names.items():
# Find which names from this original split are still in the data
kept_names_in_split = original_names_in_split[original_names_in_split.isin(new_obs_names)]
# Get the *new* integer indices corresponding to these kept names
new_indices = [new_obs_name_to_new_idx[name] for name in kept_names_in_split]
if len(new_indices) > 0:
new_split_idx_dict[split_name] = sorted(new_indices) # Store sorted indices
logger.debug(f"Split '{split_name}': {len(original_names_in_split)} -> {len(new_indices)} cells.")
total_kept_in_splits += len(new_indices)
else:
# Keep the split name but with an empty list, or remove?
# Keeping it might be less surprising.
new_split_idx_dict[split_name] = []
logger.warning(f"Split '{split_name}' is now empty after filtering.")
# 5. Check consistency
if total_kept_in_splits != self.data.shape[0]:
# This might happen if some cells were not assigned to any split initially
logger.warning(f"Total cells in updated splits ({total_kept_in_splits}) "
f"does not match total cells after filtering ({self.data.shape[0]}). "
"This may be expected if not all original cells were in a split.")
# Update the internal dictionary
self._split_idx_dict = new_split_idx_dict
# Update AnnData properties accessible directly from BaseData/Data
for prop in self._DATA_CHANNELS + ["X"]: # Assuming AnnData here based on check above
if hasattr(self._data, prop):
setattr(self, prop, getattr(self._data, prop))
logger.info("Cell filtering complete and split indices updated.")
return self
# --- START NEW METHOD ---
[docs] def filter_by_mask(self, mask: Union[Sequence[bool], pd.Series, np.ndarray], update_splits: bool = True):
"""Filter cells based on a boolean mask and optionally update splits.
Filters the cells in `self.data` using a provided boolean mask.
If `update_splits` is True, this method also updates the internal
split indices (`train_idx`, `val_idx`, etc.) to reflect the cells
remaining after filtering.
Parameters
----------
mask : Union[Sequence[bool], pd.Series, np.ndarray]
A boolean mask (list, Series, or array) with the same length as
the current number of cells (`self.data.shape[0]`). Cells where
the mask is True will be kept.
update_splits : bool, optional
Whether to update the internal split indices to align with the
filtered data. Defaults to True. If set to False, the split
indices will become invalid if any cells are removed.
Returns
-------
self
Returns the instance to allow method chaining.
Raises
------
ValueError
If the mask is not boolean or has an incorrect length.
NotImplementedError
If the underlying `self.data` is not an `anndata.AnnData` object
(as filtering MuData requires more careful handling).
"""
if not isinstance(self.data, anndata.AnnData):
raise NotImplementedError("filter_by_mask is currently only implemented for AnnData objects.")
# --- Input Validation ---
if not isinstance(mask, (list, tuple, pd.Series, np.ndarray)):
raise TypeError(f"Mask must be a sequence, Series, or ndarray, got {type(mask)}")
if len(mask) != self.data.shape[0]:
raise ValueError(f"Mask length ({len(mask)}) must match number of cells ({self.data.shape[0]})")
try:
# Ensure boolean type (handles potential integer 0/1 masks)
mask = np.asarray(mask, dtype=bool)
except TypeError:
raise ValueError("Mask could not be converted to boolean.")
num_to_keep = mask.sum()
num_to_filter = len(mask) - num_to_keep
if num_to_filter == 0:
logger.info("Provided mask keeps all cells. No filtering applied.")
return self
logger.info(f"Filtering cells based on mask: {self.data.shape[0]} -> {num_to_keep} ({num_to_filter} removed).")
# --- Store Original State (if updating splits) ---
original_split_obs_names: Dict[str, pd.Index] = {}
if update_splits:
original_obs_names = self.data.obs_names.copy()
for split_name, split_idx in self._split_idx_dict.items():
if split_idx is not None and len(split_idx) > 0:
# Check indices are valid before using them
if max(split_idx) >= len(original_obs_names):
raise IndexError(f"Invalid index found in split '{split_name}' before filtering.")
original_split_obs_names[split_name] = original_obs_names[split_idx]
else:
original_split_obs_names[split_name] = pd.Index([])
# --- Apply Filtering ---
# Slicing creates a view or copy; make it an explicit copy.
self._data = self.data[mask, :].copy()
logger.debug(f"Data shape after filtering: {self.data.shape}")
# --- Update Split Indices (if requested) ---
if update_splits:
new_obs_names = self.data.obs_names
new_obs_name_to_new_idx = {name: i for i, name in enumerate(new_obs_names)}
new_split_idx_dict = {}
total_kept_in_splits = 0
for split_name, original_names_in_split in original_split_obs_names.items():
kept_names_in_split = original_names_in_split[original_names_in_split.isin(new_obs_names)]
new_indices = [new_obs_name_to_new_idx[name] for name in kept_names_in_split]
if len(new_indices) > 0:
new_split_idx_dict[split_name] = sorted(new_indices)
logger.debug(f"Split '{split_name}': {len(original_names_in_split)} -> {len(new_indices)} cells.")
total_kept_in_splits += len(new_indices)
else:
new_split_idx_dict[split_name] = []
logger.warning(f"Split '{split_name}' is now empty after filtering.")
if total_kept_in_splits != self.data.shape[0]:
logger.warning(f"Total cells in updated splits ({total_kept_in_splits}) "
f"does not match total cells after filtering ({self.data.shape[0]}). "
"This may be expected if not all original cells were in a split.")
self._split_idx_dict = new_split_idx_dict
logger.info("Split indices updated.")
else:
logger.warning("Filtering applied, but split indices were *not* updated as requested. "
"Existing split indices are now likely invalid.")
# --- Update AnnData properties accessible directly ---
for prop in self._DATA_CHANNELS + ["X"]:
if hasattr(self._data, prop):
setattr(self, prop, getattr(self._data, prop))
logger.info("Filtering by mask complete.")
return self
[docs]class Data(BaseData):
@property
def x(self):
return self.get_x(return_type="default")
@property
def y(self):
return self.get_y(return_type="default")
def _get(self, config_keys: List[str], *, split_name: Optional[str] = None, return_type: FeatType = "numpy",
**kwargs) -> Any:
info = list(map(self.config.get, config_keys))
if all(i is None for i in info):
mods = channels = channel_types = [None]
else:
mods, channels, channel_types = map(_ensure_iter, info)
out = []
for mod, channel, channel_type in zip(mods, channels, channel_types):
try:
x = self.get_feature(split_name=split_name, return_type=return_type, mod=mod, channel=channel,
channel_type=channel_type, **kwargs)
except Exception as e:
settings = {
"split_name": split_name,
"return_type": return_type,
"mod": mod,
"channel": channel,
"channel_type": channel_type,
"kwargs": kwargs,
}
raise RuntimeError(f"Failed to get features for the following settings:\n{pformat(settings)}") from e
out.append(x)
out = out[0] if len(out) == 1 else out
return out
[docs] def get_x(self, split_name: Optional[str] = None, return_type: FeatType = "numpy", **kwargs) -> Any:
"""Retrieve cell features from a particular split."""
return self._get(self._FEATURE_CONFIGS, split_name=split_name, return_type=return_type, **kwargs)
[docs] def get_y(self, split_name: Optional[str] = None, return_type: FeatType = "numpy", **kwargs) -> Any:
"""Retrieve cell labels from a particular split."""
return self._get(self._LABEL_CONFIGS, split_name=split_name, return_type=return_type, **kwargs)
[docs] def get_data(
self, split_name: Optional[str] = None, return_type: FeatType = "numpy", x_kwargs: Dict[str, Any] = dict(),
y_kwargs: Dict[str, Any] = dict()
) -> Tuple[Any, Any]:
"""Retrieve cell features and labels from a particular split.
Parameters
----------
split_name
Name of the split to retrieve. If not set, return all.
return_type
How should the features be returned. **numpy**: return as a numpy array; **torch**: return as a torch
tensor; **anndata**: return as an anndata object.
"""
x = self.get_x(split_name, return_type, **x_kwargs)
y = self.get_y(split_name, return_type, **y_kwargs)
return x, y
[docs] def get_train_data(
self, return_type: FeatType = "numpy", x_kwargs: Dict[str, Any] = dict(), y_kwargs: Dict[str, Any] = dict()
) -> Tuple[Any, Any]:
"""Retrieve cell features and labels from the 'train' split."""
return self.get_data("train", return_type, x_kwargs, y_kwargs)
[docs] def get_val_data(
self, return_type: FeatType = "numpy", x_kwargs: Dict[str, Any] = dict(), y_kwargs: Dict[str, Any] = dict()
) -> Tuple[Any, Any]:
"""Retrieve cell features and labels from the 'val' split."""
return self.get_data("val", return_type, x_kwargs, y_kwargs)
[docs] def get_test_data(
self, return_type: FeatType = "numpy", x_kwargs: Dict[str, Any] = dict(), y_kwargs: Dict[str, Any] = dict()
) -> Tuple[Any, Any]:
"""Retrieve cell features and labels from the 'test' split."""
return self.get_data("test", return_type, x_kwargs, y_kwargs)