forked from facebookresearch/code-prediction-transformer
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgenerate_vocab.py
108 lines (89 loc) · 3.34 KB
/
generate_vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import logging
import pickle
from collections import Counter
from utils import file_tqdm, get_dfs
logging.basicConfig(level=logging.INFO)
UNK = "<unk_token>"
PAD = "<pad_token>"
def get_value(line, input_type):
if input_type == "ast":
return get_dfs(line)
elif input_type == "leaf":
return get_dfs(line, only_leaf=True)
elif input_type == "source_code":
return line[0]
def external(file_path, n_vocab):
outfile = "output/vocab.pkl"
logging.info("Reading from: {}".format(file_path))
vocab = Counter()
with open(file_path, "r") as f:
for line in file_tqdm(f):
vocab.update(get_value(json.loads(line.strip()), "ast"))
vocab_to_keep = [i[0] for i in vocab.most_common(n_vocab)]
top_total = sum(i[1] for i in vocab.most_common(n_vocab))
total = sum(vocab.values())
logging.info("Total # of vocab: {}".format(len(vocab)))
logging.info(
"Using {} top vocab covers: {:.2f}% of the entire dataset".format(
n_vocab, 100 * top_total / total
)
)
logging.info("Top 10 most common vocab:")
for v, i in vocab.most_common(10):
print(v, i)
# add unk and pad tokens
vocab_to_keep.append(UNK)
vocab_to_keep.append(PAD)
logging.info("Added {} and {}".format(UNK, PAD))
# dump vocab to file
with open(outfile, "wb") as fout:
pickle.dump(vocab_to_keep, fout)
logging.info("Wrote {} vocab to: {}".format(len(vocab_to_keep), outfile))
def main():
parser = argparse.ArgumentParser(description="Create vocab for py150 dataset")
parser.add_argument("--n_vocab", "-n", type=int, default=100000)
parser.add_argument("--input_fp", "-i")
parser.add_argument("--out_fp", "-o", default="/tmp/vocab.pkl")
parser.add_argument(
"--input_type",
"-t",
choices=["ast", "leaf", "source_code"],
help="Where to get the input from (all AST nodes, leaf nodes, or source code",
)
args = parser.parse_args()
logging.info("Reading from: {}".format(args.input_fp))
logging.info("Input type: {}".format(args.input_type))
vocab = Counter()
with open(args.input_fp, "r") as f:
for line in file_tqdm(f):
vocab.update(get_value(json.loads(line.strip()), args.input_type))
vocab_to_keep = [i[0] for i in vocab.most_common(args.n_vocab)]
top_total = sum(i[1] for i in vocab.most_common(args.n_vocab))
total = sum(vocab.values())
logging.info("Total # of vocab: {}".format(len(vocab)))
logging.info(
"Using {} top vocab covers: {:.2f}% of the entire dataset".format(
args.n_vocab, 100 * top_total / total
)
)
logging.info("Top 10 most common vocab:")
for v, i in vocab.most_common(10):
print(v, i)
# add unk and pad tokens
vocab_to_keep.append(UNK)
vocab_to_keep.append(PAD)
logging.info("Added {} and {}".format(UNK, PAD))
# dump vocab to file
with open(args.out_fp, "wb") as fout:
pickle.dump(vocab_to_keep, fout)
logging.info("Wrote {} vocab to: {}".format(len(vocab_to_keep), args.out_fp))
if __name__ == "__main__":
main()