Check Ray client protocol version (#13886)

* wip

* wip

* fix tests
This commit is contained in:
Eric Liang 2021-02-03 16:44:09 -08:00 committed by GitHub
parent 407302f93a
commit e8fce9f1f3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 6 deletions

View file

@ -8,7 +8,7 @@ import sys
import ray.util.client.server.server as ray_client_server
import ray.core.generated.ray_client_pb2 as ray_client_pb2
from ray.util.client import RayAPIStub
from ray.util.client import RayAPIStub, CURRENT_PROTOCOL_VERSION
import ray
@ -109,6 +109,45 @@ def test_python_version():
python_version="2.7.12",
ray_version="",
ray_commit="",
protocol_version=CURRENT_PROTOCOL_VERSION,
)
# inject mock connection function
server_handle.data_servicer._build_connection_response = \
mock_connection_response
ray = RayAPIStub()
with pytest.raises(RuntimeError):
_ = ray.connect("localhost:50051")
ray = RayAPIStub()
info3 = ray.connect("localhost:50051", ignore_version=True)
assert info3["num_clients"] == 1, info3
ray.disconnect()
finally:
ray_client_server.shutdown_with_server(server_handle.grpc_server)
time.sleep(2)
def test_protocol_version():
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
try:
ray = RayAPIStub()
info1 = ray.connect("localhost:50051")
local_py_version = ".".join(
[str(x) for x in list(sys.version_info)[:3]])
assert info1["protocol_version"] == CURRENT_PROTOCOL_VERSION, info1
ray.disconnect()
time.sleep(1)
def mock_connection_response():
return ray_client_pb2.ConnectionInfoResponse(
num_clients=1,
python_version=local_py_version,
ray_version="",
ray_commit="",
protocol_version="2050-01-01", # from the future
)
# inject mock connection function

View file

@ -5,6 +5,10 @@ import logging
logger = logging.getLogger(__name__)
# This version string is incremented to indicate breaking changes in the
# protocol that require upgrading the client version.
CURRENT_PROTOCOL_VERSION = "2020-02-01"
class RayAPIStub:
"""This class stands in as the replacement API for the `import ray` module.
@ -35,6 +39,9 @@ class RayAPIStub:
conn_str: Connection string, in the form "[host]:port"
secure: Whether to use a TLS secured gRPC channel
metadata: gRPC metadata to send on connect
connection_retries: number of connection attempts to make
ignore_version: whether to ignore Python or Ray version mismatches.
This should only be used for debugging purposes.
Returns:
Dictionary of connection info, e.g., {"num_clients": 1}.
@ -66,7 +73,8 @@ class RayAPIStub:
self.disconnect()
raise
def _check_versions(self, conn_info, ignore_version: bool) -> None:
def _check_versions(self, conn_info: Dict[str, Any],
ignore_version: bool) -> None:
local_major_minor = f"{sys.version_info[0]}.{sys.version_info[1]}"
if not conn_info["python_version"].startswith(local_major_minor):
version_str = f"{local_major_minor}.{sys.version_info[2]}"
@ -77,6 +85,14 @@ class RayAPIStub:
logger.warning(msg)
else:
raise RuntimeError(msg)
if CURRENT_PROTOCOL_VERSION < conn_info["protocol_version"]:
msg = "Client Ray installation out of date:" + \
f" client is {CURRENT_PROTOCOL_VERSION}," + \
f" server is {conn_info['protocol_version']}"
if ignore_version:
logger.warning(msg)
else:
raise RuntimeError(msg)
def disconnect(self):
"""Disconnect the Ray Client.

View file

@ -8,16 +8,13 @@ from threading import Lock
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.util.client import CURRENT_PROTOCOL_VERSION
if TYPE_CHECKING:
from ray.util.client.server.server import RayletServicer
logger = logging.getLogger(__name__)
# This version string is incremented to indicate breaking changes in the
# protocol that require upgrading the client version.
CURRENT_PROTOCOL_VERSION = "2020-02-01"
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, basic_service: "RayletServicer"):