From 6140fb0e2792adf40b61584c1c665a9ee263dd10 Mon Sep 17 00:00:00 2001 From: alya Date: Tue, 26 Nov 2024 14:04:51 +0200 Subject: [PATCH] slips_metrics_getter.py: add a function to directly get the metrics of a given threshold --- scripts/slips_metrics_getter.py | 48 ++++++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/scripts/slips_metrics_getter.py b/scripts/slips_metrics_getter.py index fef2501..394ac3e 100644 --- a/scripts/slips_metrics_getter.py +++ b/scripts/slips_metrics_getter.py @@ -16,7 +16,7 @@ from metrics.calculator import Calculator -THRESHOLDS_TO_BRUTEFORCE = range(1, 20) +THRESHOLDS_TO_BRUTEFORCE = range(1, 400) def print_line(): print("-"*20) @@ -270,7 +270,19 @@ def get_metrics_for_each_threshold_for_all_experiments( def print_sum_of_metrics_for_each_threshold( - sum_of_confusion_matrix_per_threshold: dict): + sum_of_confusion_matrix_per_threshold: Dict[int, Dict[str, float]]) \ + -> Dict[int, Dict[str, float]]: + """ + Calculates and prints the rest of the metrics + for each threshold in the given dict + :param sum_of_confusion_matrix_per_threshold: Dict with the sum of TP + FP TN FN for each threshold + returns somehting like this + {1: {'FPR': ., 'FNR': ., 'TPR': ., 'TNR': ., 'precision': ., 'F1': ., + 'accuracy': ., 'MCC': .}, + 2: {'FPR': ., 'FNR': ., 'TPR': ., 'TNR': ., 'precision': ., 'F1': ., + 'accuracy': ., 'MCC': .} + """ print("Printing total error metrics for all thresholds") # example of error_rates = {1: {'MCC': 0, 'FPR': 0.2, etc..}} error_rates: Dict[int, Dict[str, float]] = {} @@ -297,8 +309,38 @@ def print_sum_of_metrics_for_each_threshold( print_line() return error_rates -def main(): + +def print_error_rates_for_threshold(threshold): + """use this when you want to know the metrics of a specific threshold + directly without checking the plots""" + print(f"Error Rates for threshold: {threshold}") + print_line() + # for each experiment, what are the TPR, FPR, MCC, FP, etc. that + # slips would detect if used the given threshold? + metrics: Dict[str, Dict[str, float]] + metrics = get_experiments_metrics_for_threshold(threshold) + + # is something like + # {threshold: {confusion_matrix_sum}} + # e.g. + # {1: {'TP': sum, 'TN': 'FP': 'FN': ..}, + # 2: {'TP':sum, 'TN': 'FP': , 'FN': },} + sum_of_confusion_matrix: Dict[int, Dict[str, float]] = \ + get_sum_of_metrics_per_threshold( + {threshold: metrics}, + [ + 'TP', + 'TN', + 'FP', + 'FN', + ] + ) + print_sum_of_metrics_for_each_threshold(sum_of_confusion_matrix) + + + +def main(): args = parse_args() expirements_number = len(extracted_threat_levels)