mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
parent
407302f93a
commit
e8fce9f1f3
3 changed files with 58 additions and 6 deletions
|
@ -8,7 +8,7 @@ import sys
|
|||
import ray.util.client.server.server as ray_client_server
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
|
||||
from ray.util.client import RayAPIStub
|
||||
from ray.util.client import RayAPIStub, CURRENT_PROTOCOL_VERSION
|
||||
|
||||
import ray
|
||||
|
||||
|
@ -109,6 +109,45 @@ def test_python_version():
|
|||
python_version="2.7.12",
|
||||
ray_version="",
|
||||
ray_commit="",
|
||||
protocol_version=CURRENT_PROTOCOL_VERSION,
|
||||
)
|
||||
|
||||
# inject mock connection function
|
||||
server_handle.data_servicer._build_connection_response = \
|
||||
mock_connection_response
|
||||
|
||||
ray = RayAPIStub()
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = ray.connect("localhost:50051")
|
||||
|
||||
ray = RayAPIStub()
|
||||
info3 = ray.connect("localhost:50051", ignore_version=True)
|
||||
assert info3["num_clients"] == 1, info3
|
||||
ray.disconnect()
|
||||
finally:
|
||||
ray_client_server.shutdown_with_server(server_handle.grpc_server)
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def test_protocol_version():
|
||||
|
||||
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
|
||||
try:
|
||||
ray = RayAPIStub()
|
||||
info1 = ray.connect("localhost:50051")
|
||||
local_py_version = ".".join(
|
||||
[str(x) for x in list(sys.version_info)[:3]])
|
||||
assert info1["protocol_version"] == CURRENT_PROTOCOL_VERSION, info1
|
||||
ray.disconnect()
|
||||
time.sleep(1)
|
||||
|
||||
def mock_connection_response():
|
||||
return ray_client_pb2.ConnectionInfoResponse(
|
||||
num_clients=1,
|
||||
python_version=local_py_version,
|
||||
ray_version="",
|
||||
ray_commit="",
|
||||
protocol_version="2050-01-01", # from the future
|
||||
)
|
||||
|
||||
# inject mock connection function
|
||||
|
|
|
@ -5,6 +5,10 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This version string is incremented to indicate breaking changes in the
|
||||
# protocol that require upgrading the client version.
|
||||
CURRENT_PROTOCOL_VERSION = "2020-02-01"
|
||||
|
||||
|
||||
class RayAPIStub:
|
||||
"""This class stands in as the replacement API for the `import ray` module.
|
||||
|
@ -35,6 +39,9 @@ class RayAPIStub:
|
|||
conn_str: Connection string, in the form "[host]:port"
|
||||
secure: Whether to use a TLS secured gRPC channel
|
||||
metadata: gRPC metadata to send on connect
|
||||
connection_retries: number of connection attempts to make
|
||||
ignore_version: whether to ignore Python or Ray version mismatches.
|
||||
This should only be used for debugging purposes.
|
||||
|
||||
Returns:
|
||||
Dictionary of connection info, e.g., {"num_clients": 1}.
|
||||
|
@ -66,7 +73,8 @@ class RayAPIStub:
|
|||
self.disconnect()
|
||||
raise
|
||||
|
||||
def _check_versions(self, conn_info, ignore_version: bool) -> None:
|
||||
def _check_versions(self, conn_info: Dict[str, Any],
|
||||
ignore_version: bool) -> None:
|
||||
local_major_minor = f"{sys.version_info[0]}.{sys.version_info[1]}"
|
||||
if not conn_info["python_version"].startswith(local_major_minor):
|
||||
version_str = f"{local_major_minor}.{sys.version_info[2]}"
|
||||
|
@ -77,6 +85,14 @@ class RayAPIStub:
|
|||
logger.warning(msg)
|
||||
else:
|
||||
raise RuntimeError(msg)
|
||||
if CURRENT_PROTOCOL_VERSION < conn_info["protocol_version"]:
|
||||
msg = "Client Ray installation out of date:" + \
|
||||
f" client is {CURRENT_PROTOCOL_VERSION}," + \
|
||||
f" server is {conn_info['protocol_version']}"
|
||||
if ignore_version:
|
||||
logger.warning(msg)
|
||||
else:
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect the Ray Client.
|
||||
|
|
|
@ -8,16 +8,13 @@ from threading import Lock
|
|||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.util.client import CURRENT_PROTOCOL_VERSION
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.client.server.server import RayletServicer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This version string is incremented to indicate breaking changes in the
|
||||
# protocol that require upgrading the client version.
|
||||
CURRENT_PROTOCOL_VERSION = "2020-02-01"
|
||||
|
||||
|
||||
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
def __init__(self, basic_service: "RayletServicer"):
|
||||
|
|
Loading…
Add table
Reference in a new issue