From c6514f15a8b162855aa77907b6a1135e4dd843fb Mon Sep 17 00:00:00 2001 From: LuckyMD Date: Mon, 18 Jul 2022 19:27:55 +0200 Subject: [PATCH 1/3] added scvi-scarches-xgb to label projection --- .../label_projection/methods/scvi_tools.py | 88 ++++++++++++++++++- 1 file changed, 87 insertions(+), 1 deletion(-) diff --git a/openproblems/tasks/label_projection/methods/scvi_tools.py b/openproblems/tasks/label_projection/methods/scvi_tools.py index 05cd51825d..5191e89180 100644 --- a/openproblems/tasks/label_projection/methods/scvi_tools.py +++ b/openproblems/tasks/label_projection/methods/scvi_tools.py @@ -86,7 +86,7 @@ def _scanvi(adata, test=False, n_hidden=None, n_latent=None, n_layers=None): return preds -def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=None): +def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=None, prediction_method='scanvi'): import scvi if test: @@ -138,6 +138,15 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N train_kwargs["limit_val_batches"] = 10 query_model.train(plan_kwargs=dict(weight_decay=0.0), **train_kwargs) + if prediction_method == 'scanvi': + preds = _pred_scanvi(adata, query_model) + elif prediction_method == 'xgboost': + preds = _pred_xgb(adata, adata_train, adata_test, query_model, test=test) + + return preds + + +def _pred_scanvi(adata, query_model): # this is temporary and won't be used adata.obs["scanvi_labels"] = "Unknown" preds = query_model.predict(adata) @@ -145,6 +154,58 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N # predictions for train and test return preds +# note: add test option here +def _pred_xgb( + adata, + adata_train, + adata_test, + query_model, + label_col='labels', + test=False, + num_round: Optional[int] = None +): + import xgboost as xgb + + df = _classif_df(adata_train, query_model, label_col) + + X_train = df.drop(columns='labels') + y_train = df['labels'] + + X_test = query_model.get_latent_representation(adata_test) + + if test: + num_round = num_round or 2 + else: + num_round = num_round or 5 + + xgbc = xgb.XGBClassifier( + tree_method = 'hist', + objective = 'multi:softprob' + ) + + + xgbc.fit(X_train, y_train) + + adata_test.obs['preds_test'] = xgbc.predict(X_test) + + preds = [adata_test.obs['preds_test'][idx] if idx in adata_test.obs_names else np.nan + for idx in adata.obs_names] + + return preds + + +def _classif_df(adata, trained_model, label_col): + emb_data = trained_model.get_latent_representation(adata) + + df = pd.DataFrame( + data = emb_data, + index = adata.obs_names + ) + + df['labels'] = adata.obs[label_col] + + return df + @_scanvi_method(method_name="scANVI (All genes)") def scanvi_all_genes(adata, test=False): @@ -176,3 +237,28 @@ def scarches_scanvi_hvg(adata, test=False): adata.obs["labels_pred"] = _scanvi_scarches(bdata, test=test) adata.uns["method_code_version"] = check_version("scvi-tools") return adata + +@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (All genes)") +def scarches_scanvi_all_genes(adata, test=False): + adata.obs["labels_pred"] = _scanvi_scarches( + adata, + test=test, + prediction_method='xgboost' + ) + + adata.uns["method_code_version"] = check_version("scvi-tools") + return adata + + +@_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (Seurat v3 2000 HVG)") +def scarches_scanvi_hvg(adata, test=False): + hvg_df = _hvg(adata, test) + bdata = adata[:, hvg_df.highly_variable].copy() + adata.obs["labels_pred"] = _scanvi_scarches( + bdata, + test=test, + prediction_method='xgboost' + ) + + adata.uns["method_code_version"] = check_version("scvi-tools") + return adata From 999d0a633e2d422a5d839916c1186f2db223307b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 18 Jul 2022 17:31:48 +0000 Subject: [PATCH 2/3] pre-commit --- .../label_projection/methods/scvi_tools.py | 76 +++++++++---------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/openproblems/tasks/label_projection/methods/scvi_tools.py b/openproblems/tasks/label_projection/methods/scvi_tools.py index 5191e89180..437c588019 100644 --- a/openproblems/tasks/label_projection/methods/scvi_tools.py +++ b/openproblems/tasks/label_projection/methods/scvi_tools.py @@ -86,7 +86,14 @@ def _scanvi(adata, test=False, n_hidden=None, n_latent=None, n_layers=None): return preds -def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=None, prediction_method='scanvi'): +def _scanvi_scarches( + adata, + test=False, + n_hidden=None, + n_latent=None, + n_layers=None, + prediction_method="scanvi", +): import scvi if test: @@ -138,11 +145,11 @@ def _scanvi_scarches(adata, test=False, n_hidden=None, n_latent=None, n_layers=N train_kwargs["limit_val_batches"] = 10 query_model.train(plan_kwargs=dict(weight_decay=0.0), **train_kwargs) - if prediction_method == 'scanvi': + if prediction_method == "scanvi": preds = _pred_scanvi(adata, query_model) - elif prediction_method == 'xgboost': + elif prediction_method == "xgboost": preds = _pred_xgb(adata, adata_train, adata_test, query_model, test=test) - + return preds @@ -154,22 +161,23 @@ def _pred_scanvi(adata, query_model): # predictions for train and test return preds + # note: add test option here def _pred_xgb( - adata, - adata_train, - adata_test, - query_model, - label_col='labels', - test=False, - num_round: Optional[int] = None + adata, + adata_train, + adata_test, + query_model, + label_col="labels", + test=False, + num_round: Optional[int] = None, ): import xgboost as xgb - + df = _classif_df(adata_train, query_model, label_col) - X_train = df.drop(columns='labels') - y_train = df['labels'] + X_train = df.drop(columns="labels") + y_train = df["labels"] X_test = query_model.get_latent_representation(adata_test) @@ -178,32 +186,27 @@ def _pred_xgb( else: num_round = num_round or 5 - xgbc = xgb.XGBClassifier( - tree_method = 'hist', - objective = 'multi:softprob' - ) + xgbc = xgb.XGBClassifier(tree_method="hist", objective="multi:softprob") - xgbc.fit(X_train, y_train) - adata_test.obs['preds_test'] = xgbc.predict(X_test) + adata_test.obs["preds_test"] = xgbc.predict(X_test) + + preds = [ + adata_test.obs["preds_test"][idx] if idx in adata_test.obs_names else np.nan + for idx in adata.obs_names + ] - preds = [adata_test.obs['preds_test'][idx] if idx in adata_test.obs_names else np.nan - for idx in adata.obs_names] - return preds def _classif_df(adata, trained_model, label_col): emb_data = trained_model.get_latent_representation(adata) - - df = pd.DataFrame( - data = emb_data, - index = adata.obs_names - ) - - df['labels'] = adata.obs[label_col] - + + df = pd.DataFrame(data=emb_data, index=adata.obs_names) + + df["labels"] = adata.obs[label_col] + return df @@ -238,12 +241,11 @@ def scarches_scanvi_hvg(adata, test=False): adata.uns["method_code_version"] = check_version("scvi-tools") return adata + @_scanvi_scarches_method(method_name="scArches+scANVI+xgboost (All genes)") def scarches_scanvi_all_genes(adata, test=False): adata.obs["labels_pred"] = _scanvi_scarches( - adata, - test=test, - prediction_method='xgboost' + adata, test=test, prediction_method="xgboost" ) adata.uns["method_code_version"] = check_version("scvi-tools") @@ -255,10 +257,8 @@ def scarches_scanvi_hvg(adata, test=False): hvg_df = _hvg(adata, test) bdata = adata[:, hvg_df.highly_variable].copy() adata.obs["labels_pred"] = _scanvi_scarches( - bdata, - test=test, - prediction_method='xgboost' + bdata, test=test, prediction_method="xgboost" ) - + adata.uns["method_code_version"] = check_version("scvi-tools") return adata From 7b60383777ad19ae3d7fdc3817f91c553de3757c Mon Sep 17 00:00:00 2001 From: LuckyMD Date: Mon, 18 Jul 2022 19:32:10 +0200 Subject: [PATCH 3/3] added methods to init and renamed --- openproblems/tasks/label_projection/methods/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/openproblems/tasks/label_projection/methods/__init__.py b/openproblems/tasks/label_projection/methods/__init__.py index a5b099a5ae..1330b09e12 100644 --- a/openproblems/tasks/label_projection/methods/__init__.py +++ b/openproblems/tasks/label_projection/methods/__init__.py @@ -10,5 +10,7 @@ from .scvi_tools import scanvi_hvg from .scvi_tools import scarches_scanvi_all_genes from .scvi_tools import scarches_scanvi_hvg +from .scvi_tools import scarches_scanvi_xgb_all_genes +from .scvi_tools import scarches_scanvi_xgb_hvg from .xgboost import xgboost_log_cpm from .xgboost import xgboost_scran