diff --git a/openproblems/tasks/label_projection/methods/__init__.py b/openproblems/tasks/label_projection/methods/__init__.py index 066ba83e8e..eeff54495a 100644 --- a/openproblems/tasks/label_projection/methods/__init__.py +++ b/openproblems/tasks/label_projection/methods/__init__.py @@ -11,6 +11,8 @@ from .scvi_tools import scanvi_hvg from .scvi_tools import scarches_scanvi_all_genes from .scvi_tools import scarches_scanvi_hvg +from .scvi_tools import scarches_scanvi_xgb_all_genes +from .scvi_tools import scarches_scanvi_xgb_hvg from .seurat import seurat from .xgboost import xgboost_log_cp10k from .xgboost import xgboost_scran diff --git a/openproblems/tasks/label_projection/methods/scvi_tools.py b/openproblems/tasks/label_projection/methods/scvi_tools.py index 7b68324439..5348b54435 100644 --- a/openproblems/tasks/label_projection/methods/scvi_tools.py +++ b/openproblems/tasks/label_projection/methods/scvi_tools.py @@ -1,5 +1,7 @@ from ....tools.decorators import method from ....tools.utils import check_version +from .xgboost import _xgboost +from typing import Optional import functools @@ -104,7 +106,15 @@ def _scanvi(adata, test=False, n_hidden=None, n_latent=None, n_layers=None): return preds -def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=None): +def _scanvi_scarches( + adata, + test=False, + n_hidden=None, + n_latent=None, + n_layers=None, + prediction_method="scanvi", +): + import numpy as np import scvi if test: @@ -116,11 +126,14 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N n_layers = n_layers or 2 n_hidden = n_hidden or 128 + unlabeled_category = "Unknown" + # new obs labels to mask test set + adata.obs["scanvi_labels"] = np.where( + adata.obs["is_train"], adata.obs["labels"], unlabeled_category + ) adata_train = adata[adata.obs["is_train"]].copy() - adata_train.obs["scanvi_labels"] = adata_train.obs["labels"].copy() adata_test = adata[~adata.obs["is_train"]].copy() - adata_test.obs["scanvi_labels"] = "Unknown" scvi.model.SCVI.setup_anndata( adata_train, batch_key="batch", labels_key="scanvi_labels" ) @@ -145,7 +158,9 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N train_kwargs["limit_train_batches"] = 10 train_kwargs["limit_val_batches"] = 10 scvi_model.train(**train_kwargs) - model = scvi.model.SCANVI.from_scvi_model(scvi_model, unlabeled_category="Unknown") + model = scvi.model.SCANVI.from_scvi_model( + scvi_model, unlabeled_category=unlabeled_category + ) model.train(**train_kwargs) query_model = scvi.model.SCANVI.load_query_data(adata_test, model) @@ -156,6 +171,15 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N train_kwargs["limit_val_batches"] = 10 query_model.train(plan_kwargs=dict(weight_decay=0.0), **train_kwargs) + if prediction_method == "scanvi": + preds = _pred_scanvi(adata, query_model) + elif prediction_method == "xgboost": + preds = _pred_xgb(adata, query_model, test=test) + + return preds + + +def _pred_scanvi(adata, query_model): # this is temporary and won't be used adata.obs["scanvi_labels"] = "Unknown" preds = query_model.predict(adata) @@ -164,6 +188,20 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N return preds +# note: could extend test option +def _pred_xgb( + adata, + query_model, + test=False, + num_round: Optional[int] = None, +): + adata.obsm["X_emb"] = query_model.get_latent_representation(adata) + adata = _xgboost( + adata, test=test, obsm="X_emb", num_round=num_round, tree_method="hist" + ) + return adata.obs["labels_pred"] + + @_scanvi_method(method_name="scANVI (All genes)") def scanvi_all_genes(adata, test=False): adata.obs["labels_pred"] = _scanvi(adata, test=test) @@ -194,3 +232,25 @@ def scarches_scanvi_hvg(adata, test=False): adata.obs["labels_pred"] = _scanvi_scarches(bdata, test=test) adata.uns["method_code_version"] = check_version("scvi-tools") return adata + + +@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (All genes)") +def scarches_scanvi_xgb_all_genes(adata, test=False): + adata.obs["labels_pred"] = _scanvi_scarches( + adata, test=test, prediction_method="xgboost" + ) + + adata.uns["method_code_version"] = check_version("scvi-tools") + return adata + + +@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (Seurat v3 2000 HVG)") +def scarches_scanvi_xgb_hvg(adata, test=False): + hvg_df = _hvg(adata, test) + bdata = adata[:, hvg_df.highly_variable].copy() + adata.obs["labels_pred"] = _scanvi_scarches( + bdata, test=test, prediction_method="xgboost" + ) + + adata.uns["method_code_version"] = check_version("scvi-tools") + return adata diff --git a/openproblems/tasks/label_projection/methods/xgboost.py b/openproblems/tasks/label_projection/methods/xgboost.py index 6efb895063..290a636dd0 100644 --- a/openproblems/tasks/label_projection/methods/xgboost.py +++ b/openproblems/tasks/label_projection/methods/xgboost.py @@ -22,7 +22,13 @@ ) -def _xgboost(adata, test: bool = False, num_round: Optional[int] = None): +def _xgboost( + adata, + test: bool = False, + obsm: Optional[str] = None, + num_round: Optional[int] = None, + **kwargs, +): import xgboost as xgb if test: @@ -36,12 +42,19 @@ def _xgboost(adata, test: bool = False, num_round: Optional[int] = None): adata_train = adata[adata.obs["is_train"]] adata_test = adata[~adata.obs["is_train"]].copy() - xg_train = xgb.DMatrix(adata_train.X, label=adata_train.obs["labels_int"]) - xg_test = xgb.DMatrix(adata_test.X, label=adata_test.obs["labels_int"]) + xg_train = xgb.DMatrix( + adata_train.obsm[obsm] if obsm else adata_train.X, + label=adata_train.obs["labels_int"], + ) + xg_test = xgb.DMatrix( + adata_test.obsm[obsm] if obsm else adata_test.X, + label=adata_test.obs["labels_int"], + ) param = dict( objective="multi:softmax", num_class=len(categories), + **kwargs, ) watchlist = [(xg_train, "train")]