Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions openproblems/tasks/label_projection/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@
from .scvi_tools import scanvi_hvg
from .scvi_tools import scarches_scanvi_all_genes
from .scvi_tools import scarches_scanvi_hvg
from .scvi_tools import scarches_scanvi_xgb_all_genes
from .scvi_tools import scarches_scanvi_xgb_hvg
from .xgboost import xgboost_log_cpm
from .xgboost import xgboost_scran
88 changes: 87 additions & 1 deletion openproblems/tasks/label_projection/methods/scvi_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,14 @@ def _scanvi(adata, test=False, n_hidden=None, n_latent=None, n_layers=None):
return preds


def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=None):
def _scanvi_scarches(
adata,
test=False,
n_hidden=None,
n_latent=None,
n_layers=None,
prediction_method="scanvi",
):
import scvi

if test:
Expand Down Expand Up @@ -138,6 +145,15 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
train_kwargs["limit_val_batches"] = 10
query_model.train(plan_kwargs=dict(weight_decay=0.0), **train_kwargs)

if prediction_method == "scanvi":
preds = _pred_scanvi(adata, query_model)
elif prediction_method == "xgboost":
preds = _pred_xgb(adata, adata_train, adata_test, query_model, test=test)

return preds


def _pred_scanvi(adata, query_model):
# this is temporary and won't be used
adata.obs["scanvi_labels"] = "Unknown"
preds = query_model.predict(adata)
Expand All @@ -146,6 +162,54 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N
return preds


# note: add test option here
def _pred_xgb(
adata,
adata_train,
adata_test,
query_model,
label_col="labels",
test=False,
num_round: Optional[int] = None,
):
import xgboost as xgb

df = _classif_df(adata_train, query_model, label_col)

X_train = df.drop(columns="labels")
y_train = df["labels"]

X_test = query_model.get_latent_representation(adata_test)

if test:
num_round = num_round or 2
else:
num_round = num_round or 5

xgbc = xgb.XGBClassifier(tree_method="hist", objective="multi:softprob")

xgbc.fit(X_train, y_train)

adata_test.obs["preds_test"] = xgbc.predict(X_test)

preds = [
adata_test.obs["preds_test"][idx] if idx in adata_test.obs_names else np.nan
for idx in adata.obs_names
]

return preds


def _classif_df(adata, trained_model, label_col):
emb_data = trained_model.get_latent_representation(adata)

df = pd.DataFrame(data=emb_data, index=adata.obs_names)

df["labels"] = adata.obs[label_col]

return df


@_scanvi_method(method_name="scANVI (All genes)")
def scanvi_all_genes(adata, test=False):
adata.obs["labels_pred"] = _scanvi(adata, test=test)
Expand Down Expand Up @@ -176,3 +240,25 @@ def scarches_scanvi_hvg(adata, test=False):
adata.obs["labels_pred"] = _scanvi_scarches(bdata, test=test)
adata.uns["method_code_version"] = check_version("scvi-tools")
return adata


@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (All genes)")
def scarches_scanvi_all_genes(adata, test=False):
adata.obs["labels_pred"] = _scanvi_scarches(
adata, test=test, prediction_method="xgboost"
)

adata.uns["method_code_version"] = check_version("scvi-tools")
return adata


@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (Seurat v3 2000 HVG)")
def scarches_scanvi_hvg(adata, test=False):
hvg_df = _hvg(adata, test)
bdata = adata[:, hvg_df.highly_variable].copy()
adata.obs["labels_pred"] = _scanvi_scarches(
bdata, test=test, prediction_method="xgboost"
)

adata.uns["method_code_version"] = check_version("scvi-tools")
return adata