-
Notifications
You must be signed in to change notification settings - Fork 178
/
catboost_pruning.py
81 lines (61 loc) · 2.58 KB
/
catboost_pruning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
Optuna example that demonstrates a pruner for CatBoost.
In this example, we optimize the validation accuracy of cancer detection using CatBoost.
We optimize both the choice of booster models and their hyperparameters. Throughout
training of models, a pruner observes intermediate results and stop unpromising trials.
You can run this example as follows:
$ python catboost_pruning.py
"""
import numpy as np
import optuna
from optuna.integration import CatBoostPruningCallback
import catboost as cb
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
def objective(trial: optuna.Trial) -> float:
data, target = load_breast_cancer(return_X_y=True)
train_x, valid_x, train_y, valid_y = train_test_split(data, target, test_size=0.25)
param = {
"objective": trial.suggest_categorical("objective", ["Logloss", "CrossEntropy"]),
"colsample_bylevel": trial.suggest_float("colsample_bylevel", 0.01, 0.1, log=True),
"depth": trial.suggest_int("depth", 1, 12),
"boosting_type": trial.suggest_categorical("boosting_type", ["Ordered", "Plain"]),
"bootstrap_type": trial.suggest_categorical(
"bootstrap_type", ["Bayesian", "Bernoulli", "MVS"]
),
"used_ram_limit": "3gb",
"eval_metric": "Accuracy",
}
if param["bootstrap_type"] == "Bayesian":
param["bagging_temperature"] = trial.suggest_float("bagging_temperature", 0, 10)
elif param["bootstrap_type"] == "Bernoulli":
param["subsample"] = trial.suggest_float("subsample", 0.1, 1, log=True)
gbm = cb.CatBoostClassifier(**param)
pruning_callback = CatBoostPruningCallback(trial, "Accuracy")
gbm.fit(
train_x,
train_y,
eval_set=[(valid_x, valid_y)],
verbose=0,
early_stopping_rounds=100,
callbacks=[pruning_callback],
)
# evoke pruning manually.
pruning_callback.check_pruned()
preds = gbm.predict(valid_x)
pred_labels = np.rint(preds)
accuracy = accuracy_score(valid_y, pred_labels)
return accuracy
if __name__ == "__main__":
study = optuna.create_study(
pruner=optuna.pruners.MedianPruner(n_warmup_steps=5), direction="maximize"
)
study.optimize(objective, n_trials=100, timeout=600)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print(" Value: {}".format(trial.value))
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))