diff --git a/tests/test_predictors.py b/tests/test_predictors.py index 4850615..8311a7b 100644 --- a/tests/test_predictors.py +++ b/tests/test_predictors.py @@ -148,6 +148,14 @@ def test_classifiers_x(X, classifier, key): kmeans = cfg_kmeans.create() kmeans(adata) + if classifier == "XGBClassifier": + kwargs = { + 'n_estimators': 1, + 'max_depth': 1, + } + else: + kwargs = {} + cfg = OmegaConf.create( { "_target_": f"src.grinch.{classifier}.Config", @@ -155,6 +163,7 @@ def test_classifiers_x(X, classifier, key): "y_key": f"obs.{OBS.KMEANS}", "seed": 42, "labels_key": f"obs.{key}", + **kwargs, } ) # Need to start using convert all for lists and dicts