Skip to content

Commit

Permalink
Fix the UnicodeDecodeError when reading code
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaochao Dong (@damnever) <[email protected]>
  • Loading branch information
damnever committed Jun 30, 2023
1 parent 7b03956 commit 2563db8
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions pigar/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,16 @@ def parse_imports(
return imported_modules, annotations


def parse_file_comment_annotations(fpath: str, code: str) -> List[Annotation]:
def parse_file_comment_annotations(fpath: str,
code: bytes) -> List[Annotation]:
"""Parse annotations in comments, the valid format is as follows:
import foo # pigar: required-packages=pkg-bar
import foo # pigar: required-distributions=pkg-bar # package name
import foo # pigar: required-imports=bar # top level import name
"""
annotations: List[Annotation] = []
try:
for token in tokenize.generate_tokens(io.StringIO(code).readline):
for token in tokenize.tokenize(io.BytesIO(code).readline):
if token.type != tokenize.COMMENT:
continue
lineno, offset = token.start
Expand Down Expand Up @@ -154,7 +155,7 @@ def parse_file_comment_annotations(fpath: str, code: str) -> List[Annotation]:
)


def _read_code(fpath: str) -> Optional[str]:
def _read_code(fpath: str) -> Optional[bytes]:
if fpath.endswith(".ipynb"):
nb = nbformat.read(fpath, as_version=4)
code = ""
Expand All @@ -166,17 +167,17 @@ def _read_code(fpath: str) -> Optional[str]:
if not (match and match.group(0) == line):
code += line
code += "\n"
return code
return code.encode(encoding="utf-8")
elif fpath.endswith(".py"):
with open(fpath, 'r') as f:
with open(fpath, 'rb') as f:
return f.read()
return None


def parse_file_imports(fpath: str,
content: str,
visit_doc_str: bool = False) -> List[Module]:
py_codes: Deque[Tuple[str, int]] = collections.deque([(content, 1)])
def parse_file_imports(
fpath: str, content: bytes, visit_doc_str: bool = False
) -> List[Module]:
py_codes: Deque[Tuple[bytes, int]] = collections.deque([(content, 1)])
parser = ImportsParser(
lambda code, lineno: py_codes.append((code, lineno)), # noqa
doc_str_enabled=visit_doc_str,
Expand All @@ -196,14 +197,14 @@ class ImportsParser(object):

def __init__(
self,
rawcode_callback: Optional[Callable[[str, int], None]] = None,
rawcode_callback: Optional[Callable[[bytes, int], None]] = None,
doc_str_enabled: bool = False,
):
self._modules: List[Module] = []
self._rawcode_callback = rawcode_callback
self._doc_str_enabled = doc_str_enabled

def parse(self, content: str, fpath: str, lineno: int):
def parse(self, content: bytes, fpath: str, lineno: int):
parsed = ast.parse(content)
self._fpath = fpath
self._mods = fpath[:-3].split("/")
Expand All @@ -215,7 +216,7 @@ def _add_module(self, name: str, try_: bool, lineno: int):
Module(name=name, try_=try_, file=self._fpath, lineno=lineno)
)

def _add_rawcode(self, code: str, lineno: int):
def _add_rawcode(self, code: bytes, lineno: int):
if self._rawcode_callback:
self._rawcode_callback(code, lineno)

Expand Down Expand Up @@ -360,7 +361,7 @@ def visit(self, node: ast.AST):
def _parse_docstring(
node: Union[ast.AsyncFunctionDef, ast.FunctionDef, ast.ClassDef,
ast.Module]
) -> Optional[str]:
) -> Optional[bytes]:
"""Extract code from docstring."""
docstring = ast.get_docstring(node)
if docstring:
Expand All @@ -372,7 +373,8 @@ def _parse_docstring(
pass
else:
examples = dt.examples
return '\n'.join([example.source for example in examples])
return '\n'.join([example.source for example in examples]
).encode(encoding="utf-8")
return None

@property
Expand Down

0 comments on commit 2563db8

Please sign in to comment.