[core] Integration runtime_env with ray client (#14881)

* server side ready

* client size

* py

* fix

* up

* format

* add files

* add pyx

* up

* up

* up

* add keys

* format

* update

* format

* add unittests

* add files

* up

* up

* fix

* up

* fix thread issue

* format

* fix

* update proto

* Fix

* format

* fix

* more

* fix conflict

* fix

* fix order

* format

* add

* up

* compiling fix

* lint

* fix

* format

* fix some

* some fix

* fix comment

* test cases

* add test

* comments

* fix name

* format

* fix

* revert gcs-kv

* fix comments

* fix failure

* fix test

* format

* fix timeout

* fix

* fix

* fix

* format

* format

* fix flaky test

Co-authored-by: Yi Cheng <singye888@gmail.com>
This commit is contained in:
Yi Cheng 2021-03-31 11:39:34 -07:00 committed by GitHub
parent 91cf272c2e
commit 4480132229
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 340 additions and 146 deletions

View file

@ -6,6 +6,7 @@ from functools import wraps
RAY_CLIENT_MODE_ATTR = "__ray_client_mode_key__"
client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1"
os.environ.update({"RAY_CLIENT_MODE": "0"})
_client_hook_enabled = True

View file

@ -1,6 +1,5 @@
import hashlib
import logging
import inspect
from filelock import FileLock
from pathlib import Path
@ -9,10 +8,10 @@ from ray.job_config import JobConfig
from enum import Enum
from ray.experimental.internal_kv import (_internal_kv_put, _internal_kv_get,
_internal_kv_exists)
_internal_kv_exists,
_internal_kv_initialized)
from typing import List, Tuple
from types import ModuleType
from urllib.parse import urlparse
import os
import sys
@ -160,7 +159,7 @@ def _parse_uri(pkg_uri: str) -> Tuple[Protocol, str]:
# TODO(yic): Fix this later to handle big directories in better way
def get_project_package_name(working_dir: str, modules: List[str]) -> str:
def get_project_package_name(working_dir: str, py_modules: List[str]) -> str:
"""Get the name of the package by working dir and modules.
This function will generate the name of the package by the working
@ -180,7 +179,7 @@ def get_project_package_name(working_dir: str, modules: List[str]) -> str:
e.g., _ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip
Args:
working_dir (str): The working directory.
modules (list[module]): The python module.
py_modules (list[str]): The python module.
Returns:
Package name as a string.
@ -191,14 +190,12 @@ def get_project_package_name(working_dir: str, modules: List[str]) -> str:
assert isinstance(working_dir, str)
assert Path(working_dir).exists()
hash_val = _xor_bytes(hash_val, _hash_modules(Path(working_dir)))
for module in modules or []:
assert inspect.ismodule(module)
hash_val = _xor_bytes(hash_val,
_hash_modules(Path(module.__file__).parent))
for py_module in py_modules or []:
hash_val = _xor_bytes(hash_val, _hash_modules(Path(py_module).parent))
return RAY_PKG_PREFIX + hash_val.hex() + ".zip" if hash_val else None
def create_project_package(working_dir: str, modules: List[ModuleType],
def create_project_package(working_dir: str, py_modules: List[str],
output_path: str) -> None:
"""Create a pckage that will be used by workers.
@ -207,7 +204,8 @@ def create_project_package(working_dir: str, modules: List[ModuleType],
Args:
working_dir (str): The working directory.
modules (list[module]): The python modules to be included.
py_modules (list[str]): The list of path of python modules to be
included.
output_path (str): The path of file to be created.
"""
pkg_file = Path(output_path)
@ -216,20 +214,15 @@ def create_project_package(working_dir: str, modules: List[ModuleType],
# put all files in /path/working_dir into zip
working_path = Path(working_dir)
_zip_module(working_path, working_path, zip_handler)
for module in modules or []:
logger.info(module.__file__)
# we only take care of modules with path like this for now:
# /path/module_name/__init__.py
# module_path should be: /path/module_name
module_path = Path(module.__file__).parent
_zip_module(module_path, module_path.parent, zip_handler)
for py_module in py_modules or []:
_zip_module(Path(py_module), Path(py_module).parent, zip_handler)
def fetch_package(pkg_uri: str, pkg_file: Path) -> int:
def fetch_package(pkg_uri: str, pkg_file: Path = None) -> int:
"""Fetch a package from a given uri.
This function is used to fetch a pacakge from the given uri to local
filesystem.
filesystem. If it exists, it'll just return 0.
Args:
pkg_uri (str): The uri of the package to download.
@ -239,9 +232,15 @@ def fetch_package(pkg_uri: str, pkg_file: Path) -> int:
Returns:
The number of bytes downloaded.
"""
if pkg_file is None:
pkg_file = Path(_get_local_path(pkg_uri))
if pkg_file.exists():
return 0
(protocol, pkg_name) = _parse_uri(pkg_uri)
if protocol in (Protocol.GCS, Protocol.PIN_GCS):
code = _internal_kv_get(pkg_uri)
if code is None:
raise IOError("Fetch uri failed")
code = code or b""
pkg_file.write_bytes(code)
return len(code)
@ -284,6 +283,7 @@ def package_exists(pkg_uri: str) -> bool:
Return:
True for package existing and False for not.
"""
assert _internal_kv_initialized()
(protocol, pkg_name) = _parse_uri(pkg_uri)
if protocol in (Protocol.GCS, Protocol.PIN_GCS):
return _internal_kv_exists(pkg_uri)
@ -303,11 +303,11 @@ def rewrite_working_dir_uri(job_config: JobConfig) -> None:
"""
# For now, we only support local directory and packages
working_dir = job_config.runtime_env.get("working_dir")
required_modules = job_config.runtime_env.get("local_modules")
py_modules = job_config.runtime_env.get("py_modules")
if (not job_config.runtime_env.get("working_dir_uri")) and (
working_dir or required_modules):
pkg_name = get_project_package_name(working_dir, required_modules)
if (not job_config.runtime_env.get("working_dir_uri")) and (working_dir
or py_modules):
pkg_name = get_project_package_name(working_dir, py_modules)
job_config.runtime_env[
"working_dir_uri"] = Protocol.GCS.value + "://" + pkg_name
@ -323,18 +323,18 @@ def upload_runtime_env_package_if_needed(job_config: JobConfig) -> None:
Args:
job_config (JobConfig): The job config of driver.
"""
assert _internal_kv_initialized()
pkg_uris = job_config.get_runtime_env_uris()
for pkg_uri in pkg_uris:
if not package_exists(pkg_uri):
file_path = _get_local_path(pkg_uri)
pkg_file = Path(file_path)
working_dir = job_config.runtime_env.get("working_dir")
required_modules = job_config.runtime_env.get("local_modules")
py_modules = job_config.runtime_env.get("py_modules")
logger.info(f"{pkg_uri} doesn't exist. Create new package with"
f" {working_dir} and {required_modules}")
f" {working_dir} and {py_modules}")
if not pkg_file.exists():
create_project_package(working_dir, required_modules,
file_path)
create_project_package(working_dir, py_modules, file_path)
# Push the data to remote storage
pkg_size = push_package(pkg_uri, pkg_file)
logger.info(f"{pkg_uri} has been pushed with {pkg_size} bytes")
@ -349,12 +349,13 @@ def ensure_runtime_env_setup(pkg_uris: List[str]) -> None:
Args:
pkg_uri list(str): Package of the working dir for the runtime env.
"""
assert _internal_kv_initialized()
for pkg_uri in pkg_uris:
pkg_file = Path(_get_local_path(pkg_uri))
# For each node, the package will only be downloaded one time
# Locking to avoid multiple process download concurrently
lock = FileLock(str(pkg_file) + ".lock")
with lock:
with FileLock(str(pkg_file) + ".lock"):
# TODO(yic): checksum calculation is required
if pkg_file.exists():
logger.debug(

View file

@ -69,14 +69,14 @@ def load_package(config_path: str) -> "_RuntimePackage":
# Autofill working directory by uploading to GCS storage.
if "working_dir" not in runtime_env:
pkg_name = runtime_support.get_project_package_name(
working_dir=base_dir, modules=[])
working_dir=base_dir, py_modules=[])
pkg_uri = runtime_support.Protocol.GCS.value + "://" + pkg_name
def do_register_package():
if not runtime_support.package_exists(pkg_uri):
tmp_path = os.path.join(_pkg_tmp(), "_tmp{}".format(pkg_name))
runtime_support.create_project_package(
working_dir=base_dir, modules=[], output_path=tmp_path)
working_dir=base_dir, py_modules=[], output_path=tmp_path)
# TODO(ekl) does this get garbage collected correctly with the
# current job id?
runtime_support.push_package(pkg_uri, tmp_path)

View file

@ -15,6 +15,7 @@ class JobConfig:
`CLASSPATH` in Java and `PYTHONPATH` in Python.
runtime_env (dict): A runtime environment dictionary (see
``runtime_env.py`` for detailed documentation).
client_job (bool): A boolean represent the source of the job.
"""
def __init__(self,
@ -22,7 +23,8 @@ class JobConfig:
num_java_workers_per_process=1,
jvm_options=None,
code_search_path=None,
runtime_env=None):
runtime_env=None,
client_job=False):
if worker_env is None:
self.worker_env = dict()
else:
@ -45,8 +47,15 @@ class JobConfig:
f"The type of code search path is incorrect: " \
f"{type(code_search_path)}"
self.runtime_env = runtime_env or dict()
self.client_job = client_job
def serialize(self):
"""Serialize the struct into protobuf string"""
job_config = self.get_proto_job_config()
return job_config.SerializeToString()
def get_proto_job_config(self):
"""Return the prototype structure of JobConfig"""
job_config = ray.gcs_utils.JobConfig()
for key in self.worker_env:
job_config.worker_env[key] = self.worker_env[key]
@ -55,9 +64,10 @@ class JobConfig:
job_config.jvm_options.extend(self.jvm_options)
job_config.code_search_path.extend(self.code_search_path)
job_config.runtime_env.CopyFrom(self._get_proto_runtime())
return job_config.SerializeToString()
return job_config
def get_runtime_env_uris(self):
"""Get the uris of runtime environment"""
if self.runtime_env.get("working_dir_uri"):
return [self.runtime_env.get("working_dir_uri")]
return []

View file

@ -8,7 +8,7 @@ import sys
import time
import socket
import math
from typing import Dict
from contextlib import redirect_stdout, redirect_stderr
import yaml
@ -178,11 +178,12 @@ def kill_process_by_name(name, SIGKILL=False):
p.terminate()
def run_string_as_driver(driver_script):
def run_string_as_driver(driver_script: str, env: Dict = None):
"""Run a driver as a separate process.
Args:
driver_script: A string to run as a Python script.
driver_script (str): A string to run as a Python script.
env (dict): The environment variables for the driver.
Returns:
The script's output.
@ -191,7 +192,9 @@ def run_string_as_driver(driver_script):
[sys.executable, "-"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
stderr=subprocess.STDOUT,
env=env,
)
with proc:
output = proc.communicate(driver_script.encode("ascii"))[0]
if proc.returncode:
@ -489,7 +492,7 @@ def new_scheduler_enabled():
def client_test_enabled() -> bool:
return os.environ.get("RAY_CLIENT_MODE") == "1"
return os.environ.get("RAY_CLIENT_MODE") is not None
def fetch_prometheus(prom_addresses):

View file

@ -54,8 +54,8 @@ def init_and_serve_lazy():
cluster.add_node(num_cpus=1, num_gpus=0)
address = cluster.address
def connect():
ray.init(address=address)
def connect(job_config=None):
ray.init(address=address, job_config=job_config)
server_handle = ray_client_server.serve("localhost:50051", connect)
yield server_handle

View file

@ -5,9 +5,9 @@ import unittest
import subprocess
import tempfile
from unittest import mock
from pathlib import Path
import ray
from ray.test_utils import run_string_as_driver
from ray._private.utils import get_conda_env_dir, get_conda_bin_executable
@ -19,14 +19,19 @@ import logging
sys.path.insert(0, "{working_dir}")
import test_module
import ray
import ray.util
import os
job_config = ray.job_config.JobConfig(
runtime_env={runtime_env}
)
ray.init(address="{redis_address}",
job_config=job_config,
logging_level=logging.DEBUG)
if os.environ.get("USE_RAY_CLIENT"):
ray.util.connect("{address}", job_config=job_config)
else:
ray.init(address="{address}",
job_config=job_config,
logging_level=logging.DEBUG)
@ray.remote
def run_test():
@ -40,15 +45,17 @@ class TestActor(object):
{execute_statement}
ray.shutdown()
if os.environ.get("USE_RAY_CLIENT"):
ray.util.disconnect()
else:
ray.shutdown()
from time import sleep
sleep(5)
sleep(10)
"""
@pytest.fixture(scope="session")
def working_dir():
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
path = Path(tmp_dir)
module_path = path / "test_module"
@ -68,50 +75,59 @@ from test_module.test import one
yield tmp_dir
def start_client_server(cluster, client_mode):
from ray._private.runtime_env import PKG_DIR
if not client_mode:
return (cluster.address, None, PKG_DIR)
ray.worker._global_node._ray_params.ray_client_server_port = "10003"
ray.worker._global_node.start_ray_client_server()
return ("localhost:10003", {"USE_RAY_CLIENT": "1"}, PKG_DIR)
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
def test_single_node(ray_start_cluster_head, working_dir):
@pytest.mark.parametrize("client_mode", [True, False])
def test_single_node(ray_start_cluster_head, working_dir, client_mode):
cluster = ray_start_cluster_head
redis_address = cluster.address
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))"
script = driver_script.format(**locals())
out = run_string_as_driver(script)
out = run_string_as_driver(script, env)
assert out.strip().split()[-1] == "1000"
from ray._private.runtime_env import PKG_DIR
assert len(list(Path(PKG_DIR).iterdir())) == 1
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
def test_two_node(two_node_cluster, working_dir):
@pytest.mark.parametrize("client_mode", [True, False])
def test_two_node(two_node_cluster, working_dir, client_mode):
cluster, _ = two_node_cluster
redis_address = cluster.address
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))"
script = driver_script.format(**locals())
out = run_string_as_driver(script)
out = run_string_as_driver(script, env)
assert out.strip().split()[-1] == "1000"
from ray._private.runtime_env import PKG_DIR
assert len(list(Path(PKG_DIR).iterdir())) == 1
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
def test_two_node_module(two_node_cluster, working_dir):
@pytest.mark.parametrize("client_mode", [True, False])
def test_two_node_module(two_node_cluster, working_dir, client_mode):
cluster, _ = two_node_cluster
redis_address = cluster.address
runtime_env = """{ "local_modules": [test_module] }"""
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
runtime_env = """{ "py_modules": [test_module.__path__[0]] }"""
execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))"
script = driver_script.format(**locals())
print(script)
out = run_string_as_driver(script)
out = run_string_as_driver(script, env)
assert out.strip().split()[-1] == "1000"
from ray._private.runtime_env import PKG_DIR
assert len(list(Path(PKG_DIR).iterdir())) == 1
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
def test_two_node_uri(two_node_cluster, working_dir):
@pytest.mark.parametrize("client_mode", [True, False])
def test_two_node_uri(two_node_cluster, working_dir, client_mode):
cluster, _ = two_node_cluster
redis_address = cluster.address
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
import ray._private.runtime_env as runtime_env
import tempfile
with tempfile.NamedTemporaryFile(suffix="zip") as tmp_file:
@ -122,41 +138,40 @@ def test_two_node_uri(two_node_cluster, working_dir):
runtime_env = f"""{{ "working_dir_uri": "{pkg_uri}" }}"""
execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))"
script = driver_script.format(**locals())
out = run_string_as_driver(script)
out = run_string_as_driver(script, env)
assert out.strip().split()[-1] == "1000"
from ray._private.runtime_env import PKG_DIR
assert len(list(Path(PKG_DIR).iterdir())) == 1
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
def test_regular_actors(ray_start_cluster_head, working_dir):
@pytest.mark.parametrize("client_mode", [True, False])
def test_regular_actors(ray_start_cluster_head, working_dir, client_mode):
cluster = ray_start_cluster_head
redis_address = cluster.address
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
execute_statement = """
test_actor = TestActor.options(name="test_actor").remote()
print(sum(ray.get([test_actor.one.remote()] * 1000)))
"""
script = driver_script.format(**locals())
out = run_string_as_driver(script)
out = run_string_as_driver(script, env)
assert out.strip().split()[-1] == "1000"
from ray._private.runtime_env import PKG_DIR
assert len(list(Path(PKG_DIR).iterdir())) == 1
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
def test_detached_actors(ray_start_cluster_head, working_dir):
@pytest.mark.parametrize("client_mode", [True, False])
def test_detached_actors(ray_start_cluster_head, working_dir, client_mode):
cluster = ray_start_cluster_head
redis_address = cluster.address
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
execute_statement = """
test_actor = TestActor.options(name="test_actor", lifetime="detached").remote()
print(sum(ray.get([test_actor.one.remote()] * 1000)))
"""
script = driver_script.format(**locals())
out = run_string_as_driver(script)
out = run_string_as_driver(script, env)
assert out.strip().split()[-1] == "1000"
from ray._private.runtime_env import PKG_DIR
# It's a detached actors, so it should still be there
assert len(list(Path(PKG_DIR).iterdir())) == 2
pkg = list(Path(PKG_DIR).glob("*.zip"))[0]

View file

@ -1,5 +1,5 @@
from typing import List, Tuple, Dict, Any
from ray.job_config import JobConfig
import os
import sys
import logging
@ -29,6 +29,7 @@ class RayAPIStub:
def connect(self,
conn_str: str,
job_config: JobConfig = None,
secure: bool = False,
metadata: List[Tuple[str, str]] = None,
connection_retries: int = 3,
@ -38,6 +39,7 @@ class RayAPIStub:
Args:
conn_str: Connection string, in the form "[host]:port"
job_config: The job config of the server.
secure: Whether to use a TLS secured gRPC channel
metadata: gRPC metadata to send on connect
connection_retries: number of connection attempts to make
@ -67,6 +69,7 @@ class RayAPIStub:
metadata=metadata,
connection_retries=connection_retries)
self.api.worker = self.client_worker
self.client_worker._server_init(job_config)
conn_info = self.client_worker.connection_info()
self._check_versions(conn_info, ignore_version)
return conn_info

View file

@ -247,6 +247,10 @@ class ClientAPI:
"""Hook for internal_kv._internal_kv_initialized."""
return self.is_initialized()
def _internal_kv_exists(self, key: bytes) -> bool:
"""Hook for internal_kv._internal_kv_exists."""
return self.worker.internal_kv_exists(as_bytes(key))
def _internal_kv_get(self, key: bytes) -> bytes:
"""Hook for internal_kv._internal_kv_get."""
return self.worker.internal_kv_get(as_bytes(key))

View file

@ -117,6 +117,19 @@ class DataClient:
req.req_id = req_id
self.request_queue.put(req)
def Init(self, request: ray_client_pb2.InitRequest,
context=None) -> ray_client_pb2.InitResponse:
datareq = ray_client_pb2.DataRequest(init=request, )
resp = self._blocking_send(datareq)
return resp.init
def PrepRuntimeEnv(self,
request: ray_client_pb2.PrepRuntimeEnvRequest,
context=None) -> ray_client_pb2.PrepRuntimeEnvResponse:
datareq = ray_client_pb2.DataRequest(prep_runtime_env=request, )
resp = self._blocking_send(datareq)
return resp.prep_runtime_env
def ConnectionInfo(self,
context=None) -> ray_client_pb2.ConnectionInfoResponse:
datareq = ray_client_pb2.DataRequest(

View file

@ -29,7 +29,7 @@ def ray_start_client_server_pair():
def ray_start_cluster_client_server_pair(address):
ray._inside_client_test = True
def ray_connect_handler():
def ray_connect_handler(job_config=None):
real_ray.init(address=address)
server = ray_client_server.serve(

View file

@ -3,7 +3,7 @@ import logging
import grpc
import sys
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING
from threading import Lock
import ray.core.generated.ray_client_pb2 as ray_client_pb2
@ -20,12 +20,10 @@ logger = logging.getLogger(__name__)
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, basic_service: "RayletServicer",
ray_connect_handler: Callable):
def __init__(self, basic_service: "RayletServicer"):
self.basic_service = basic_service
self.clients_lock = Lock()
self.num_clients = 0 # guarded by self.clients_lock
self.ray_connect_handler = ray_connect_handler
def Datapath(self, request_iterator, context):
metadata = {k: v for k, v in context.invocation_metadata()}
@ -36,55 +34,47 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
return
logger.debug(f"New data connection from client {client_id}: ")
try:
with self.clients_lock:
with disable_client_hook():
# It's important to keep the ray initialization call
# within this locked context or else Ray could hang.
if self.num_clients == 0 and not ray.is_initialized():
self.ray_connect_handler()
threshold = int(CLIENT_SERVER_MAX_THREADS / 2)
if self.num_clients >= threshold:
context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
logger.warning(
f"[Data Servicer]: Num clients {self.num_clients} "
f"has reached the threshold {threshold}. "
f"Rejecting client: {metadata['client_id']}. ")
if log_once("client_threshold"):
logger.warning(
"You can configure the client connection "
"threshold by setting the "
"RAY_CLIENT_SERVER_MAX_THREADS env var "
f"(currently set to {CLIENT_SERVER_MAX_THREADS}).")
return
self.num_clients += 1
logger.debug(f"Accepted data connection from {client_id}. "
f"Total clients: {self.num_clients}")
accepted_connection = True
for req in request_iterator:
resp = None
req_type = req.WhichOneof("type")
if req_type == "get":
get_resp = self.basic_service._get_object(
req.get, client_id)
resp = ray_client_pb2.DataResponse(get=get_resp)
elif req_type == "put":
put_resp = self.basic_service._put_object(
req.put, client_id)
resp = ray_client_pb2.DataResponse(put=put_resp)
elif req_type == "release":
released = []
for rel_id in req.release.ids:
rel = self.basic_service.release(client_id, rel_id)
released.append(rel)
resp = ray_client_pb2.DataResponse(
release=ray_client_pb2.ReleaseResponse(ok=released))
elif req_type == "connection_info":
resp = ray_client_pb2.DataResponse(
connection_info=self._build_connection_response())
if req_type == "init":
resp = self._init(req.init, client_id)
if resp is None:
context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
return
logger.debug(f"Accepted data connection from {client_id}. "
f"Total clients: {self.num_clients}")
accepted_connection = True
else:
raise Exception(f"Unreachable code: Request type "
f"{req_type} not handled in Datapath")
assert accepted_connection
if req_type == "get":
get_resp = self.basic_service._get_object(
req.get, client_id)
resp = ray_client_pb2.DataResponse(get=get_resp)
elif req_type == "put":
put_resp = self.basic_service._put_object(
req.put, client_id)
resp = ray_client_pb2.DataResponse(put=put_resp)
elif req_type == "release":
released = []
for rel_id in req.release.ids:
rel = self.basic_service.release(client_id, rel_id)
released.append(rel)
resp = ray_client_pb2.DataResponse(
release=ray_client_pb2.ReleaseResponse(
ok=released))
elif req_type == "connection_info":
resp = ray_client_pb2.DataResponse(
connection_info=self._build_connection_response())
elif req_type == "prep_runtime_env":
with self.clients_lock:
resp_prep = self.basic_service.PrepRuntimeEnv(
req.prep_runtime_env)
resp = ray_client_pb2.DataResponse(
prep_runtime_env=resp_prep)
else:
raise Exception(f"Unreachable code: Request type "
f"{req_type} not handled in Datapath")
resp.req_id = req.req_id
yield resp
except grpc.RpcError as e:
@ -106,6 +96,25 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
logger.debug("Shutting down ray.")
ray.shutdown()
def _init(self, req_init, client_id):
with self.clients_lock:
threshold = int(CLIENT_SERVER_MAX_THREADS / 2)
if self.num_clients >= threshold:
logger.warning(
f"[Data Servicer]: Num clients {self.num_clients} "
f"has reached the threshold {threshold}. "
f"Rejecting client: {client_id}. ")
if log_once("client_threshold"):
logger.warning(
"You can configure the client connection "
"threshold by setting the "
"RAY_CLIENT_SERVER_MAX_THREADS env var "
f"(currently set to {CLIENT_SERVER_MAX_THREADS}).")
return None
resp_init = self.basic_service.Init(req_init)
self.num_clients += 1
return ray_client_pb2.DataResponse(init=resp_init, )
def _build_connection_response(self):
with self.clients_lock:
cur_num_clients = self.num_clients

View file

@ -4,14 +4,16 @@ import grpc
import base64
from collections import defaultdict
from dataclasses import dataclass
import sys
import threading
from typing import Any
from typing import List
from typing import Dict
from typing import Set
from typing import Optional
from typing import Callable
from ray import cloudpickle
from ray.job_config import JobConfig
import ray
import ray.state
import ray.core.generated.ray_client_pb2 as ray_client_pb2
@ -33,7 +35,12 @@ logger = logging.getLogger(__name__)
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
def __init__(self):
def __init__(self, ray_connect_handler: Callable):
"""Construct a raylet service
Args:
ray_connect_handler (Callable): Function to connect to ray cluster
"""
# Stores client_id -> (ref_id -> ObjectRef)
self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict(
dict)
@ -46,6 +53,42 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
self.registered_actor_classes = {}
self.named_actors = set()
self.state_lock = threading.Lock()
self.ray_connect_handler = ray_connect_handler
def Init(self, request, context=None) -> ray_client_pb2.InitResponse:
import pickle
if request.job_config:
job_config = pickle.loads(request.job_config)
job_config.client_job = True
else:
job_config = None
current_job_config = None
with disable_client_hook():
if ray.is_initialized():
worker = ray.worker.global_worker
current_job_config = worker.core_worker.get_job_config()
else:
self.ray_connect_handler(job_config)
if job_config is None:
return ray_client_pb2.InitResponse()
job_config = job_config.get_proto_job_config()
# If the server has been initialized, we need to compare whether the
# runtime env is compatible.
if current_job_config and set(job_config.runtime_env.uris) != set(
current_job_config.runtime_env.uris):
raise grpc.RpcError(
"Runtime environment doesn't match "
f"request one {job_config.runtime_env.uris} "
f"current one {current_job_config.runtime_env.uris}")
return ray_client_pb2.InitResponse()
def PrepRuntimeEnv(self, request,
context=None) -> ray_client_pb2.PrepRuntimeEnvResponse:
job_config = ray.worker.global_worker.core_worker.get_job_config()
missing_uris = self._prepare_runtime_env(job_config.runtime_env)
if len(missing_uris) != 0:
raise grpc.RpcError(f"Missing uris: {missing_uris}")
return ray_client_pb2.PrepRuntimeEnvResponse()
def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse:
with disable_client_hook():
@ -69,6 +112,13 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
request.prefix)
return ray_client_pb2.KVListResponse(keys=keys)
def KVExists(self, request,
context=None) -> ray_client_pb2.KVExistsResponse:
with disable_client_hook():
exists = ray.experimental.internal_kv._internal_kv_exists(
request.key)
return ray_client_pb2.KVExistsResponse(exists=exists)
def ClusterInfo(self, request,
context=None) -> ray_client_pb2.ClusterInfoResponse:
resp = ray_client_pb2.ClusterInfoResponse()
@ -371,6 +421,21 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
kwargout[k] = convert_from_arg(kwarg_map[k], self)
return argout, kwargout
def _prepare_runtime_env(self, job_runtime_env) -> List[str]:
"""Download runtime environment to local node"""
missing_uris = []
uris = job_runtime_env.uris
from ray._private import runtime_env
with disable_client_hook():
for uri in uris:
try:
runtime_env.fetch_package(uri)
print("Adding!: ", runtime_env._get_local_path(uri))
sys.path.insert(0, str(runtime_env._get_local_path(uri)))
except IOError:
missing_uris.append(uri)
return missing_uris
def lookup_or_register_func(
self, id: bytes, client_id: str,
options: Optional[Dict]) -> ray.remote_function.RemoteFunction:
@ -453,10 +518,10 @@ class ClientServerHandle:
def serve(connection_str, ray_connect_handler=None):
def default_connect_handler():
def default_connect_handler(job_config: JobConfig = None):
with disable_client_hook():
if not ray.is_initialized():
return ray.init()
return ray.init(job_config=job_config)
ray_connect_handler = ray_connect_handler or default_connect_handler
server = grpc.server(
@ -465,9 +530,8 @@ def serve(connection_str, ray_connect_handler=None):
("grpc.max_send_message_length", GRPC_MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_SIZE),
])
task_servicer = RayletServicer()
data_servicer = DataServicer(
task_servicer, ray_connect_handler=ray_connect_handler)
task_servicer = RayletServicer(ray_connect_handler)
data_servicer = DataServicer(task_servicer)
logs_servicer = LogstreamServicer()
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
task_servicer, server)
@ -491,13 +555,13 @@ def init_and_serve(connection_str, *args, **kwargs):
# Disable client mode inside the worker's environment
info = ray.init(*args, **kwargs)
def ray_connect_handler():
def ray_connect_handler(job_config=None):
# Ray client will disconnect from ray when
# num_clients == 0.
if ray.is_initialized():
return info
else:
return ray.init(*args, **kwargs)
return ray.init(job_config=job_config, *args, **kwargs)
server_handle = serve(
connection_str, ray_connect_handler=ray_connect_handler)
@ -511,14 +575,17 @@ def shutdown_with_server(server, _exiting_interpreter=False):
def create_ray_handler(redis_address, redis_password):
def ray_connect_handler():
def ray_connect_handler(job_config: JobConfig = None):
if redis_address:
if redis_password:
ray.init(address=redis_address, _redis_password=redis_password)
ray.init(
address=redis_address,
_redis_password=redis_password,
job_config=job_config)
else:
ray.init(address=redis_address)
ray.init(address=redis_address, job_config=job_config)
else:
ray.init()
ray.init(job_config=job_config)
return ray_connect_handler

View file

@ -17,6 +17,7 @@ from typing import TYPE_CHECKING
import grpc
from ray.job_config import JobConfig
import ray.cloudpickle as cloudpickle
# Use cloudpickle's version of pickle for UnpicklingError
from ray.cloudpickle.compat import pickle
@ -145,6 +146,7 @@ class Worker:
self.log_client = LogstreamClient(self.channel, self.metadata)
self.log_client.set_logstream_level(logging.INFO)
self.closed = False
def _on_channel_state_change(self, conn_state: grpc.ChannelConnectivity):
@ -394,6 +396,11 @@ class Worker:
resp = self.server.KVGet(req, metadata=self.metadata)
return resp.value
def internal_kv_exists(self, key: bytes) -> bytes:
req = ray_client_pb2.KVGetRequest(key=key)
resp = self.server.KVGet(req, metadata=self.metadata)
return resp.value
def internal_kv_put(self, key: bytes, value: bytes,
overwrite: bool) -> bool:
req = ray_client_pb2.KVPutRequest(
@ -431,6 +438,30 @@ class Worker:
def is_connected(self) -> bool:
return self._conn_state == grpc.ChannelConnectivity.READY
def _server_init(self, job_config: JobConfig):
"""Initialize the server"""
try:
if job_config is None:
init_req = ray_client_pb2.InitRequest()
self.data_client.Init(init_req)
return
import ray._private.runtime_env as runtime_env
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
if runtime_env.PKG_DIR is None:
runtime_env.PKG_DIR = tmp_dir
# Generate the uri for runtime env
runtime_env.rewrite_working_dir_uri(job_config)
init_req = ray_client_pb2.InitRequest(
job_config=pickle.dumps(job_config))
self.data_client.Init(init_req)
runtime_env.upload_runtime_env_package_if_needed(job_config)
prep_req = ray_client_pb2.PrepRuntimeEnvRequest()
self.data_client.PrepRuntimeEnv(prep_req)
except grpc.RpcError as e:
raise decode_exception(e.details())
def _convert_actor(self, actor: "ActorClass") -> str:
"""Register a ClientActorClass for the ActorClass and return a UUID"""
key = uuid.uuid4().hex

View file

@ -1,5 +1,5 @@
from ray.util.client import ray
from ray.job_config import JobConfig
from ray._private.client_mode_hook import _enable_client_hook
from ray._private.client_mode_hook import _explicitly_enable_client_mode
@ -10,6 +10,7 @@ def connect(conn_str: str,
secure: bool = False,
metadata: List[Tuple[str, str]] = None,
connection_retries: int = 3,
job_config: JobConfig = None,
*,
ignore_version: bool = False) -> Dict[str, Any]:
if ray.is_connected():
@ -26,6 +27,7 @@ def connect(conn_str: str,
# the correct metadata
return ray.connect(
conn_str,
job_config=job_config,
secure=secure,
metadata=metadata,
connection_retries=3,

View file

@ -1219,15 +1219,16 @@ def connect(node,
node.node_manager_port, node.raylet_ip_address, (mode == LOCAL_MODE),
driver_name, log_stdout_file_path, log_stderr_file_path,
serialized_job_config, node.metrics_agent_port)
# Notify raylet that the core worker is ready.
worker.core_worker.notify_raylet()
# Create an object for interfacing with the global state.
# Note, global state should be intialized after `CoreWorker`, because it
# will use glog, which is intialized in `CoreWorker`.
ray.state.state._initialize_global_state(
node.redis_address, redis_password=node.redis_password)
if mode == SCRIPT_MODE:
# If it's a driver and it's not coming from ray client, we'll prepare the
# environment here. If it's ray client, the environmen will be prepared
# at the server side.
if mode == SCRIPT_MODE and not job_config.client_job:
runtime_env.upload_runtime_env_package_if_needed(job_config)
elif mode == WORKER_MODE:
# TODO(ekl) get rid of the env var hack and get runtime env from the
@ -1237,6 +1238,9 @@ def connect(node,
worker.core_worker.get_job_config().runtime_env.uris
runtime_env.ensure_runtime_env_setup(job_config)
# Notify raylet that the core worker is ready.
worker.core_worker.notify_raylet()
if driver_object_store_memory is not None:
worker.core_worker.set_object_store_client_options(
f"ray_driver_{os.getpid()}", driver_object_store_memory)

View file

@ -162,11 +162,10 @@ cc_proto_library(
proto_library(
name = "ray_client_proto",
srcs = ["ray_client.proto"],
deps = [],
deps = [":common_proto"],
)
python_grpc_compile(
name = "ray_client_py_proto",
deps = [":ray_client_proto"]
deps = [":ray_client_proto"],
)

View file

@ -196,6 +196,14 @@ message TerminateResponse {
bool ok = 1;
}
message KVExistsRequest {
bytes key = 1;
}
message KVExistsResponse {
bool exists = 1;
}
message KVGetRequest {
bytes key = 1;
}
@ -229,7 +237,25 @@ message KVListResponse {
repeated bytes keys = 1;
}
message InitRequest {
// job_config of ray.init
bytes job_config = 1;
}
message InitResponse {
}
message PrepRuntimeEnvRequest {
}
message PrepRuntimeEnvResponse {
}
service RayletDriver {
rpc Init(InitRequest) returns (InitResponse) {
}
rpc PrepRuntimeEnv(PrepRuntimeEnvRequest) returns (PrepRuntimeEnvResponse) {
}
rpc GetObject(GetRequest) returns (GetResponse) {
}
rpc PutObject(PutRequest) returns (PutResponse) {
@ -250,6 +276,8 @@ service RayletDriver {
}
rpc KVList(KVListRequest) returns (KVListResponse) {
}
rpc KVExists(KVExistsRequest) returns (KVExistsResponse) {
}
}
message ReleaseRequest {
@ -288,6 +316,8 @@ message DataRequest {
PutRequest put = 3;
ReleaseRequest release = 4;
ConnectionInfoRequest connection_info = 5;
InitRequest init = 6;
PrepRuntimeEnvRequest prep_runtime_env = 7;
}
}
@ -299,6 +329,8 @@ message DataResponse {
PutResponse put = 3;
ReleaseResponse release = 4;
ConnectionInfoResponse connection_info = 5;
InitResponse init = 6;
PrepRuntimeEnvResponse prep_runtime_env = 7;
}
}