-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrie.py
156 lines (131 loc) · 4.02 KB
/
trie.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# encoding = utf-8
class Trie(object):
def __init__(self):
"""
Initialize your data structure here.
"""
self.root = {}
self.end = -1
def insert(self, word):
"""
Inserts a word into the trie.
:type word: str
:rtype: void
"""
curNode = self.root
for c in word:
if not c in curNode:
curNode[c] = {}
curNode = curNode[c]
curNode[self.end] = True
def search(self, word):
"""
Returns if the word is in the trie.
:type word: str
:rtype: bool
"""
curNode = self.root
for c in word:
if not c in curNode:
return False
curNode = curNode[c]
# Doesn't end here
if not self.end in curNode:
return False
return True
def startsWith(self, prefix):
"""
Returns if there is any word in the trie that starts with the given prefix.
:type prefix: str
:rtype: bool
"""
curNode = self.root
for c in prefix:
if not c in curNode:
return False
curNode = curNode[c]
return True
def get_start(self,prefix):
'''
给出一个前辍,打印出所有匹配的字符串
:param prefix:
:return:
'''
def get_key(pre,pre_node):
result = []
if pre_node.get(self.end):
result.append(pre)
for key in pre_node.keys():
if key != self.end:
result.extend(get_key(pre+key,pre_node.get(key)))
return result
if not self.startsWith(prefix):
return []
else:
node = self.root
for p in prefix:
node = node.get(p)
else:
return get_key(prefix,node)
@staticmethod
def levenshtein_dp(s: str, t: str) -> int:
'''
计算莱文斯坦距离(Levenshtein distance),距离越小,说明两个单词越相近
:param s:
:param t:
:return:
'''
m, n = len(s), len(t)
table = [[0] * (n + 1) for _ in range(m + 1)]
table[0] = [j for j in range(n + 1)]
# print(table)
for i in range(m + 1):
table[i][0] = i
for i in range(1, m + 1):
for j in range(1, n + 1):
table[i][j] = min(1 + table[i - 1][j], 1 + table[i][j - 1],
int(s[i - 1] != t[j - 1]) + table[i - 1][j - 1])
for t in table:
print(t)
return table[-1][-1]
def get_all_words_of_trie(self):
words = []
for k in self.root.keys():
words.extend(self.get_start(k))
return words
def get_right_word(self,input_word):
'''
输入一个单词,返回正确的单词
:param input_word:
:return:
'''
words = self.get_all_words_of_trie()
right_word = input_word
min_distance = 99999
for item in words:
distance = self.levenshtein_dp(input_word,item)
if min_distance > distance:
min_distance = distance
right_word = item
return right_word
if __name__ == "__main__":
trie = Trie()
trie.insert("中")
trie.insert("中国")
trie.insert("中国人")
trie.insert("中华人民共和国")
# print(trie.root)
trie.insert("Python")
trie.insert("Python 算法")
trie.insert("Python web")
trie.insert("Python web 开发")
trie.insert("Python web 开发 视频教程")
trie.insert("Python 算法 源码")
trie.insert("Perl 算法 源码")
# print(trie.search("Perl"))
# print(trie.search("Perl 算法 源码"))
# print((trie.get_start('P')))
# print((trie.get_start('Python web')))
# print((trie.get_start('Python 算')))
# print(trie.get_all_words_of_trie())
print(trie.levenshtein_dp("facbok","facebook"))