-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinlegalbert-classification.py
63 lines (48 loc) · 2.59 KB
/
inlegalbert-classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/usr/bin/env python
# coding: utf-8
"""
Script to classify a list of input sentences in a CSV file as either regulatory or not and add a new column to the CSV file
recording the classification results
The output of this script is a CSV file
"""
import pandas as pd
import argparse
import sys
import os
from krippendorff import alpha
import xgboost
from train_inlegalbert_xgboost import create_features
from classify_text_with_inlegal_bert_xgboost import classify_texts
argParser = argparse.ArgumentParser(description='InLegalBERT classification of sentences from EU legislation as either regulatory or not')
required = argParser.add_argument_group('required arguments')
required.add_argument("-in", "--input", required=True, help="Path to input CSV file generated by https://github.com/nature-of-eu-rules/data-preprocessing/blob/main/extract_sentences.py or https://github.com/nature-of-eu-rules/regulatory-statement-classification/blob/main/rule-based-classification.py. Must contain at least one column 'sent' with English sentences")
required.add_argument("-m", "--model", required=True, help="Path to JSON input InLegalBERT model file trained using this script: https://github.com/nature-of-eu-rules/regulatory-statement-classification/blob/main/train_inlegalbert_xgboost.py")
required.add_argument("-out", "--output", required=True, help="Path to output CSV file")
# Optional
argParser.add_argument("-col", "--column", help="Column name in input CSV which holds the texts to classify")
args = argParser.parse_args()
IN_FNAME = str(args.input) # Input filename
MODEL_FNAME = str(args.model) # Model filename
OUT_FNAME = str(args.output) # Output filename
COL_NAME = 'sent'
# Import data
data_df = pd.read_csv(IN_FNAME)
# Initialise column with texts to classify
if args.column:
COL_NAME = str(args.column)
# If column does not exist, exit
if COL_NAME not in data_df.columns:
sys.exit('Column does not exist in input file!')
# Classify sentences
sentences = data_df[COL_NAME].tolist()
sent_classif = classify_texts(sentences, MODEL_FNAME)
# Append classification results for all sentences as a new column to input dataframe
data_df['regulatory_according_to_inlegalbert'] = sent_classif
# Extract the two columns you want to compare
rater_columns = ['regulatory_according_to_rule', 'regulatory_according_to_inlegalbert']
data_to_compare = data_df[rater_columns]
# # Calculate Krippendorff's alpha for the selected columns
# alpha_value = alpha(data=data_to_compare, level_of_measurement='interval')
# print(f"Krippendorff's alpha: {alpha_value:.2f}")
# Save the new dataframe to file
data_df.to_csv(OUT_FNAME, index=False)