-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtree_parser.py
220 lines (196 loc) · 8.37 KB
/
tree_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import torch
from torch.nn.utils.rnn import pad_sequence
def tarjan(sequence):
r"""
Tarjan algorithm for finding Strongly Connected Components (SCCs) of a graph.
Args:
sequence (list):
List of head indices.
Yields:
A list of indices that make up a SCC. All self-loops are ignored.
Examples:
next(tarjan([2, 5, 0, 3, 1])) # (1 -> 5 -> 2 -> 1) is a cycle
[2, 5, 1]
"""
# sequence = [-1] + sequence
sequence = [-1, -1] + sequence
# record the search order, i.e., the timestep
dfn = [-1] * len(sequence)
# record the the smallest timestep in a SCC
low = [-1] * len(sequence)
# push the visited into the stack
stack, onstack = [], [False] * len(sequence)
def connect(i, timestep):
dfn[i] = low[i] = timestep[0]
timestep[0] += 1
stack.append(i)
onstack[i] = True
for j, head in enumerate(sequence):
if head != i:
continue
if dfn[j] == -1:
yield from connect(j, timestep)
low[i] = min(low[i], low[j])
elif onstack[j]:
low[i] = min(low[i], dfn[j])
# a SCC is completed
if low[i] == dfn[i]:
cycle = [stack.pop()]
while cycle[-1] != i:
onstack[cycle[-1]] = False
cycle.append(stack.pop())
onstack[i] = False
# ignore the self-loop
if len(cycle) > 1:
yield cycle
timestep = [0]
for i in range(len(sequence)):
if dfn[i] == -1:
yield from connect(i, timestep)
def chuliu_edmonds(s, device):
r"""
ChuLiu/Edmonds algorithm for non-projective decoding.
Some code is borrowed from `tdozat's implementation`_.
Descriptions of notations and formulas can be found in
`Non-projective Dependency Parsing using Spanning Tree Algorithms`_.
Notes:
The algorithm does not guarantee to parse a single-root tree.
References:
- Ryan McDonald, Fernando Pereira, Kiril Ribarov and Jan Hajic. 2005.
`Non-projective Dependency Parsing using Spanning Tree Algorithms`_.
Args:
s (~torch.Tensor): ``[seq_len, seq_len]``.
Scores of all dependent-head pairs.
Returns:
~torch.Tensor:
A tensor with shape ``[seq_len]`` for the resulting non-projective parse tree.
.. _tdozat's implementation:
https://github.com/tdozat/Parser-v3
.. _Non-projective Dependency Parsing using Spanning Tree Algorithms:
https://www.aclweb.org/anthology/H05-1066/
"""
s = s.to(device)
s[0, 1:] = float('-inf')
# prevent self-loops
s.diagonal()[1:].fill_(float('-inf'))
# select heads with highest scores
tree = s.argmax(-1)
# return the cycle finded by tarjan algorithm lazily
# cycle = next(tarjan(tree.tolist()[1:]), None)
cycle = next(tarjan(tree.tolist()[2:]), None)
# if the tree has no cycles, then it is a MST
if not cycle:
return tree
# indices of cycle in the original tree
# cycle = torch.tensor(cycle).to(device)
cycle = next(tarjan(tree.tolist()[2:]), None)
cycle = torch.tensor(cycle)
# indices of noncycle in the original tree
noncycle = torch.ones(len(s)).index_fill_(0, cycle, 0)
noncycle = torch.where(noncycle.gt(0))[0]
noncycle = noncycle.to(device)
def contract(s, device):
# heads of cycle in original tree
cycle_heads = tree[cycle]
# scores of cycle in original tree
s_cycle = s[cycle, cycle_heads]
# calculate the scores of cycle's potential dependents
# s(c->x) = max(s(x'->x)), x in noncycle and x' in cycle
s_dep = s[noncycle][:, cycle]
# find the best cycle head for each noncycle dependent
deps = s_dep.argmax(1)
# calculate the scores of cycle's potential heads
# s(x->c) = max(s(x'->x) - s(a(x')->x') + s(cycle)), x in noncycle and x' in cycle
# a(v) is the predecessor of v in cycle
# s(cycle) = sum(s(a(v)->v))
s_head = s[cycle][:, noncycle] - s_cycle.view(-1, 1) + s_cycle.sum()
# find the best noncycle head for each cycle dependent
heads = s_head.argmax(0)
contracted = torch.cat((noncycle, torch.tensor([-1]).to(device)))
# calculate the scores of contracted graph
s = s[contracted][:, contracted]
# set the contracted graph scores of cycle's potential dependents
s[:-1, -1] = s_dep[range(len(deps)), deps]
# set the contracted graph scores of cycle's potential heads
s[-1, :-1] = s_head[heads, range(len(heads))]
return s.to(device), heads.to(device), deps.to(device)
# keep track of the endpoints of the edges into and out of cycle for reconstruction later
s, heads, deps = contract(s, device)
# y is the contracted tree
y = chuliu_edmonds(s, device)
# exclude head of cycle from y
y, cycle_head = y[:-1], y[-1]
# fix the subtree with no heads coming from the cycle
# len(y) denotes heads coming from the cycle
subtree = y < len(y)
subtree = subtree.to(device)
# add the nodes to the new tree
tree[noncycle[subtree]] = noncycle[y[subtree]]
# fix the subtree with heads coming from the cycle
subtree = ~subtree
subtree = subtree.to(device)
cycle = cycle.to(device)
# add the nodes to the tree
tree[noncycle[subtree]] = cycle[deps[subtree]]
# fix the root of the cycle
cycle_root = heads[cycle_head]
# break the cycle and add the root of the cycle to the tree
tree[cycle[cycle_root]] = noncycle[cycle_head]
return tree.to(device)
def mst(scores, mask, device, multiroot=False):
r"""
MST algorithm for decoding non-pojective trees.
This is a wrapper for ChuLiu/Edmonds algorithm.
The algorithm first runs ChuLiu/Edmonds to parse a tree and then have a check of multi-roots,
If ``multiroot=True`` and there indeed exist multi-roots, the algorithm seeks to find
best single-root trees by iterating all possible single-root trees parsed by ChuLiu/Edmonds.
Otherwise the resulting trees are directly taken as the final outputs.
Args:
scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
Scores of all dependent-head pairs.
mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
The mask to avoid parsing over padding tokens.
The first column serving as pseudo words for roots should be ``False``.
muliroot (bool):
Ensures to parse a single-root tree If ``False``.
Returns:
~torch.Tensor:
A tensor with shape ``[batch_size, seq_len]`` for the resulting non-projective parse trees.
Examples:
scores = torch.tensor([[[-11.9436, -13.1464, -6.4789, -13.8917],
[-60.6957, -60.2866, -48.6457, -63.8125],
[-38.1747, -49.9296, -45.2733, -49.5571],
[-19.7504, -23.9066, -9.9139, -16.2088]]])
scores[:, 0, 1:] = float('-inf')
scores.diagonal(0, 1, 2)[1:].fill_(float('-inf'))
mask = torch.tensor([[False, True, True, True]])
mst(scores, mask)
tensor([[0, 2, 0, 2]])
"""
batch_size, seq_len, _ = scores.shape
preds = []
for i, length in enumerate(mask.sum(1).tolist()):
s = scores[i][:length+2, :length+2]
tree = chuliu_edmonds(s, device)
roots = torch.where(tree[2:].eq(0))[0] + 2
if not multiroot and len(roots) > 1:
s_root = s[:, 0]
s_best = float('-inf')
s = s.index_fill(1, torch.tensor(0).to(device), float('-inf'))
for root in roots:
s[:, 0] = float('-inf')
s[root, 0] = s_root[root]
t = chuliu_edmonds(s, device)
s_tree = s[2:].gather(1, t[2:].unsqueeze(-1)).sum()
if s_tree > s_best:
s_best, tree = s_tree, t
# remove root head
# tree = tree[1:].tolist()
# tree = tree[1:]
tree = tree[2:]
preds.append(tree)
# return pad(preds, total_length=seq_len).to(mask.device)
# return pad_sequence(preds).to(mask.device)
# return pad_sequence(preds, batch_first=True, padding_value=padding_ind)
return torch.cat(preds)
# return preds