From 8447185311a2441c4143679bbb9b04b74aafc653 Mon Sep 17 00:00:00 2001 From: decitre <590094+decitre@users.noreply.github.com> Date: Mon, 5 Feb 2024 15:59:41 +0100 Subject: [PATCH] Breaking change: compiler_path as parameter of compile() instead of ProtoCollection constructor --- pyproject.toml | 2 +- src/proto_topy.py | 75 +++++++++++++++++--------- tests/test_proto_topy.py | 114 +++++++++++++++++++-------------------- 3 files changed, 107 insertions(+), 84 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8c0f29b..a964480 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,7 +118,7 @@ commands = [testenv:lint] deps = pre-commit -commands = pre-commit run --all-files --show-diff-on-failure +commands = pre-commit run --all-files [testenv:report] deps = diff --git a/src/proto_topy.py b/src/proto_topy.py index cd20b29..d620275 100644 --- a/src/proto_topy.py +++ b/src/proto_topy.py @@ -60,16 +60,26 @@ def __init__(self, file_path: Path, source: str): self.py = None self.py_source = None - def set_module(self, content: str, global_scope: dict = None): + def _set_module(self, content: str, global_scope: dict = None): self.py_source = content spec = importlib.util.spec_from_loader(self.name, loader=None) compiled_content = compile(content, self.name, "exec") self.py = importlib.util.module_from_spec(spec) exec(compiled_content, self.py.__dict__) - def compiled(self, compiler_path: Path) -> "ProtoModule": - collection = ProtoCollection(compiler_path, self) - collection.compile() + def compiled(self, compiler_path: Path = None) -> "ProtoModule": + """ + Returns the ProtoModule instance in a compiled state: + + - self.py_source contains the generated Python code + - self.py contains the loaded module + + :param compiler_path: The Path to the protoc compiler (optional) + :return: self + """ + + collection = ProtoCollection(self) + collection.compile(compiler_path=compiler_path) return self @@ -82,27 +92,25 @@ class CompilationFailed(Exception): class ProtoCollection: - compiler_path: Path + """ + Encapsulates a protobuf `FileDescriptorSet` associated to a list of `ProtoModule` instances. + + Important attributes: + - descriptor_set: a `FileDescriptorSet` instance, a compiled protobuf describing the messafe types in the collection + - descriptor_data: the serialized `FileDescriptorSet` instance. Suitable to a transmission over the wire + - messages: A dictionary of protobuf messages classes indexed by their proto names + """ + modules: Dict[Path, ProtoModule] descriptor_data: bytes descriptor_set: FileDescriptorSet messages: dict - def __init__(self, compiler_path: Path, *protos: ProtoModule): + def __init__(self, *protos: ProtoModule): self.modules = {} - self.compiler_path = compiler_path self.descriptor_data = None self.descriptor_set = None self.messages = {} - self.pool = descriptor_pool.DescriptorPool() - - if not self.compiler_path: - if "PROTOC" in os.environ and os.path.exists(os.environ["PROTOC"]): - self.compiler_path = Path(os.environ["PROTOC"]) - else: - self.compiler_path or Path(which("protoc")) - if not self.compiler_path.is_file(): - raise FileNotFoundError() for proto in protos or []: self.add_proto(proto) @@ -112,7 +120,22 @@ def add_proto(self, proto: ProtoModule): raise KeyError(f"{proto.file_path} already added") self.modules[proto.file_path] = proto - def compile(self, global_scope: dict = None) -> "ProtoCollection": + @staticmethod + def _get_compiler_path() -> Path: + if "PROTOC" in os.environ and os.path.exists(os.environ["PROTOC"]): + compiler_path = Path(os.environ["PROTOC"]) + else: + compiler_path = Path(which("protoc")) + if not compiler_path.is_file(): + raise FileNotFoundError("protoc compiler not found") + return compiler_path + + def compile( + self, compiler_path: Path = None, global_scope: dict = None + ) -> "ProtoCollection": + if not compiler_path: + compiler_path = ProtoCollection._get_compiler_path() + with TemporaryDirectory() as dir: protos_target_paths = { Path(dir, proto.file_path): proto for proto in self.modules.values() @@ -124,7 +147,7 @@ def compile(self, global_scope: dict = None) -> "ProtoCollection": compile_to_py_options = [f"--proto_path={dir}", f"--python_out={dir}"] ProtoCollection._do_compile( - self.compiler_path, + compiler_path, compile_to_py_options, proto_source_files, raise_exception=True, @@ -137,7 +160,7 @@ def compile(self, global_scope: dict = None) -> "ProtoCollection": f"--descriptor_set_out={artifact_fds_path}", ] ProtoCollection._do_compile( - self.compiler_path, + compiler_path, compile_to_py_options, proto_source_files, raise_exception=False, @@ -145,10 +168,12 @@ def compile(self, global_scope: dict = None) -> "ProtoCollection": with open(str(artifact_fds_path), mode="rb") as f: self.descriptor_data = f.read() self.descriptor_set = FileDescriptorSet.FromString(self.descriptor_data) + + pool = descriptor_pool.DescriptorPool() for file_descriptor_proto in self.descriptor_set.file: - self.pool.Add(file_descriptor_proto) + pool.Add(file_descriptor_proto) self.messages = GetMessageClassesForFiles( - [fdp.name for fdp in self.descriptor_set.file], self.pool + [fdp.name for fdp in self.descriptor_set.file], pool ) self._add_init_files(dir) @@ -158,13 +183,15 @@ def compile(self, global_scope: dict = None) -> "ProtoCollection": with open( Path(dir, proto.package_path, f"{proto.name}_pb2.py") ) as module_path: - proto.set_module(module_path.read(), global_scope=global_scope) + proto._set_module(module_path.read(), global_scope=global_scope) sys.path.pop() return self - def version(self) -> str: + def compiler_version(self, compiler_path: Path = None) -> str: + if not compiler_path: + compiler_path = ProtoCollection._get_compiler_path() outs = ProtoCollection._do_compile( - self.compiler_path, + compiler_path, ["--version"], [], raise_exception=True, diff --git a/tests/test_proto_topy.py b/tests/test_proto_topy.py index 13bfb3f..b470ebf 100644 --- a/tests/test_proto_topy.py +++ b/tests/test_proto_topy.py @@ -25,76 +25,72 @@ def address_book(): def test_compiler_version(): - version = ProtoCollection(compiler_path=protoc_path).version() + version = ProtoCollection().compiler_version(compiler_path=protoc_path) assert version is not None and tuple(map(int, version.split("."))) > (3, 0, 0) -def unlink_proto(path_str: str) -> Path: - proto = Path(path_str) - if proto.exists(): - proto.unlink() - return proto +def unlink_proto_file(path_str: str) -> Path: + proto_path = Path(path_str) + if proto_path.exists(): + proto_path.unlink() + return proto_path def test_add_proto(): - test1_proto = unlink_proto("test1.proto") + test1_proto = unlink_proto_file("test1.proto") proto = ProtoModule(file_path=test1_proto, source="") - modules = ProtoCollection(compiler_path=protoc_path) + modules = ProtoCollection() modules.add_proto(proto) assert test1_proto in modules.modules - unlink_proto("test1.proto") + unlink_proto_file("test1.proto") def test_add_proto2(): - test2_proto = unlink_proto("test2.proto") - test3_proto = unlink_proto("test3.proto") + test2_proto = unlink_proto_file("test2.proto") + test3_proto = unlink_proto_file("test3.proto") modules = ProtoCollection( - protoc_path, *( ProtoModule(file_path=test2_proto, source=""), ProtoModule(file_path=test3_proto, source=""), - ), + ) ) assert test2_proto in modules.modules assert test3_proto in modules.modules - unlink_proto("test2.proto") - unlink_proto("test3.proto") + unlink_proto_file("test2.proto") + unlink_proto_file("test3.proto") def test_bad_protoc(): - dummy = unlink_proto("dummy") + dummy = unlink_proto_file("dummy") with pytest.raises(FileNotFoundError): - ProtoCollection(dummy).compile() - unlink_proto("dummy") - - -def test_no_protoc(): - with pytest.raises(TypeError): - ProtoCollection() + ProtoCollection().compile(compiler_path=dummy) + unlink_proto_file("dummy") def test_compile_invalid_source(): - test4_proto = unlink_proto("test4.proto") + test4_proto = unlink_proto_file("test4.proto") with pytest.raises(CompilationFailed): - ProtoModule(file_path=test4_proto, source="foo").compiled(protoc_path) - unlink_proto("test4.proto") + ProtoModule(file_path=test4_proto, source="foo").compiled( + compiler_path=protoc_path + ) + unlink_proto_file("test4.proto") def test_compile_redundant_proto(): - testr_proto = unlink_proto("testr.proto") + testr_proto = unlink_proto_file("testr.proto") proto_source = 'syntax = "proto3"; message TestR { int32 foo = 1; };' proto1 = ProtoModule(file_path=testr_proto, source=proto_source) proto2 = ProtoModule(file_path=testr_proto, source=proto_source) with pytest.raises(KeyError, match=r"testr.proto already added"): - ProtoCollection(protoc_path, proto1, proto2).compile() - unlink_proto("testr.proto") + ProtoCollection(proto1, proto2).compile(compiler_path=protoc_path) + unlink_proto_file("testr.proto") def test_compile_minimal_proto(): from google.protobuf.timestamp_pb2 import Timestamp - test5_proto = unlink_proto("test5.proto") + test5_proto = unlink_proto_file("test5.proto") proto = ProtoModule( file_path=test5_proto, source=""" @@ -109,13 +105,13 @@ def test_compile_minimal_proto(): atest5 = proto.py.Test5() assert isinstance(atest5.created, Timestamp) del sys.modules["test5"] - unlink_proto("test5.proto") + unlink_proto_file("test5.proto") def test_compile_minimal_proto_in_a_package(): from google.protobuf.timestamp_pb2 import Timestamp - thing_proto = unlink_proto("p1/p2/p3/thing.proto") + thing_proto = unlink_proto_file("p1/p2/p3/thing.proto") proto = ProtoModule( file_path=thing_proto, source=""" @@ -134,21 +130,21 @@ def test_compile_minimal_proto_in_a_package(): sys.modules["thing"] = proto.py athing = proto.py.Thing() assert isinstance(athing.created, Timestamp) - unlink_proto("p1/p2/p3/thing.proto") + unlink_proto_file("p1/p2/p3/thing.proto") def test_compile_missing_dependency(): - test_proto = unlink_proto("test.proto") + test_proto = unlink_proto_file("test.proto") with pytest.raises(CompilationFailed, match=r"other.proto: File not found.*"): ProtoModule( file_path=test_proto, source='syntax = "proto3"; import "other.proto";', ).compiled(protoc_path) - unlink_proto("test.proto") + unlink_proto_file("test.proto") def test_compile_ununsed_dependency(): - test_proto = unlink_proto("test.proto") + test_proto = unlink_proto_file("test.proto") proto_module = ProtoModule( file_path=test_proto, source=""" @@ -157,7 +153,7 @@ def test_compile_ununsed_dependency(): """, ) - other_proto = unlink_proto("other.proto") + other_proto = unlink_proto_file("other.proto") other_proto_module = ProtoModule( file_path=other_proto, source=""" @@ -168,19 +164,19 @@ def test_compile_ununsed_dependency(): } """, ) - modules = ProtoCollection(protoc_path, proto_module, other_proto_module) + modules = ProtoCollection(proto_module, other_proto_module) try: - modules.compile() + modules.compile(compiler_path=protoc_path) except CompilationFailed: pytest.fail("Unexpected CompilationFailed ..") - unlink_proto("test.proto") - unlink_proto("other.proto") + unlink_proto_file("test.proto") + unlink_proto_file("other.proto") def test_compile_simple_dependency(): from google.protobuf.timestamp_pb2 import Timestamp - test_proto = unlink_proto("p3/p4/test6.proto") + test_proto = unlink_proto_file("p3/p4/test6.proto") proto_module = ProtoModule( file_path=test_proto, source=""" @@ -192,7 +188,7 @@ def test_compile_simple_dependency(): """, ) - other_proto = unlink_proto("p1/p2/other2.proto") + other_proto = unlink_proto_file("p1/p2/other2.proto") other_proto_module = ProtoModule( file_path=other_proto, source=""" @@ -203,38 +199,38 @@ def test_compile_simple_dependency(): } """, ) - modules = ProtoCollection(protoc_path, proto_module, other_proto_module) - modules.compile() + modules = ProtoCollection(proto_module, other_proto_module) + modules.compile(compiler_path=protoc_path) sys.modules.update({proto.name: proto.py for proto in modules.modules.values()}) atest6 = modules.modules[test_proto].py.Test6() assert isinstance(atest6.foo.created, Timestamp) for proto_module in modules.modules.values(): del sys.modules[proto_module.name] - unlink_proto("p3/p4/test6.proto") - unlink_proto("p1/p2/other2.proto") + unlink_proto_file("p3/p4/test6.proto") + unlink_proto_file("p1/p2/other2.proto") def test_encode_message(): proto_source = 'syntax = "proto3"; message Test{n} {{ int32 foo = 1; }};' - test7_proto = unlink_proto("test7.proto") - test8_proto = unlink_proto("test8.proto") + test7_proto = unlink_proto_file("test7.proto") + test8_proto = unlink_proto_file("test8.proto") proto1 = ProtoModule(file_path=test7_proto, source=proto_source.format(n=7)) proto2 = ProtoModule(file_path=test8_proto, source=proto_source.format(n=8)) - ProtoCollection(protoc_path, proto1, proto2).compile() + ProtoCollection(proto1, proto2).compile(compiler_path=protoc_path) assert array("B", proto1.py.Test7(foo=124).SerializeToString()) == array( "B", [8, 124] ) assert array("B", proto2.py.Test8(foo=123).SerializeToString()) == array( "B", [8, 123] ) - unlink_proto("test7.proto") - unlink_proto("test8.proto") + unlink_proto_file("test7.proto") + unlink_proto_file("test8.proto") def test_decode_message(): - test9_proto = unlink_proto("test9.proto") + test9_proto = unlink_proto_file("test9.proto") proto = ProtoModule( file_path=test9_proto, source='syntax = "proto3"; message Test9 { int32 foo = 1; };', @@ -242,11 +238,11 @@ def test_decode_message(): aTest9 = proto.py.Test9() aTest9.ParseFromString(bytes(array("B", [8, 124]))) assert aTest9.foo == 124 - unlink_proto("test9.proto") + unlink_proto_file("test9.proto") def test_decode_messages_stream(): - test10_proto = unlink_proto("test10.proto") + test10_proto = unlink_proto_file("test10.proto") proto = ProtoModule( file_path=test10_proto, source='syntax = "proto3"; message Test10 { int32 foo = 1; };', @@ -256,11 +252,11 @@ def test_decode_messages_stream(): ) factory.stream.seek(0) assert [thing.foo for _, thing in factory.message_read(proto.py.Test10)] == [1, 12] - unlink_proto("test10.proto") + unlink_proto_file("test10.proto") def test_decode_messages_stream2(): - test11_proto = unlink_proto("test11.proto") + test11_proto = unlink_proto_file("test11.proto") proto = ProtoModule( file_path=test11_proto, source='syntax = "proto3"; message Test11 { int32 foo = 1; };', @@ -277,13 +273,13 @@ def test_decode_messages_stream2(): aTest11.ParseFromString(offset_data[1]) foos.append(aTest11.foo) assert foos == [1, 12] - unlink_proto("test11.proto") + unlink_proto_file("test11.proto") @pytest.mark.vcr def test_google_addressbook_example(address_book): - adressbook_proto = unlink_proto( + adressbook_proto = unlink_proto_file( "protocolbuffers/protobuf/blob/main/examples/addressbook.proto" ) proto = ProtoModule(