Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
plot_confusion_matrix deprecated from scikit-learn 1.0 and removed in 1.2
  • Loading branch information
edwenger authored Jun 6, 2023
1 parent b07a0e6 commit 5a3d305
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import plot_confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np

Expand All @@ -29,5 +29,5 @@
outfile.write(metrics)

# Plot it
disp = plot_confusion_matrix(clf, X_test, y_test, normalize="true", cmap=plt.cm.Blues)
disp = ConfusionMatrixDisplay.from_estimator(clf, X_test, y_test, normalize="true", cmap=plt.cm.Blues)
plt.savefig("plot.png")

0 comments on commit 5a3d305

Please sign in to comment.