Skip to content

Commit

Permalink
slips_metrics_getter.py: add a function to directly get the metrics o…
Browse files Browse the repository at this point in the history
…f a given threshold
  • Loading branch information
AlyaGomaa committed Nov 26, 2024
1 parent 592e182 commit 6140fb0
Showing 1 changed file with 45 additions and 3 deletions.
48 changes: 45 additions & 3 deletions scripts/slips_metrics_getter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from metrics.calculator import Calculator


THRESHOLDS_TO_BRUTEFORCE = range(1, 20)
THRESHOLDS_TO_BRUTEFORCE = range(1, 400)

def print_line():
print("-"*20)
Expand Down Expand Up @@ -270,7 +270,19 @@ def get_metrics_for_each_threshold_for_all_experiments(


def print_sum_of_metrics_for_each_threshold(
sum_of_confusion_matrix_per_threshold: dict):
sum_of_confusion_matrix_per_threshold: Dict[int, Dict[str, float]]) \
-> Dict[int, Dict[str, float]]:
"""
Calculates and prints the rest of the metrics
for each threshold in the given dict
:param sum_of_confusion_matrix_per_threshold: Dict with the sum of TP
FP TN FN for each threshold
returns somehting like this
{1: {'FPR': ., 'FNR': ., 'TPR': ., 'TNR': ., 'precision': ., 'F1': .,
'accuracy': ., 'MCC': .},
2: {'FPR': ., 'FNR': ., 'TPR': ., 'TNR': ., 'precision': ., 'F1': .,
'accuracy': ., 'MCC': .}
"""
print("Printing total error metrics for all thresholds")
# example of error_rates = {1: {'MCC': 0, 'FPR': 0.2, etc..}}
error_rates: Dict[int, Dict[str, float]] = {}
Expand All @@ -297,8 +309,38 @@ def print_sum_of_metrics_for_each_threshold(
print_line()
return error_rates

def main():

def print_error_rates_for_threshold(threshold):
"""use this when you want to know the metrics of a specific threshold
directly without checking the plots"""
print(f"Error Rates for threshold: {threshold}")
print_line()

# for each experiment, what are the TPR, FPR, MCC, FP, etc. that
# slips would detect if used the given threshold?
metrics: Dict[str, Dict[str, float]]
metrics = get_experiments_metrics_for_threshold(threshold)

# is something like
# {threshold: {confusion_matrix_sum}}
# e.g.
# {1: {'TP': sum, 'TN': 'FP': 'FN': ..},
# 2: {'TP':sum, 'TN': 'FP': , 'FN': },}
sum_of_confusion_matrix: Dict[int, Dict[str, float]] = \
get_sum_of_metrics_per_threshold(
{threshold: metrics},
[
'TP',
'TN',
'FP',
'FN',
]
)
print_sum_of_metrics_for_each_threshold(sum_of_confusion_matrix)



def main():
args = parse_args()

expirements_number = len(extracted_threat_levels)
Expand Down

0 comments on commit 6140fb0

Please sign in to comment.