mirror of
https://github.com/vale981/ray
synced 2025-03-09 04:46:38 -04:00
parent
407302f93a
commit
e8fce9f1f3
3 changed files with 58 additions and 6 deletions
|
@ -8,7 +8,7 @@ import sys
|
||||||
import ray.util.client.server.server as ray_client_server
|
import ray.util.client.server.server as ray_client_server
|
||||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||||
|
|
||||||
from ray.util.client import RayAPIStub
|
from ray.util.client import RayAPIStub, CURRENT_PROTOCOL_VERSION
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
|
@ -109,6 +109,45 @@ def test_python_version():
|
||||||
python_version="2.7.12",
|
python_version="2.7.12",
|
||||||
ray_version="",
|
ray_version="",
|
||||||
ray_commit="",
|
ray_commit="",
|
||||||
|
protocol_version=CURRENT_PROTOCOL_VERSION,
|
||||||
|
)
|
||||||
|
|
||||||
|
# inject mock connection function
|
||||||
|
server_handle.data_servicer._build_connection_response = \
|
||||||
|
mock_connection_response
|
||||||
|
|
||||||
|
ray = RayAPIStub()
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
_ = ray.connect("localhost:50051")
|
||||||
|
|
||||||
|
ray = RayAPIStub()
|
||||||
|
info3 = ray.connect("localhost:50051", ignore_version=True)
|
||||||
|
assert info3["num_clients"] == 1, info3
|
||||||
|
ray.disconnect()
|
||||||
|
finally:
|
||||||
|
ray_client_server.shutdown_with_server(server_handle.grpc_server)
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_protocol_version():
|
||||||
|
|
||||||
|
server_handle, _ = ray_client_server.init_and_serve("localhost:50051")
|
||||||
|
try:
|
||||||
|
ray = RayAPIStub()
|
||||||
|
info1 = ray.connect("localhost:50051")
|
||||||
|
local_py_version = ".".join(
|
||||||
|
[str(x) for x in list(sys.version_info)[:3]])
|
||||||
|
assert info1["protocol_version"] == CURRENT_PROTOCOL_VERSION, info1
|
||||||
|
ray.disconnect()
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
def mock_connection_response():
|
||||||
|
return ray_client_pb2.ConnectionInfoResponse(
|
||||||
|
num_clients=1,
|
||||||
|
python_version=local_py_version,
|
||||||
|
ray_version="",
|
||||||
|
ray_commit="",
|
||||||
|
protocol_version="2050-01-01", # from the future
|
||||||
)
|
)
|
||||||
|
|
||||||
# inject mock connection function
|
# inject mock connection function
|
||||||
|
|
|
@ -5,6 +5,10 @@ import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# This version string is incremented to indicate breaking changes in the
|
||||||
|
# protocol that require upgrading the client version.
|
||||||
|
CURRENT_PROTOCOL_VERSION = "2020-02-01"
|
||||||
|
|
||||||
|
|
||||||
class RayAPIStub:
|
class RayAPIStub:
|
||||||
"""This class stands in as the replacement API for the `import ray` module.
|
"""This class stands in as the replacement API for the `import ray` module.
|
||||||
|
@ -35,6 +39,9 @@ class RayAPIStub:
|
||||||
conn_str: Connection string, in the form "[host]:port"
|
conn_str: Connection string, in the form "[host]:port"
|
||||||
secure: Whether to use a TLS secured gRPC channel
|
secure: Whether to use a TLS secured gRPC channel
|
||||||
metadata: gRPC metadata to send on connect
|
metadata: gRPC metadata to send on connect
|
||||||
|
connection_retries: number of connection attempts to make
|
||||||
|
ignore_version: whether to ignore Python or Ray version mismatches.
|
||||||
|
This should only be used for debugging purposes.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of connection info, e.g., {"num_clients": 1}.
|
Dictionary of connection info, e.g., {"num_clients": 1}.
|
||||||
|
@ -66,7 +73,8 @@ class RayAPIStub:
|
||||||
self.disconnect()
|
self.disconnect()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _check_versions(self, conn_info, ignore_version: bool) -> None:
|
def _check_versions(self, conn_info: Dict[str, Any],
|
||||||
|
ignore_version: bool) -> None:
|
||||||
local_major_minor = f"{sys.version_info[0]}.{sys.version_info[1]}"
|
local_major_minor = f"{sys.version_info[0]}.{sys.version_info[1]}"
|
||||||
if not conn_info["python_version"].startswith(local_major_minor):
|
if not conn_info["python_version"].startswith(local_major_minor):
|
||||||
version_str = f"{local_major_minor}.{sys.version_info[2]}"
|
version_str = f"{local_major_minor}.{sys.version_info[2]}"
|
||||||
|
@ -77,6 +85,14 @@ class RayAPIStub:
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(msg)
|
raise RuntimeError(msg)
|
||||||
|
if CURRENT_PROTOCOL_VERSION < conn_info["protocol_version"]:
|
||||||
|
msg = "Client Ray installation out of date:" + \
|
||||||
|
f" client is {CURRENT_PROTOCOL_VERSION}," + \
|
||||||
|
f" server is {conn_info['protocol_version']}"
|
||||||
|
if ignore_version:
|
||||||
|
logger.warning(msg)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
"""Disconnect the Ray Client.
|
"""Disconnect the Ray Client.
|
||||||
|
|
|
@ -8,16 +8,13 @@ from threading import Lock
|
||||||
|
|
||||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||||
|
from ray.util.client import CURRENT_PROTOCOL_VERSION
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.util.client.server.server import RayletServicer
|
from ray.util.client.server.server import RayletServicer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# This version string is incremented to indicate breaking changes in the
|
|
||||||
# protocol that require upgrading the client version.
|
|
||||||
CURRENT_PROTOCOL_VERSION = "2020-02-01"
|
|
||||||
|
|
||||||
|
|
||||||
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||||
def __init__(self, basic_service: "RayletServicer"):
|
def __init__(self, basic_service: "RayletServicer"):
|
||||||
|
|
Loading…
Add table
Reference in a new issue