Skip to content

Commit

Permalink
Breaking change: compiler_path as parameter of compile() instead of P…
Browse files Browse the repository at this point in the history
…rotoCollection constructor
  • Loading branch information
decitre committed Feb 5, 2024
1 parent 2807ece commit 8447185
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 84 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ commands =
[testenv:lint]
deps = pre-commit
commands = pre-commit run --all-files --show-diff-on-failure
commands = pre-commit run --all-files
[testenv:report]
deps =
Expand Down
75 changes: 51 additions & 24 deletions src/proto_topy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,26 @@ def __init__(self, file_path: Path, source: str):
self.py = None
self.py_source = None

def set_module(self, content: str, global_scope: dict = None):
def _set_module(self, content: str, global_scope: dict = None):
self.py_source = content
spec = importlib.util.spec_from_loader(self.name, loader=None)
compiled_content = compile(content, self.name, "exec")
self.py = importlib.util.module_from_spec(spec)
exec(compiled_content, self.py.__dict__)

def compiled(self, compiler_path: Path) -> "ProtoModule":
collection = ProtoCollection(compiler_path, self)
collection.compile()
def compiled(self, compiler_path: Path = None) -> "ProtoModule":
"""
Returns the ProtoModule instance in a compiled state:
- self.py_source contains the generated Python code
- self.py contains the loaded module
:param compiler_path: The Path to the protoc compiler (optional)
:return: self
"""

collection = ProtoCollection(self)
collection.compile(compiler_path=compiler_path)
return self


Expand All @@ -82,27 +92,25 @@ class CompilationFailed(Exception):


class ProtoCollection:
compiler_path: Path
"""
Encapsulates a protobuf `FileDescriptorSet` associated to a list of `ProtoModule` instances.
Important attributes:
- descriptor_set: a `FileDescriptorSet` instance, a compiled protobuf describing the messafe types in the collection
- descriptor_data: the serialized `FileDescriptorSet` instance. Suitable to a transmission over the wire
- messages: A dictionary of protobuf messages classes indexed by their proto names
"""

modules: Dict[Path, ProtoModule]
descriptor_data: bytes
descriptor_set: FileDescriptorSet
messages: dict

def __init__(self, compiler_path: Path, *protos: ProtoModule):
def __init__(self, *protos: ProtoModule):
self.modules = {}
self.compiler_path = compiler_path
self.descriptor_data = None
self.descriptor_set = None
self.messages = {}
self.pool = descriptor_pool.DescriptorPool()

if not self.compiler_path:
if "PROTOC" in os.environ and os.path.exists(os.environ["PROTOC"]):
self.compiler_path = Path(os.environ["PROTOC"])
else:
self.compiler_path or Path(which("protoc"))
if not self.compiler_path.is_file():
raise FileNotFoundError()

for proto in protos or []:
self.add_proto(proto)
Expand All @@ -112,7 +120,22 @@ def add_proto(self, proto: ProtoModule):
raise KeyError(f"{proto.file_path} already added")
self.modules[proto.file_path] = proto

def compile(self, global_scope: dict = None) -> "ProtoCollection":
@staticmethod
def _get_compiler_path() -> Path:
if "PROTOC" in os.environ and os.path.exists(os.environ["PROTOC"]):
compiler_path = Path(os.environ["PROTOC"])
else:
compiler_path = Path(which("protoc"))
if not compiler_path.is_file():
raise FileNotFoundError("protoc compiler not found")
return compiler_path

def compile(
self, compiler_path: Path = None, global_scope: dict = None
) -> "ProtoCollection":
if not compiler_path:
compiler_path = ProtoCollection._get_compiler_path()

with TemporaryDirectory() as dir:
protos_target_paths = {
Path(dir, proto.file_path): proto for proto in self.modules.values()
Expand All @@ -124,7 +147,7 @@ def compile(self, global_scope: dict = None) -> "ProtoCollection":

compile_to_py_options = [f"--proto_path={dir}", f"--python_out={dir}"]
ProtoCollection._do_compile(
self.compiler_path,
compiler_path,
compile_to_py_options,
proto_source_files,
raise_exception=True,
Expand All @@ -137,18 +160,20 @@ def compile(self, global_scope: dict = None) -> "ProtoCollection":
f"--descriptor_set_out={artifact_fds_path}",
]
ProtoCollection._do_compile(
self.compiler_path,
compiler_path,
compile_to_py_options,
proto_source_files,
raise_exception=False,
)
with open(str(artifact_fds_path), mode="rb") as f:
self.descriptor_data = f.read()
self.descriptor_set = FileDescriptorSet.FromString(self.descriptor_data)

pool = descriptor_pool.DescriptorPool()
for file_descriptor_proto in self.descriptor_set.file:
self.pool.Add(file_descriptor_proto)
pool.Add(file_descriptor_proto)
self.messages = GetMessageClassesForFiles(
[fdp.name for fdp in self.descriptor_set.file], self.pool
[fdp.name for fdp in self.descriptor_set.file], pool
)

self._add_init_files(dir)
Expand All @@ -158,13 +183,15 @@ def compile(self, global_scope: dict = None) -> "ProtoCollection":
with open(
Path(dir, proto.package_path, f"{proto.name}_pb2.py")
) as module_path:
proto.set_module(module_path.read(), global_scope=global_scope)
proto._set_module(module_path.read(), global_scope=global_scope)
sys.path.pop()
return self

def version(self) -> str:
def compiler_version(self, compiler_path: Path = None) -> str:
if not compiler_path:
compiler_path = ProtoCollection._get_compiler_path()
outs = ProtoCollection._do_compile(
self.compiler_path,
compiler_path,
["--version"],
[],
raise_exception=True,
Expand Down
Loading

0 comments on commit 8447185

Please sign in to comment.