-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcodet5.py
124 lines (101 loc) · 4.65 KB
/
codet5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from transformers import RobertaTokenizer, T5ForConditionalGeneration
import torch
import ast
import json
import re
class CodeBlockVisitor(ast.NodeVisitor):
def __init__(self, source_code):
self.source_code = source_code
self.code_blocks = []
def visit_FunctionDef(self, node):
self.code_blocks.append(ast.get_source_segment(self.source_code, node))
self.generic_visit(node)
def visit_For(self, node):
self.code_blocks.append(ast.get_source_segment(self.source_code, node))
self.generic_visit(node)
def visit_While(self, node):
self.code_blocks.append(ast.get_source_segment(self.source_code, node))
self.generic_visit(node)
def visit_If(self, node):
self.code_blocks.append(ast.get_source_segment(self.source_code, node))
self.generic_visit(node)
class CodeT5:
def __init__(self):
if torch.backends.mps.is_available():
# OVERRIDE FOR FASTER DEV
self.device = torch.device('cpu')
print('[INFO] Using CPU')
'''
self.device = torch.device('mps')
print('[INFO] Using MPS')
'''
else:
self.device = torch.device('cpu')
print('[INFO] Using CPU')
self.comment_pattern = re.compile(r'''
^\s*(
//| # C, C++, Java, JavaScript, C#, Go, Swift, Kotlin, Dart, TypeScript
\#| # Python, Ruby, Perl, Shell Script
%| # MATLAB
;| # Assembly
--| # SQL, Ada
/\*| # Start of C-style block comment
\(\*| # Start of Pascal-style block comment
<!--| # HTML, XML
' # Visual Basic, Haskell
)
''', re.VERBOSE)
self.tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base-multi-sum')
self.model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base-multi-sum').to(self.device)
def summarize_code(self, source_code):
input_ids = self.tokenizer(source_code, return_tensors='pt').input_ids.to(self.device)
generated_ids = self.model.generate(input_ids, max_length=100)
summary = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True).replace(' .', '.')
return summary
def extract_code_blocks(self, source_code):
tree = ast.parse(source_code)
visitor = CodeBlockVisitor(source_code)
visitor.visit(tree)
return visitor.code_blocks
def summarize_by_line(self, code):
code_lines = code.splitlines()
output = []
for line in code_lines:
if line.strip():
summary = self.summarize_code(line)
print(f'Code: {line}\nSummary: {summary}\n')
output.append(f'Summary: {summary}\n')
return '\n'.join(output)
def summarize_line(self, code):
output = []
for line in code:
if line == '\n' or line == '' or line.strip() == '\n' or line.strip() == '':
output.append('\n')
elif self.is_comment(line.strip()):
output.append('\n')
else:
summary = self.summarize_code(line)
print(f'Code: {line}\nSummary: {summary}\n')
output.append(f'{summary}\n')
return output
def summarize_by_chunks(self, code):
code_blocks = self.extract_code_blocks(code)
output = []
for block in code_blocks:
summary = self.summarize_code(block)
print(f'Code: {block}\nSummary: {summary}\n')
output.append(f'{summary}\n')
return '\n'.join(output)
def summarize(self, code, filename):
# get language
with open('static/json/coding_languages.json', 'r') as f: data = json.load(f)
ext_to_name = {entry['extensions'][0].strip('.'): entry['name'] for entry in data if 'extensions' in entry and entry['extensions']}
language = ext_to_name.get(filename.split('.')[-1] if '.' in filename else 'Unknown', 'Unknown').lower()
if language == 'python':
# return self.summarize_by_chunks('\n'.join(code))
return self.summarize_line(code)
else:
print('not python')
return self.summarize_line(code)
def is_comment(self, line):
return bool(self.comment_pattern.match(line.strip()))