Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions openproblems/tasks/label_projection/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .scvi_tools import scanvi_hvg
from .scvi_tools import scarches_scanvi_all_genes
from .scvi_tools import scarches_scanvi_hvg
from .scvi_tools import scarches_scanvi_xgb_all_genes
from .scvi_tools import scarches_scanvi_xgb_hvg
from .seurat import seurat
from .xgboost import xgboost_log_cp10k
from .xgboost import xgboost_scran
68 changes: 64 additions & 4 deletions openproblems/tasks/label_projection/methods/scvi_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ....tools.decorators import method
from ....tools.utils import check_version
from .xgboost import _xgboost
from typing import Optional

import functools

Expand Down Expand Up @@ -104,7 +106,15 @@ def _scanvi(adata, test=False, n_hidden=None, n_latent=None, n_layers=None):
return preds


def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=None):
def _scanvi_scarches(
adata,
test=False,
n_hidden=None,
n_latent=None,
n_layers=None,
prediction_method="scanvi",
):
import numpy as np
import scvi

if test:
Expand All @@ -116,11 +126,14 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
n_layers = n_layers or 2
n_hidden = n_hidden or 128

unlabeled_category = "Unknown"

# new obs labels to mask test set
adata.obs["scanvi_labels"] = np.where(
adata.obs["is_train"], adata.obs["labels"], unlabeled_category
)
adata_train = adata[adata.obs["is_train"]].copy()
adata_train.obs["scanvi_labels"] = adata_train.obs["labels"].copy()
adata_test = adata[~adata.obs["is_train"]].copy()
adata_test.obs["scanvi_labels"] = "Unknown"
scvi.model.SCVI.setup_anndata(
adata_train, batch_key="batch", labels_key="scanvi_labels"
)
Expand All @@ -145,7 +158,9 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
train_kwargs["limit_train_batches"] = 10
train_kwargs["limit_val_batches"] = 10
scvi_model.train(**train_kwargs)
model = scvi.model.SCANVI.from_scvi_model(scvi_model, unlabeled_category="Unknown")
model = scvi.model.SCANVI.from_scvi_model(
scvi_model, unlabeled_category=unlabeled_category
)
model.train(**train_kwargs)

query_model = scvi.model.SCANVI.load_query_data(adata_test, model)
Expand All @@ -156,6 +171,15 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
train_kwargs["limit_val_batches"] = 10
query_model.train(plan_kwargs=dict(weight_decay=0.0), **train_kwargs)

if prediction_method == "scanvi":
preds = _pred_scanvi(adata, query_model)
elif prediction_method == "xgboost":
preds = _pred_xgb(adata, query_model, test=test)

return preds


def _pred_scanvi(adata, query_model):
# this is temporary and won't be used
adata.obs["scanvi_labels"] = "Unknown"
preds = query_model.predict(adata)
Expand All @@ -164,6 +188,20 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
return preds


# note: could extend test option
def _pred_xgb(
adata,
query_model,
test=False,
num_round: Optional[int] = None,
):
adata.obsm["X_emb"] = query_model.get_latent_representation(adata)
adata = _xgboost(
adata, test=test, obsm="X_emb", num_round=num_round, tree_method="hist"
)
return adata.obs["labels_pred"]


@_scanvi_method(method_name="scANVI (All genes)")
def scanvi_all_genes(adata, test=False):
adata.obs["labels_pred"] = _scanvi(adata, test=test)
Expand Down Expand Up @@ -194,3 +232,25 @@ def scarches_scanvi_hvg(adata, test=False):
adata.obs["labels_pred"] = _scanvi_scarches(bdata, test=test)
adata.uns["method_code_version"] = check_version("scvi-tools")
return adata


@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (All genes)")
def scarches_scanvi_xgb_all_genes(adata, test=False):
adata.obs["labels_pred"] = _scanvi_scarches(
adata, test=test, prediction_method="xgboost"
)

adata.uns["method_code_version"] = check_version("scvi-tools")
return adata


@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (Seurat v3 2000 HVG)")
def scarches_scanvi_xgb_hvg(adata, test=False):
hvg_df = _hvg(adata, test)
bdata = adata[:, hvg_df.highly_variable].copy()
adata.obs["labels_pred"] = _scanvi_scarches(
bdata, test=test, prediction_method="xgboost"
)

adata.uns["method_code_version"] = check_version("scvi-tools")
return adata
19 changes: 16 additions & 3 deletions openproblems/tasks/label_projection/methods/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
)


def _xgboost(adata, test: bool = False, num_round: Optional[int] = None):
def _xgboost(
adata,
test: bool = False,
obsm: Optional[str] = None,
num_round: Optional[int] = None,
**kwargs,
):
import xgboost as xgb

if test:
Expand All @@ -36,12 +42,19 @@ def _xgboost(adata, test: bool = False, num_round: Optional[int] = None):
adata_train = adata[adata.obs["is_train"]]
adata_test = adata[~adata.obs["is_train"]].copy()

xg_train = xgb.DMatrix(adata_train.X, label=adata_train.obs["labels_int"])
xg_test = xgb.DMatrix(adata_test.X, label=adata_test.obs["labels_int"])
xg_train = xgb.DMatrix(
adata_train.obsm[obsm] if obsm else adata_train.X,
label=adata_train.obs["labels_int"],
)
xg_test = xgb.DMatrix(
adata_test.obsm[obsm] if obsm else adata_test.X,
label=adata_test.obs["labels_int"],
)

param = dict(
objective="multi:softmax",
num_class=len(categories),
**kwargs,
)

watchlist = [(xg_train, "train")]
Expand Down