diff --git a/python/ray/tests/test_client_init.py b/python/ray/tests/test_client_init.py index 9528f1d20..6b6ce8a42 100644 --- a/python/ray/tests/test_client_init.py +++ b/python/ray/tests/test_client_init.py @@ -8,7 +8,7 @@ import sys import ray.util.client.server.server as ray_client_server import ray.core.generated.ray_client_pb2 as ray_client_pb2 -from ray.util.client import RayAPIStub +from ray.util.client import RayAPIStub, CURRENT_PROTOCOL_VERSION import ray @@ -109,6 +109,45 @@ def test_python_version(): python_version="2.7.12", ray_version="", ray_commit="", + protocol_version=CURRENT_PROTOCOL_VERSION, + ) + + # inject mock connection function + server_handle.data_servicer._build_connection_response = \ + mock_connection_response + + ray = RayAPIStub() + with pytest.raises(RuntimeError): + _ = ray.connect("localhost:50051") + + ray = RayAPIStub() + info3 = ray.connect("localhost:50051", ignore_version=True) + assert info3["num_clients"] == 1, info3 + ray.disconnect() + finally: + ray_client_server.shutdown_with_server(server_handle.grpc_server) + time.sleep(2) + + +def test_protocol_version(): + + server_handle, _ = ray_client_server.init_and_serve("localhost:50051") + try: + ray = RayAPIStub() + info1 = ray.connect("localhost:50051") + local_py_version = ".".join( + [str(x) for x in list(sys.version_info)[:3]]) + assert info1["protocol_version"] == CURRENT_PROTOCOL_VERSION, info1 + ray.disconnect() + time.sleep(1) + + def mock_connection_response(): + return ray_client_pb2.ConnectionInfoResponse( + num_clients=1, + python_version=local_py_version, + ray_version="", + ray_commit="", + protocol_version="2050-01-01", # from the future ) # inject mock connection function diff --git a/python/ray/util/client/__init__.py b/python/ray/util/client/__init__.py index 9a2d14877..3fdcd4f88 100644 --- a/python/ray/util/client/__init__.py +++ b/python/ray/util/client/__init__.py @@ -5,6 +5,10 @@ import logging logger = logging.getLogger(__name__) +# This version string is incremented to indicate breaking changes in the +# protocol that require upgrading the client version. +CURRENT_PROTOCOL_VERSION = "2020-02-01" + class RayAPIStub: """This class stands in as the replacement API for the `import ray` module. @@ -35,6 +39,9 @@ class RayAPIStub: conn_str: Connection string, in the form "[host]:port" secure: Whether to use a TLS secured gRPC channel metadata: gRPC metadata to send on connect + connection_retries: number of connection attempts to make + ignore_version: whether to ignore Python or Ray version mismatches. + This should only be used for debugging purposes. Returns: Dictionary of connection info, e.g., {"num_clients": 1}. @@ -66,7 +73,8 @@ class RayAPIStub: self.disconnect() raise - def _check_versions(self, conn_info, ignore_version: bool) -> None: + def _check_versions(self, conn_info: Dict[str, Any], + ignore_version: bool) -> None: local_major_minor = f"{sys.version_info[0]}.{sys.version_info[1]}" if not conn_info["python_version"].startswith(local_major_minor): version_str = f"{local_major_minor}.{sys.version_info[2]}" @@ -77,6 +85,14 @@ class RayAPIStub: logger.warning(msg) else: raise RuntimeError(msg) + if CURRENT_PROTOCOL_VERSION < conn_info["protocol_version"]: + msg = "Client Ray installation out of date:" + \ + f" client is {CURRENT_PROTOCOL_VERSION}," + \ + f" server is {conn_info['protocol_version']}" + if ignore_version: + logger.warning(msg) + else: + raise RuntimeError(msg) def disconnect(self): """Disconnect the Ray Client. diff --git a/python/ray/util/client/server/dataservicer.py b/python/ray/util/client/server/dataservicer.py index 709147820..82ddc85c6 100644 --- a/python/ray/util/client/server/dataservicer.py +++ b/python/ray/util/client/server/dataservicer.py @@ -8,16 +8,13 @@ from threading import Lock import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc +from ray.util.client import CURRENT_PROTOCOL_VERSION if TYPE_CHECKING: from ray.util.client.server.server import RayletServicer logger = logging.getLogger(__name__) -# This version string is incremented to indicate breaking changes in the -# protocol that require upgrading the client version. -CURRENT_PROTOCOL_VERSION = "2020-02-01" - class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): def __init__(self, basic_service: "RayletServicer"):