-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathweat_score_replication_glove.py
122 lines (94 loc) · 3.78 KB
/
weat_score_replication_glove.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import json
import unittest
from ..constants import WEAT_TEST_TOLERANCE
from ..word_vectors import WordVectors
from ..weat_test import weat_score
class TestWeatGloVe(unittest.TestCase):
# Load test data from file
with open("ddo/tests/weat_tests.json", "r") as f:
weat_tests = json.load(f)
# Loading word vectors
vectors = WordVectors("glove")
def _calculate_weat_score(self, test_data):
"""Simple test helper function to calculate and return the WEAT score for given data."""
# Make the out-of-vocabulary tokens easier to access
oovs = test_data["out_of_vocabularies"]["glove"]
# Generate and return test scores
return weat_score(
target_words_X=[x for x in test_data["X"] if x not in oovs["X"]],
target_words_Y=[y for y in test_data["Y"] if y not in oovs["Y"]],
attribute_words_a=[a for a in test_data["A"] if a not in oovs["A"]],
attribute_words_b=[b for b in test_data["B"] if b not in oovs["B"]],
word_vector_getter=self.__class__.vectors)[0]
def test_one(self):
# Retrieving test data
test_data = self.__class__.weat_tests["test1"]
self.assertAlmostEqual(
self._calculate_weat_score(test_data),
test_data["glove_result"],
delta=WEAT_TEST_TOLERANCE)
def test_two(self):
# Retrieving test data
test_data = self.__class__.weat_tests["test2"]
self.assertAlmostEqual(
self._calculate_weat_score(test_data),
test_data["glove_result"],
delta=WEAT_TEST_TOLERANCE)
def test_three(self):
# Retrieving test data
test_data = self.__class__.weat_tests["test3"]
self.assertAlmostEqual(
self._calculate_weat_score(test_data),
test_data["glove_result"],
delta=WEAT_TEST_TOLERANCE)
def test_four(self):
# Retrieving test data
test_data = self.__class__.weat_tests["test4"]
self.assertAlmostEqual(
self._calculate_weat_score(test_data),
test_data["glove_result"],
delta=WEAT_TEST_TOLERANCE)
def test_five(self):
# Retrieving test data
test_data = self.__class__.weat_tests["test5"]
self.assertAlmostEqual(
self._calculate_weat_score(test_data),
test_data["glove_result"],
delta=WEAT_TEST_TOLERANCE)
def test_six(self):
# Retrieving test data
test_data = self.__class__.weat_tests["test6"]
self.assertAlmostEqual(
self._calculate_weat_score(test_data),
test_data["glove_result"],
delta=WEAT_TEST_TOLERANCE)
def test_seven(self):
# Retrieving test data
test_data = self.__class__.weat_tests["test7"]
self.assertAlmostEqual(
self._calculate_weat_score(test_data),
test_data["glove_result"],
delta=WEAT_TEST_TOLERANCE)
def test_eight(self):
# Retrieving test data
test_data = self.__class__.weat_tests["test8"]
self.assertAlmostEqual(
self._calculate_weat_score(test_data),
test_data["glove_result"],
delta=WEAT_TEST_TOLERANCE)
def test_nine(self):
# Retrieving test data
test_data = self.__class__.weat_tests["test9"]
self.assertAlmostEqual(
self._calculate_weat_score(test_data),
test_data["glove_result"],
delta=WEAT_TEST_TOLERANCE)
def test_ten(self):
# Retrieving test data
test_data = self.__class__.weat_tests["test10"]
self.assertAlmostEqual(
self._calculate_weat_score(test_data),
test_data["glove_result"],
delta=WEAT_TEST_TOLERANCE)
if __name__ == "__main__":
unittest.main()