diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl index d9d38935..472e7434 100644 --- a/Dockerfile.tmpl +++ b/Dockerfile.tmpl @@ -26,8 +26,9 @@ RUN uv pip install --no-build-isolation --no-cache --system "git+https://github. # b/404590350: Ray and torchtune have conflicting cli named `tune`. `ray` is not part of Colab's base image. Re-install `tune` to ensure the torchtune CLI is available by default. # b/468367647: Unpin protobuf, version greater than v5.29.5 causes issues with numerous packages +# grpcio-tools must be installed here (not in kaggle_requirements.txt) to stay version-compatible with protobuf. RUN uv pip install --system --force-reinstall --no-cache --no-deps torchtune -RUN uv pip install --system --force-reinstall --no-cache "protobuf==5.29.5" +RUN uv pip install --system --force-reinstall --no-cache "protobuf==5.29.5" "grpcio-tools>=1.60.0" # Adding non-package dependencies: ADD clean-layer.sh /tmp/clean-layer.sh diff --git a/tests/test_grpcio_tools.py b/tests/test_grpcio_tools.py new file mode 100644 index 00000000..cc0c6418 --- /dev/null +++ b/tests/test_grpcio_tools.py @@ -0,0 +1,60 @@ +import os +import subprocess +import sys +import tempfile +import unittest + +PROTO_CONTENT = """\ +syntax = "proto3"; + +package smoketest; + +message PingRequest { + string message = 1; +} + +message PingReply { + string message = 1; +} + +service PingService { + rpc Ping (PingRequest) returns (PingReply); +} +""" + + +class TestGrpcioTools(unittest.TestCase): + def test_compile_proto(self): + with tempfile.TemporaryDirectory() as tmpdir: + proto_path = os.path.join(tmpdir, "ping.proto") + with open(proto_path, "w") as f: + f.write(PROTO_CONTENT) + + subprocess.check_call( + [ + sys.executable, + "-m", + "grpc_tools.protoc", + f"--proto_path={tmpdir}", + f"--python_out={tmpdir}", + f"--grpc_python_out={tmpdir}", + f"--pyi_out={tmpdir}", + "ping.proto", + ] + ) + + pb2_path = os.path.join(tmpdir, "ping_pb2.py") + pb2_grpc_path = os.path.join(tmpdir, "ping_pb2_grpc.py") + pyi_path = os.path.join(tmpdir, "ping_pb2.pyi") + self.assertTrue(os.path.exists(pb2_path)) + self.assertTrue(os.path.exists(pb2_grpc_path)) + self.assertTrue(os.path.exists(pyi_path)) + + sys.path.insert(0, tmpdir) + import ping_pb2 + import ping_pb2_grpc + + req = ping_pb2.PingRequest(message="hello") + self.assertEqual(req.message, "hello") + self.assertTrue(hasattr(ping_pb2_grpc, "PingServiceStub")) + self.assertTrue(hasattr(ping_pb2_grpc, "PingServiceServicer"))