-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgallina.py
153 lines (125 loc) · 5.85 KB
/
gallina.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Utilities for reconstructing Gallina terms from their serialized S-expressions in CoqGym
from io import StringIO
from vernac_types import Constr__constr
from lark import Lark, Transformer, Visitor, Discard
from lark.lexer import Token
from lark.tree import Tree
from serutils import SerAPIWrapper
import logging
logging.basicConfig(level=logging.DEBUG)
from collections import defaultdict
from syntax import SyntaxConfig
import re
import pdb
from utils import log
def traverse_postorder(node, callback, ancestor_info=None, get_ancestor_info=None):
old_ancestor_info = ancestor_info
if get_ancestor_info is not None:
ancestor_info = get_ancestor_info(node, ancestor_info)
for c in node.children:
if isinstance(c, Tree):
traverse_postorder(c, callback, ancestor_info, get_ancestor_info)
if get_ancestor_info is not None:
callback(node, old_ancestor_info)
else:
callback(node)
class GallinaTermParser:
def __init__(self, coq_projects_path, syntax_config, caching=True, use_serapi=True):
self.caching = caching
self.syntax_config = syntax_config
t = Constr__constr()
self.grammar = t.to_ebnf(recursive=True) + '''
%import common.STRING_INNER
%import common.ESCAPED_STRING
%import common.SIGNED_INT
%import common.WS
%ignore WS
'''
self.parser = Lark(StringIO(self.grammar), start='constr__constr', parser='lalr')
if caching:
self.cache = {}
self.serapi = None
if use_serapi:
self.serapi = SerAPIWrapper(coq_projects_path, timeout=30)
def load_project(self, proj):
self.serapi.load_project(proj)
def parse_no_cache(self, term_str):
syn_conf = self.syntax_config
ast = self.parser.parse(term_str)
ast.quantified_idents = set()
def get_quantified_idents(node):
if node.data == 'constructor_prod' and node.children != [] and SyntaxConfig.is_name(node.children[0]):
ident = node.children[0].children[0].value
if ident.startswith('"') and ident.endswith('"'):
ident = ident[1:-1]
ast.quantified_idents.add(ident)
traverse_postorder(ast, get_quantified_idents)
ast.quantified_idents = list(ast.quantified_idents)
# Postprocess: compute height, remove some tokens, make identifiers explicit
# Make everything nonterminal for compatibility
def postprocess(node, is_constructor_desc):
children = []
node.height = 0
keep_constructor_desc = is_constructor_desc and syn_conf.include_constructor_names
for c in node.children:
if isinstance(c, Tree):
node.height = max(node.height, c.height + 1)
children.append(c)
# Don't erase fully-qualified definition & theorem names
elif (((syn_conf.include_defs or keep_constructor_desc) and SyntaxConfig.is_label(node)) or
((syn_conf.include_paths or keep_constructor_desc) and SyntaxConfig.is_path(node)) or
(syn_conf.merge_vocab and syn_conf.include_locals and SyntaxConfig.is_local(node))):
ident_wrapper = SyntaxConfig.singleton_ident(c.value)
node.height = 2
children.append(ident_wrapper)
# Don't erase local variable names
elif ((syn_conf.include_locals or keep_constructor_desc) and SyntaxConfig.is_local(node)):
var_value = SyntaxConfig.nonterminal_value(c.value)
node.height = 1
children.append(var_value)
elif keep_constructor_desc:
children.append(c)
# Recover constructor names
if syn_conf.include_constructor_names and SyntaxConfig.is_constructor(node):
constructor_name = self.serapi.get_constr_name(node)
if not syn_conf.include_paths:
for path_node in node.find_data('constructor_dirpath'):
path_node.children = []
if not syn_conf.include_defs:
def_node = next(node.find_data('names__label__t'))
def_node.children = []
if constructor_name:
children.append(SyntaxConfig.singleton_ident(constructor_name))
if SyntaxConfig.is_constructor(node):
for parent in node.iter_subtrees():
parent.children = [c for c in parent.children if not isinstance(c, Token)]
node.children = children
def get_is_construct_desc(node, is_construct_desc):
return is_construct_desc or SyntaxConfig.is_constructor(node)
traverse_postorder(ast, postprocess, False, get_is_construct_desc)
return ast
def parse(self, term_str):
if self.caching:
if term_str not in self.cache:
self.cache[term_str] = self.parse_no_cache(term_str)
return self.cache[term_str]
else:
return self.parse_no_cache(term_str)
def print_grammar(self):
print(self.grammar)
class Counter(Visitor):
def __init__(self):
super().__init__()
self.counts_nonterminal = defaultdict(int)
self.counts_terminal = defaultdict(int)
def __default__(self, tree):
self.counts_nonterminal[tree.data] += 1
for c in tree.children:
if isinstance(c, Token):
self.counts_terminal[c.value] += 1
class TreeHeight(Transformer):
def __default__(self, symbol, children, meta):
return 1 + max([0 if isinstance(c, Token) else c for c in children] + [-1])
class TreeNumTokens(Transformer):
def __default__(self, symbol, children, meta):
return sum([1 if isinstance(c, Token) else c for c in children])