-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtext_knowledge_graph.py
238 lines (196 loc) · 7.44 KB
/
text_knowledge_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
#!/usr/bin/python3
'''
Author: Ambareesh Ravi
Date: 26 July, 2021
File: text_knowledge_graph.py
Description:
Creates and visualizes a knowledge from textual data using Natural Language Processing.
Has applications in medicine, finance, recommendation systems, fraud detection, trading etc.
'''
# Library imports
import argparse
import numpy as np
import pandas as pd
import spacy
from spacy.matcher import Matcher
import networkx as nx
import matplotlib.pyplot as plt
from tqdm import tqdm
# Module imports
from data import *
# Global variables
# Change to different language pack as required
nlp = spacy.load('en_core_web_sm')
class TextKnowledgeGraph:
# Creates and visualizes a knowledge graph from textual data
def __init__(self, data):
'''
Initializes the class
Args:
data - the text data as <pandas.DataFrame>
Returns:
-
Exception:
-
'''
self.data = data
# define pattern matching params
self.matcher = Matcher(nlp.vocab)
pattern = [
{'DEP':'ROOT'},
{'DEP':'prep','OP':"?"},
{'DEP':'agent','OP':"?"},
{'POS':'ADJ','OP':"?"}
]
self.matcher.add("matching_1", None, pattern)
# Build the knowledge graph
self.build()
def extract_entities(self, sentence):
'''
Extracts entities from a sentence using Spacy dependency parser
Args:
sentence - the input sentence as <str>
Returns:
pair of entities as <list>
Exception:
-
'''
entity1, entity2, prefix, modifier, prev_token_dep, prev_token_text = "", "", "", "", "", ""
for token in nlp(sentence):
# Skip punctuation
if token.dep_ == "punct": continue
# Check for compound sentence/ words
if token.dep_ == "compound":
prefix = token.text
# Check for and add the previous compound words
if prev_token_dep == "compound":
prefix = "%s %s"%(prev_token_text, token.text)
# Check if token is a modifier
if token.dep_.endswith("mod") == True:
modifier = token.text
# Check for and add the previous compound words
if prev_token_dep == "compound":
modifier = "%s %s"%(prev_token_text, token.text)
# Check if the word/ token is the subject
if token.dep_.find("subj") == True:
entity1 = "%s %s %s"%(modifier, prefix, token.text)
prefix, modifier, prev_token_dep, prev_token_text = "", "", "", ""
# Check if the word/ token is the object
if token.dep_.find("obj") == True:
entity2 = "%s %s %s"%(modifier, prefix, token.text)
# Update values
prev_token_dep, prev_token_text = token.dep_, token.text
# Return results
return [entity1.strip(), entity2.strip()]
def extract_relations(self, sentence):
'''
Extracts the relationships in the sentence
Args:
sentence - the input sentence as <str>
Returns:
relationship as <str>
Exception:
-
'''
doc = nlp(sentence)
matches = self.matcher(doc)
span = doc[matches[-1][1]:matches[-1][2]]
return span.text
def get_knowledge_graph_data(self, entity_pairs, relations):
'''
Creates and returns as dataframe for knowledge graph creation
Args:
entity_pairs - <list> of all entity pairs in the dataset
relations - <list> of all relationships between the entity pairs in the dataset
Returns:
data as <pandas.DataFrame>
Exception:
-
'''
ep_array = np.array(entity_pairs)
# subject [source] -> object [target]
kd_df = pd.DataFrame(
{
"source": ep_array[:,0],
"target": ep_array[:,1],
"edge": relations
}
)
return kd_df
def create_network(self, kd_df, key_relation = None):
'''
Creates directed graph from knowledge graph dataframe
Args:
kd_df - knowledge graph data as <pandas.DataFrame>
key_relation - a particular relationship to look for <str>
Returns:
graph as <nx.MultiDiGraph>
Exception:
-
'''
dir_graph = nx.from_pandas_edgelist(
df = kd_df[kd_df['edge'] == key_relation] if key_relation else kd_df,
source = 'source',
target = 'target',
edge_attr = True,
create_using = nx.MultiDiGraph()
)
return dir_graph
def plot_graph(self, dir_graph, figsize = (12,12), node_spacing = 0.5, node_size = 1000, node_color = 'skyblue'):
'''
Plots and displays the knowledge graph using matplotlib.pyplot
Args:
dir_graph - knowledge graph as <nx.MultiDiGraph>
figsize - size of the figure as a <tuple>
node_spacing - parameter to adjust the distance between nodes in the graph as <float>
node_size - maximum number of nodes as <int>
node_color - colour for the nodes as <str> [correspondingly color map has to be changed]
Returns:
-
Exception:
-
'''
plt.figure(figsize = figsize)
pos = nx.spring_layout(dir_graph, k = node_spacing)
nx.draw(dir_graph, with_labels = True, node_color = node_color, node_size = node_size, edge_cmap = plt.cm.Blues, pos = pos)
plt.show()
def build(self,):
'''
Builds the knowledge graph internally and stores it in a dataframe
Args:
-
Returns:
-
Exception:
-
'''
entity_pairs = [self.extract_entities(sent) for sent in tqdm(self.data["sentence"])]
relations = [self.extract_relations(sent) for sent in tqdm(self.data['sentence'])]
self.kd_df = self.get_knowledge_graph_data(entity_pairs, relations)
def get_by_relationship(self, relationship):
'''
Dynamically generates and visualizes the part of the graph based on the relationship
Args:
relationship - key relationship to look for as <str>
Returns:
-
Exception:
-
'''
dir_graph = self.create_network(self.kd_df, relationship)
self.plot_graph(dir_graph)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, default = "data/wikipedia_sentences.csv", help="Path to the data csv file")
parser.add_argument("--relationship", type=str, default = None, help="A relationship between entities to be observed. If left empty, the tool will show EVERYTHING!")
args = parser.parse_args()
# Load data
data = Dataset(args.data_path)
# Create an object for the knowledge graph
kg = TextKnowledgeGraph(data()) # data() is same as data.df
# Visualize based on the relationships
# kg.get_by_relationship("written by")
# kg.get_by_relationship("directed by")
# kg.get_by_relationship("includes")
# kg.get_by_relationship("composed by")
kg.get_by_relationship(args.relationship)