Skip to content

Commit

Permalink
more coverage tests for train_more and plot
Browse files Browse the repository at this point in the history
  • Loading branch information
Bourne227 committed Jul 2, 2024
1 parent d64ff43 commit 189c097
Showing 1 changed file with 29 additions and 23 deletions.
52 changes: 29 additions & 23 deletions pyod/test/test_alad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import os
import sys
import unittest
from unittest.mock import patch

# noinspection PyProtectedMember
from numpy.testing import assert_equal
from numpy.testing import assert_raises
from sklearn.base import clone
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt

# temporary solution for relative imports in case pyod is not installed
# if pyod is installed, no need to use the following line
Expand Down Expand Up @@ -38,7 +40,6 @@ def setUp(self):
dropout_rate=0.2,
add_recon_loss=False,
lambda_recon_loss=0.05,
# only important when add_recon_loss = True
add_disc_zz_loss=True,
dec_layers=[75, 100],
enc_layers=[100, 75],
Expand Down Expand Up @@ -101,17 +102,14 @@ def test_prediction_proba_parameter(self):
self.clf.predict_proba(self.X_test, method='something')

def test_prediction_labels_confidence(self):
pred_labels, confidence = self.clf.predict(self.X_test,
return_confidence=True)
pred_labels, confidence = self.clf.predict(self.X_test, return_confidence=True)
assert_equal(pred_labels.shape, self.y_test.shape)
assert_equal(confidence.shape, self.y_test.shape)
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_proba_linear_confidence(self):
pred_proba, confidence = self.clf.predict_proba(self.X_test,
method='linear',
return_confidence=True)
pred_proba, confidence = self.clf.predict_proba(self.X_test, method='linear', return_confidence=True)
assert (pred_proba.min() >= 0)
assert (pred_proba.max() <= 1)

Expand All @@ -125,24 +123,21 @@ def test_fit_predict(self):

def test_fit_predict_score(self):
self.clf.fit_predict_score(self.X_test, self.y_test)
self.clf.fit_predict_score(self.X_test, self.y_test,
scoring='roc_auc_score')
self.clf.fit_predict_score(self.X_test, self.y_test,
scoring='prc_n_score')
self.clf.fit_predict_score(self.X_test, self.y_test, scoring='roc_auc_score')
self.clf.fit_predict_score(self.X_test, self.y_test, scoring='prc_n_score')
with assert_raises(NotImplementedError):
self.clf.fit_predict_score(self.X_test, self.y_test,
scoring='something')
self.clf.fit_predict_score(self.X_test, self.y_test, scoring='something')

def test_prediction_scores_with_sigmoid(self):
self.alad = ALAD(activation_hidden_gen='sigmoid', activation_hidden_disc='sigmoid')
self.alad.fit(self.X_train)

pred_scores = self.alad.predict(self.X_test)
roc_auc = roc_auc_score(self.y_test, pred_scores)
print(f"ROC AUC Score with Sigmoid: {roc_auc}")
self.assertGreaterEqual(roc_auc, 0)
self.alad = ALAD(activation_hidden_gen='sigmoid', activation_hidden_disc='sigmoid')
self.alad.fit(self.X_train)
pred_scores = self.alad.predict(self.X_test)

roc_auc = roc_auc_score(self.y_test, pred_scores)
print(f"ROC AUC Score with Sigmoid: {roc_auc}")

self.assertGreaterEqual(roc_auc, 0)

def test_prediction_scores_with_relu(self):
self.alad = ALAD(activation_hidden_gen='relu', activation_hidden_disc='relu')
Expand All @@ -155,14 +150,25 @@ def test_prediction_scores_with_relu(self):

self.assertGreaterEqual(roc_auc, 0)


def test_model_clone(self):
# for deep models this may not apply
clone_clf = clone(self.clf)

def test_train_more(self):
initial_scores = self.clf.decision_function(self.X_test)
self.clf.train_more(self.X_train, epochs=50)
new_scores = self.clf.decision_function(self.X_test)
assert (roc_auc_score(self.y_test, new_scores) >= self.roc_floor)
self.assertNotEqual(initial_scores.tolist(), new_scores.tolist(), "Scores should change after training more")

def test_plot_learning_curves(self):
with patch('matplotlib.pyplot.show'):
self.clf.plot_learning_curves()
plt.close('all')

def tearDown(self):
pass


if __name__ == '__main__':
unittest.main()
unittest.main()

0 comments on commit 189c097

Please sign in to comment.