From 448013222978734615aaebc8e312f946fdce94fe Mon Sep 17 00:00:00 2001 From: Yi Cheng <74173148+iycheng@users.noreply.github.com> Date: Wed, 31 Mar 2021 11:39:34 -0700 Subject: [PATCH] [core] Integration runtime_env with ray client (#14881) * server side ready * client size * py * fix * up * format * add files * add pyx * up * up * up * add keys * format * update * format * add unittests * add files * up * up * fix * up * fix thread issue * format * fix * update proto * Fix * format * fix * more * fix conflict * fix * fix order * format * add * up * compiling fix * lint * fix * format * fix some * some fix * fix comment * test cases * add test * comments * fix name * format * fix * revert gcs-kv * fix comments * fix failure * fix test * format * fix timeout * fix * fix * fix * format * format * fix flaky test Co-authored-by: Yi Cheng --- python/ray/_private/client_mode_hook.py | 1 + python/ray/_private/runtime_env.py | 61 +++++----- .../experimental/packaging/load_package.py | 4 +- python/ray/job_config.py | 14 ++- python/ray/test_utils.py | 13 ++- python/ray/tests/test_client_init.py | 4 +- python/ray/tests/test_runtime_env.py | 81 +++++++------ python/ray/util/client/__init__.py | 5 +- python/ray/util/client/api.py | 4 + python/ray/util/client/dataclient.py | 13 +++ python/ray/util/client/ray_client_helpers.py | 2 +- python/ray/util/client/server/dataservicer.py | 107 ++++++++++-------- python/ray/util/client/server/server.py | 95 +++++++++++++--- python/ray/util/client/worker.py | 31 +++++ python/ray/util/client_connect.py | 4 +- python/ray/worker.py | 10 +- src/ray/protobuf/BUILD | 5 +- src/ray/protobuf/ray_client.proto | 32 ++++++ 18 files changed, 340 insertions(+), 146 deletions(-) diff --git a/python/ray/_private/client_mode_hook.py b/python/ray/_private/client_mode_hook.py index 74682f1cf..48e9ac394 100644 --- a/python/ray/_private/client_mode_hook.py +++ b/python/ray/_private/client_mode_hook.py @@ -6,6 +6,7 @@ from functools import wraps RAY_CLIENT_MODE_ATTR = "__ray_client_mode_key__" client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1" +os.environ.update({"RAY_CLIENT_MODE": "0"}) _client_hook_enabled = True diff --git a/python/ray/_private/runtime_env.py b/python/ray/_private/runtime_env.py index ed1aca8fc..a63ec9e95 100644 --- a/python/ray/_private/runtime_env.py +++ b/python/ray/_private/runtime_env.py @@ -1,6 +1,5 @@ import hashlib import logging -import inspect from filelock import FileLock from pathlib import Path @@ -9,10 +8,10 @@ from ray.job_config import JobConfig from enum import Enum from ray.experimental.internal_kv import (_internal_kv_put, _internal_kv_get, - _internal_kv_exists) + _internal_kv_exists, + _internal_kv_initialized) from typing import List, Tuple -from types import ModuleType from urllib.parse import urlparse import os import sys @@ -160,7 +159,7 @@ def _parse_uri(pkg_uri: str) -> Tuple[Protocol, str]: # TODO(yic): Fix this later to handle big directories in better way -def get_project_package_name(working_dir: str, modules: List[str]) -> str: +def get_project_package_name(working_dir: str, py_modules: List[str]) -> str: """Get the name of the package by working dir and modules. This function will generate the name of the package by the working @@ -180,7 +179,7 @@ def get_project_package_name(working_dir: str, modules: List[str]) -> str: e.g., _ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip Args: working_dir (str): The working directory. - modules (list[module]): The python module. + py_modules (list[str]): The python module. Returns: Package name as a string. @@ -191,14 +190,12 @@ def get_project_package_name(working_dir: str, modules: List[str]) -> str: assert isinstance(working_dir, str) assert Path(working_dir).exists() hash_val = _xor_bytes(hash_val, _hash_modules(Path(working_dir))) - for module in modules or []: - assert inspect.ismodule(module) - hash_val = _xor_bytes(hash_val, - _hash_modules(Path(module.__file__).parent)) + for py_module in py_modules or []: + hash_val = _xor_bytes(hash_val, _hash_modules(Path(py_module).parent)) return RAY_PKG_PREFIX + hash_val.hex() + ".zip" if hash_val else None -def create_project_package(working_dir: str, modules: List[ModuleType], +def create_project_package(working_dir: str, py_modules: List[str], output_path: str) -> None: """Create a pckage that will be used by workers. @@ -207,7 +204,8 @@ def create_project_package(working_dir: str, modules: List[ModuleType], Args: working_dir (str): The working directory. - modules (list[module]): The python modules to be included. + py_modules (list[str]): The list of path of python modules to be + included. output_path (str): The path of file to be created. """ pkg_file = Path(output_path) @@ -216,20 +214,15 @@ def create_project_package(working_dir: str, modules: List[ModuleType], # put all files in /path/working_dir into zip working_path = Path(working_dir) _zip_module(working_path, working_path, zip_handler) - for module in modules or []: - logger.info(module.__file__) - # we only take care of modules with path like this for now: - # /path/module_name/__init__.py - # module_path should be: /path/module_name - module_path = Path(module.__file__).parent - _zip_module(module_path, module_path.parent, zip_handler) + for py_module in py_modules or []: + _zip_module(Path(py_module), Path(py_module).parent, zip_handler) -def fetch_package(pkg_uri: str, pkg_file: Path) -> int: +def fetch_package(pkg_uri: str, pkg_file: Path = None) -> int: """Fetch a package from a given uri. This function is used to fetch a pacakge from the given uri to local - filesystem. + filesystem. If it exists, it'll just return 0. Args: pkg_uri (str): The uri of the package to download. @@ -239,9 +232,15 @@ def fetch_package(pkg_uri: str, pkg_file: Path) -> int: Returns: The number of bytes downloaded. """ + if pkg_file is None: + pkg_file = Path(_get_local_path(pkg_uri)) + if pkg_file.exists(): + return 0 (protocol, pkg_name) = _parse_uri(pkg_uri) if protocol in (Protocol.GCS, Protocol.PIN_GCS): code = _internal_kv_get(pkg_uri) + if code is None: + raise IOError("Fetch uri failed") code = code or b"" pkg_file.write_bytes(code) return len(code) @@ -284,6 +283,7 @@ def package_exists(pkg_uri: str) -> bool: Return: True for package existing and False for not. """ + assert _internal_kv_initialized() (protocol, pkg_name) = _parse_uri(pkg_uri) if protocol in (Protocol.GCS, Protocol.PIN_GCS): return _internal_kv_exists(pkg_uri) @@ -303,11 +303,11 @@ def rewrite_working_dir_uri(job_config: JobConfig) -> None: """ # For now, we only support local directory and packages working_dir = job_config.runtime_env.get("working_dir") - required_modules = job_config.runtime_env.get("local_modules") + py_modules = job_config.runtime_env.get("py_modules") - if (not job_config.runtime_env.get("working_dir_uri")) and ( - working_dir or required_modules): - pkg_name = get_project_package_name(working_dir, required_modules) + if (not job_config.runtime_env.get("working_dir_uri")) and (working_dir + or py_modules): + pkg_name = get_project_package_name(working_dir, py_modules) job_config.runtime_env[ "working_dir_uri"] = Protocol.GCS.value + "://" + pkg_name @@ -323,18 +323,18 @@ def upload_runtime_env_package_if_needed(job_config: JobConfig) -> None: Args: job_config (JobConfig): The job config of driver. """ + assert _internal_kv_initialized() pkg_uris = job_config.get_runtime_env_uris() for pkg_uri in pkg_uris: if not package_exists(pkg_uri): file_path = _get_local_path(pkg_uri) pkg_file = Path(file_path) working_dir = job_config.runtime_env.get("working_dir") - required_modules = job_config.runtime_env.get("local_modules") + py_modules = job_config.runtime_env.get("py_modules") logger.info(f"{pkg_uri} doesn't exist. Create new package with" - f" {working_dir} and {required_modules}") + f" {working_dir} and {py_modules}") if not pkg_file.exists(): - create_project_package(working_dir, required_modules, - file_path) + create_project_package(working_dir, py_modules, file_path) # Push the data to remote storage pkg_size = push_package(pkg_uri, pkg_file) logger.info(f"{pkg_uri} has been pushed with {pkg_size} bytes") @@ -349,12 +349,13 @@ def ensure_runtime_env_setup(pkg_uris: List[str]) -> None: Args: pkg_uri list(str): Package of the working dir for the runtime env. """ + + assert _internal_kv_initialized() for pkg_uri in pkg_uris: pkg_file = Path(_get_local_path(pkg_uri)) # For each node, the package will only be downloaded one time # Locking to avoid multiple process download concurrently - lock = FileLock(str(pkg_file) + ".lock") - with lock: + with FileLock(str(pkg_file) + ".lock"): # TODO(yic): checksum calculation is required if pkg_file.exists(): logger.debug( diff --git a/python/ray/experimental/packaging/load_package.py b/python/ray/experimental/packaging/load_package.py index fcdb8f371..0f21f62a1 100644 --- a/python/ray/experimental/packaging/load_package.py +++ b/python/ray/experimental/packaging/load_package.py @@ -69,14 +69,14 @@ def load_package(config_path: str) -> "_RuntimePackage": # Autofill working directory by uploading to GCS storage. if "working_dir" not in runtime_env: pkg_name = runtime_support.get_project_package_name( - working_dir=base_dir, modules=[]) + working_dir=base_dir, py_modules=[]) pkg_uri = runtime_support.Protocol.GCS.value + "://" + pkg_name def do_register_package(): if not runtime_support.package_exists(pkg_uri): tmp_path = os.path.join(_pkg_tmp(), "_tmp{}".format(pkg_name)) runtime_support.create_project_package( - working_dir=base_dir, modules=[], output_path=tmp_path) + working_dir=base_dir, py_modules=[], output_path=tmp_path) # TODO(ekl) does this get garbage collected correctly with the # current job id? runtime_support.push_package(pkg_uri, tmp_path) diff --git a/python/ray/job_config.py b/python/ray/job_config.py index 54dd91414..6bb8cef38 100644 --- a/python/ray/job_config.py +++ b/python/ray/job_config.py @@ -15,6 +15,7 @@ class JobConfig: `CLASSPATH` in Java and `PYTHONPATH` in Python. runtime_env (dict): A runtime environment dictionary (see ``runtime_env.py`` for detailed documentation). + client_job (bool): A boolean represent the source of the job. """ def __init__(self, @@ -22,7 +23,8 @@ class JobConfig: num_java_workers_per_process=1, jvm_options=None, code_search_path=None, - runtime_env=None): + runtime_env=None, + client_job=False): if worker_env is None: self.worker_env = dict() else: @@ -45,8 +47,15 @@ class JobConfig: f"The type of code search path is incorrect: " \ f"{type(code_search_path)}" self.runtime_env = runtime_env or dict() + self.client_job = client_job def serialize(self): + """Serialize the struct into protobuf string""" + job_config = self.get_proto_job_config() + return job_config.SerializeToString() + + def get_proto_job_config(self): + """Return the prototype structure of JobConfig""" job_config = ray.gcs_utils.JobConfig() for key in self.worker_env: job_config.worker_env[key] = self.worker_env[key] @@ -55,9 +64,10 @@ class JobConfig: job_config.jvm_options.extend(self.jvm_options) job_config.code_search_path.extend(self.code_search_path) job_config.runtime_env.CopyFrom(self._get_proto_runtime()) - return job_config.SerializeToString() + return job_config def get_runtime_env_uris(self): + """Get the uris of runtime environment""" if self.runtime_env.get("working_dir_uri"): return [self.runtime_env.get("working_dir_uri")] return [] diff --git a/python/ray/test_utils.py b/python/ray/test_utils.py index bd6d9b7a3..c63bbff4f 100644 --- a/python/ray/test_utils.py +++ b/python/ray/test_utils.py @@ -8,7 +8,7 @@ import sys import time import socket import math - +from typing import Dict from contextlib import redirect_stdout, redirect_stderr import yaml @@ -178,11 +178,12 @@ def kill_process_by_name(name, SIGKILL=False): p.terminate() -def run_string_as_driver(driver_script): +def run_string_as_driver(driver_script: str, env: Dict = None): """Run a driver as a separate process. Args: - driver_script: A string to run as a Python script. + driver_script (str): A string to run as a Python script. + env (dict): The environment variables for the driver. Returns: The script's output. @@ -191,7 +192,9 @@ def run_string_as_driver(driver_script): [sys.executable, "-"], stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) + stderr=subprocess.STDOUT, + env=env, + ) with proc: output = proc.communicate(driver_script.encode("ascii"))[0] if proc.returncode: @@ -489,7 +492,7 @@ def new_scheduler_enabled(): def client_test_enabled() -> bool: - return os.environ.get("RAY_CLIENT_MODE") == "1" + return os.environ.get("RAY_CLIENT_MODE") is not None def fetch_prometheus(prom_addresses): diff --git a/python/ray/tests/test_client_init.py b/python/ray/tests/test_client_init.py index 7ef2e744f..b76733b6b 100644 --- a/python/ray/tests/test_client_init.py +++ b/python/ray/tests/test_client_init.py @@ -54,8 +54,8 @@ def init_and_serve_lazy(): cluster.add_node(num_cpus=1, num_gpus=0) address = cluster.address - def connect(): - ray.init(address=address) + def connect(job_config=None): + ray.init(address=address, job_config=job_config) server_handle = ray_client_server.serve("localhost:50051", connect) yield server_handle diff --git a/python/ray/tests/test_runtime_env.py b/python/ray/tests/test_runtime_env.py index b5575fca7..e9682c08c 100644 --- a/python/ray/tests/test_runtime_env.py +++ b/python/ray/tests/test_runtime_env.py @@ -5,9 +5,9 @@ import unittest import subprocess +import tempfile from unittest import mock from pathlib import Path - import ray from ray.test_utils import run_string_as_driver from ray._private.utils import get_conda_env_dir, get_conda_bin_executable @@ -19,14 +19,19 @@ import logging sys.path.insert(0, "{working_dir}") import test_module import ray +import ray.util +import os job_config = ray.job_config.JobConfig( runtime_env={runtime_env} ) -ray.init(address="{redis_address}", - job_config=job_config, - logging_level=logging.DEBUG) +if os.environ.get("USE_RAY_CLIENT"): + ray.util.connect("{address}", job_config=job_config) +else: + ray.init(address="{address}", + job_config=job_config, + logging_level=logging.DEBUG) @ray.remote def run_test(): @@ -40,15 +45,17 @@ class TestActor(object): {execute_statement} -ray.shutdown() +if os.environ.get("USE_RAY_CLIENT"): + ray.util.disconnect() +else: + ray.shutdown() from time import sleep -sleep(5) +sleep(10) """ @pytest.fixture(scope="session") def working_dir(): - import tempfile with tempfile.TemporaryDirectory() as tmp_dir: path = Path(tmp_dir) module_path = path / "test_module" @@ -68,50 +75,59 @@ from test_module.test import one yield tmp_dir +def start_client_server(cluster, client_mode): + from ray._private.runtime_env import PKG_DIR + if not client_mode: + return (cluster.address, None, PKG_DIR) + ray.worker._global_node._ray_params.ray_client_server_port = "10003" + ray.worker._global_node.start_ray_client_server() + return ("localhost:10003", {"USE_RAY_CLIENT": "1"}, PKG_DIR) + + @unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.") -def test_single_node(ray_start_cluster_head, working_dir): +@pytest.mark.parametrize("client_mode", [True, False]) +def test_single_node(ray_start_cluster_head, working_dir, client_mode): cluster = ray_start_cluster_head - redis_address = cluster.address + (address, env, PKG_DIR) = start_client_server(cluster, client_mode) runtime_env = f"""{{ "working_dir": "{working_dir}" }}""" execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))" script = driver_script.format(**locals()) - out = run_string_as_driver(script) + out = run_string_as_driver(script, env) assert out.strip().split()[-1] == "1000" - from ray._private.runtime_env import PKG_DIR assert len(list(Path(PKG_DIR).iterdir())) == 1 @unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.") -def test_two_node(two_node_cluster, working_dir): +@pytest.mark.parametrize("client_mode", [True, False]) +def test_two_node(two_node_cluster, working_dir, client_mode): cluster, _ = two_node_cluster - redis_address = cluster.address + (address, env, PKG_DIR) = start_client_server(cluster, client_mode) runtime_env = f"""{{ "working_dir": "{working_dir}" }}""" execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))" script = driver_script.format(**locals()) - out = run_string_as_driver(script) + out = run_string_as_driver(script, env) assert out.strip().split()[-1] == "1000" - from ray._private.runtime_env import PKG_DIR assert len(list(Path(PKG_DIR).iterdir())) == 1 @unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.") -def test_two_node_module(two_node_cluster, working_dir): +@pytest.mark.parametrize("client_mode", [True, False]) +def test_two_node_module(two_node_cluster, working_dir, client_mode): cluster, _ = two_node_cluster - redis_address = cluster.address - runtime_env = """{ "local_modules": [test_module] }""" + (address, env, PKG_DIR) = start_client_server(cluster, client_mode) + runtime_env = """{ "py_modules": [test_module.__path__[0]] }""" execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))" script = driver_script.format(**locals()) - print(script) - out = run_string_as_driver(script) + out = run_string_as_driver(script, env) assert out.strip().split()[-1] == "1000" - from ray._private.runtime_env import PKG_DIR assert len(list(Path(PKG_DIR).iterdir())) == 1 @unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.") -def test_two_node_uri(two_node_cluster, working_dir): +@pytest.mark.parametrize("client_mode", [True, False]) +def test_two_node_uri(two_node_cluster, working_dir, client_mode): cluster, _ = two_node_cluster - redis_address = cluster.address + (address, env, PKG_DIR) = start_client_server(cluster, client_mode) import ray._private.runtime_env as runtime_env import tempfile with tempfile.NamedTemporaryFile(suffix="zip") as tmp_file: @@ -122,41 +138,40 @@ def test_two_node_uri(two_node_cluster, working_dir): runtime_env = f"""{{ "working_dir_uri": "{pkg_uri}" }}""" execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))" script = driver_script.format(**locals()) - out = run_string_as_driver(script) + out = run_string_as_driver(script, env) assert out.strip().split()[-1] == "1000" - from ray._private.runtime_env import PKG_DIR assert len(list(Path(PKG_DIR).iterdir())) == 1 @unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.") -def test_regular_actors(ray_start_cluster_head, working_dir): +@pytest.mark.parametrize("client_mode", [True, False]) +def test_regular_actors(ray_start_cluster_head, working_dir, client_mode): cluster = ray_start_cluster_head - redis_address = cluster.address + (address, env, PKG_DIR) = start_client_server(cluster, client_mode) runtime_env = f"""{{ "working_dir": "{working_dir}" }}""" execute_statement = """ test_actor = TestActor.options(name="test_actor").remote() print(sum(ray.get([test_actor.one.remote()] * 1000))) """ script = driver_script.format(**locals()) - out = run_string_as_driver(script) + out = run_string_as_driver(script, env) assert out.strip().split()[-1] == "1000" - from ray._private.runtime_env import PKG_DIR assert len(list(Path(PKG_DIR).iterdir())) == 1 @unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.") -def test_detached_actors(ray_start_cluster_head, working_dir): +@pytest.mark.parametrize("client_mode", [True, False]) +def test_detached_actors(ray_start_cluster_head, working_dir, client_mode): cluster = ray_start_cluster_head - redis_address = cluster.address + (address, env, PKG_DIR) = start_client_server(cluster, client_mode) runtime_env = f"""{{ "working_dir": "{working_dir}" }}""" execute_statement = """ test_actor = TestActor.options(name="test_actor", lifetime="detached").remote() print(sum(ray.get([test_actor.one.remote()] * 1000))) """ script = driver_script.format(**locals()) - out = run_string_as_driver(script) + out = run_string_as_driver(script, env) assert out.strip().split()[-1] == "1000" - from ray._private.runtime_env import PKG_DIR # It's a detached actors, so it should still be there assert len(list(Path(PKG_DIR).iterdir())) == 2 pkg = list(Path(PKG_DIR).glob("*.zip"))[0] diff --git a/python/ray/util/client/__init__.py b/python/ray/util/client/__init__.py index 988b96fab..8bd3602de 100644 --- a/python/ray/util/client/__init__.py +++ b/python/ray/util/client/__init__.py @@ -1,5 +1,5 @@ from typing import List, Tuple, Dict, Any - +from ray.job_config import JobConfig import os import sys import logging @@ -29,6 +29,7 @@ class RayAPIStub: def connect(self, conn_str: str, + job_config: JobConfig = None, secure: bool = False, metadata: List[Tuple[str, str]] = None, connection_retries: int = 3, @@ -38,6 +39,7 @@ class RayAPIStub: Args: conn_str: Connection string, in the form "[host]:port" + job_config: The job config of the server. secure: Whether to use a TLS secured gRPC channel metadata: gRPC metadata to send on connect connection_retries: number of connection attempts to make @@ -67,6 +69,7 @@ class RayAPIStub: metadata=metadata, connection_retries=connection_retries) self.api.worker = self.client_worker + self.client_worker._server_init(job_config) conn_info = self.client_worker.connection_info() self._check_versions(conn_info, ignore_version) return conn_info diff --git a/python/ray/util/client/api.py b/python/ray/util/client/api.py index 5b1ae881e..3c260de81 100644 --- a/python/ray/util/client/api.py +++ b/python/ray/util/client/api.py @@ -247,6 +247,10 @@ class ClientAPI: """Hook for internal_kv._internal_kv_initialized.""" return self.is_initialized() + def _internal_kv_exists(self, key: bytes) -> bool: + """Hook for internal_kv._internal_kv_exists.""" + return self.worker.internal_kv_exists(as_bytes(key)) + def _internal_kv_get(self, key: bytes) -> bytes: """Hook for internal_kv._internal_kv_get.""" return self.worker.internal_kv_get(as_bytes(key)) diff --git a/python/ray/util/client/dataclient.py b/python/ray/util/client/dataclient.py index 9de49c823..9fcda555f 100644 --- a/python/ray/util/client/dataclient.py +++ b/python/ray/util/client/dataclient.py @@ -117,6 +117,19 @@ class DataClient: req.req_id = req_id self.request_queue.put(req) + def Init(self, request: ray_client_pb2.InitRequest, + context=None) -> ray_client_pb2.InitResponse: + datareq = ray_client_pb2.DataRequest(init=request, ) + resp = self._blocking_send(datareq) + return resp.init + + def PrepRuntimeEnv(self, + request: ray_client_pb2.PrepRuntimeEnvRequest, + context=None) -> ray_client_pb2.PrepRuntimeEnvResponse: + datareq = ray_client_pb2.DataRequest(prep_runtime_env=request, ) + resp = self._blocking_send(datareq) + return resp.prep_runtime_env + def ConnectionInfo(self, context=None) -> ray_client_pb2.ConnectionInfoResponse: datareq = ray_client_pb2.DataRequest( diff --git a/python/ray/util/client/ray_client_helpers.py b/python/ray/util/client/ray_client_helpers.py index a7f16c246..424de51b8 100644 --- a/python/ray/util/client/ray_client_helpers.py +++ b/python/ray/util/client/ray_client_helpers.py @@ -29,7 +29,7 @@ def ray_start_client_server_pair(): def ray_start_cluster_client_server_pair(address): ray._inside_client_test = True - def ray_connect_handler(): + def ray_connect_handler(job_config=None): real_ray.init(address=address) server = ray_client_server.serve( diff --git a/python/ray/util/client/server/dataservicer.py b/python/ray/util/client/server/dataservicer.py index ec24e0608..e4b5f90dc 100644 --- a/python/ray/util/client/server/dataservicer.py +++ b/python/ray/util/client/server/dataservicer.py @@ -3,7 +3,7 @@ import logging import grpc import sys -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from threading import Lock import ray.core.generated.ray_client_pb2 as ray_client_pb2 @@ -20,12 +20,10 @@ logger = logging.getLogger(__name__) class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): - def __init__(self, basic_service: "RayletServicer", - ray_connect_handler: Callable): + def __init__(self, basic_service: "RayletServicer"): self.basic_service = basic_service self.clients_lock = Lock() self.num_clients = 0 # guarded by self.clients_lock - self.ray_connect_handler = ray_connect_handler def Datapath(self, request_iterator, context): metadata = {k: v for k, v in context.invocation_metadata()} @@ -36,55 +34,47 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): return logger.debug(f"New data connection from client {client_id}: ") try: - with self.clients_lock: - with disable_client_hook(): - # It's important to keep the ray initialization call - # within this locked context or else Ray could hang. - if self.num_clients == 0 and not ray.is_initialized(): - self.ray_connect_handler() - threshold = int(CLIENT_SERVER_MAX_THREADS / 2) - if self.num_clients >= threshold: - context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED) - logger.warning( - f"[Data Servicer]: Num clients {self.num_clients} " - f"has reached the threshold {threshold}. " - f"Rejecting client: {metadata['client_id']}. ") - if log_once("client_threshold"): - logger.warning( - "You can configure the client connection " - "threshold by setting the " - "RAY_CLIENT_SERVER_MAX_THREADS env var " - f"(currently set to {CLIENT_SERVER_MAX_THREADS}).") - return - - self.num_clients += 1 - logger.debug(f"Accepted data connection from {client_id}. " - f"Total clients: {self.num_clients}") - accepted_connection = True for req in request_iterator: resp = None req_type = req.WhichOneof("type") - if req_type == "get": - get_resp = self.basic_service._get_object( - req.get, client_id) - resp = ray_client_pb2.DataResponse(get=get_resp) - elif req_type == "put": - put_resp = self.basic_service._put_object( - req.put, client_id) - resp = ray_client_pb2.DataResponse(put=put_resp) - elif req_type == "release": - released = [] - for rel_id in req.release.ids: - rel = self.basic_service.release(client_id, rel_id) - released.append(rel) - resp = ray_client_pb2.DataResponse( - release=ray_client_pb2.ReleaseResponse(ok=released)) - elif req_type == "connection_info": - resp = ray_client_pb2.DataResponse( - connection_info=self._build_connection_response()) + if req_type == "init": + resp = self._init(req.init, client_id) + if resp is None: + context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED) + return + logger.debug(f"Accepted data connection from {client_id}. " + f"Total clients: {self.num_clients}") + accepted_connection = True else: - raise Exception(f"Unreachable code: Request type " - f"{req_type} not handled in Datapath") + assert accepted_connection + if req_type == "get": + get_resp = self.basic_service._get_object( + req.get, client_id) + resp = ray_client_pb2.DataResponse(get=get_resp) + elif req_type == "put": + put_resp = self.basic_service._put_object( + req.put, client_id) + resp = ray_client_pb2.DataResponse(put=put_resp) + elif req_type == "release": + released = [] + for rel_id in req.release.ids: + rel = self.basic_service.release(client_id, rel_id) + released.append(rel) + resp = ray_client_pb2.DataResponse( + release=ray_client_pb2.ReleaseResponse( + ok=released)) + elif req_type == "connection_info": + resp = ray_client_pb2.DataResponse( + connection_info=self._build_connection_response()) + elif req_type == "prep_runtime_env": + with self.clients_lock: + resp_prep = self.basic_service.PrepRuntimeEnv( + req.prep_runtime_env) + resp = ray_client_pb2.DataResponse( + prep_runtime_env=resp_prep) + else: + raise Exception(f"Unreachable code: Request type " + f"{req_type} not handled in Datapath") resp.req_id = req.req_id yield resp except grpc.RpcError as e: @@ -106,6 +96,25 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): logger.debug("Shutting down ray.") ray.shutdown() + def _init(self, req_init, client_id): + with self.clients_lock: + threshold = int(CLIENT_SERVER_MAX_THREADS / 2) + if self.num_clients >= threshold: + logger.warning( + f"[Data Servicer]: Num clients {self.num_clients} " + f"has reached the threshold {threshold}. " + f"Rejecting client: {client_id}. ") + if log_once("client_threshold"): + logger.warning( + "You can configure the client connection " + "threshold by setting the " + "RAY_CLIENT_SERVER_MAX_THREADS env var " + f"(currently set to {CLIENT_SERVER_MAX_THREADS}).") + return None + resp_init = self.basic_service.Init(req_init) + self.num_clients += 1 + return ray_client_pb2.DataResponse(init=resp_init, ) + def _build_connection_response(self): with self.clients_lock: cur_num_clients = self.num_clients diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index c9f9f26d5..09ef15e6b 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -4,14 +4,16 @@ import grpc import base64 from collections import defaultdict from dataclasses import dataclass - +import sys import threading from typing import Any +from typing import List from typing import Dict from typing import Set from typing import Optional - +from typing import Callable from ray import cloudpickle +from ray.job_config import JobConfig import ray import ray.state import ray.core.generated.ray_client_pb2 as ray_client_pb2 @@ -33,7 +35,12 @@ logger = logging.getLogger(__name__) class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): - def __init__(self): + def __init__(self, ray_connect_handler: Callable): + """Construct a raylet service + + Args: + ray_connect_handler (Callable): Function to connect to ray cluster + """ # Stores client_id -> (ref_id -> ObjectRef) self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict( dict) @@ -46,6 +53,42 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): self.registered_actor_classes = {} self.named_actors = set() self.state_lock = threading.Lock() + self.ray_connect_handler = ray_connect_handler + + def Init(self, request, context=None) -> ray_client_pb2.InitResponse: + import pickle + if request.job_config: + job_config = pickle.loads(request.job_config) + job_config.client_job = True + else: + job_config = None + current_job_config = None + with disable_client_hook(): + if ray.is_initialized(): + worker = ray.worker.global_worker + current_job_config = worker.core_worker.get_job_config() + else: + self.ray_connect_handler(job_config) + if job_config is None: + return ray_client_pb2.InitResponse() + job_config = job_config.get_proto_job_config() + # If the server has been initialized, we need to compare whether the + # runtime env is compatible. + if current_job_config and set(job_config.runtime_env.uris) != set( + current_job_config.runtime_env.uris): + raise grpc.RpcError( + "Runtime environment doesn't match " + f"request one {job_config.runtime_env.uris} " + f"current one {current_job_config.runtime_env.uris}") + return ray_client_pb2.InitResponse() + + def PrepRuntimeEnv(self, request, + context=None) -> ray_client_pb2.PrepRuntimeEnvResponse: + job_config = ray.worker.global_worker.core_worker.get_job_config() + missing_uris = self._prepare_runtime_env(job_config.runtime_env) + if len(missing_uris) != 0: + raise grpc.RpcError(f"Missing uris: {missing_uris}") + return ray_client_pb2.PrepRuntimeEnvResponse() def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse: with disable_client_hook(): @@ -69,6 +112,13 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): request.prefix) return ray_client_pb2.KVListResponse(keys=keys) + def KVExists(self, request, + context=None) -> ray_client_pb2.KVExistsResponse: + with disable_client_hook(): + exists = ray.experimental.internal_kv._internal_kv_exists( + request.key) + return ray_client_pb2.KVExistsResponse(exists=exists) + def ClusterInfo(self, request, context=None) -> ray_client_pb2.ClusterInfoResponse: resp = ray_client_pb2.ClusterInfoResponse() @@ -371,6 +421,21 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): kwargout[k] = convert_from_arg(kwarg_map[k], self) return argout, kwargout + def _prepare_runtime_env(self, job_runtime_env) -> List[str]: + """Download runtime environment to local node""" + missing_uris = [] + uris = job_runtime_env.uris + from ray._private import runtime_env + with disable_client_hook(): + for uri in uris: + try: + runtime_env.fetch_package(uri) + print("Adding!: ", runtime_env._get_local_path(uri)) + sys.path.insert(0, str(runtime_env._get_local_path(uri))) + except IOError: + missing_uris.append(uri) + return missing_uris + def lookup_or_register_func( self, id: bytes, client_id: str, options: Optional[Dict]) -> ray.remote_function.RemoteFunction: @@ -453,10 +518,10 @@ class ClientServerHandle: def serve(connection_str, ray_connect_handler=None): - def default_connect_handler(): + def default_connect_handler(job_config: JobConfig = None): with disable_client_hook(): if not ray.is_initialized(): - return ray.init() + return ray.init(job_config=job_config) ray_connect_handler = ray_connect_handler or default_connect_handler server = grpc.server( @@ -465,9 +530,8 @@ def serve(connection_str, ray_connect_handler=None): ("grpc.max_send_message_length", GRPC_MAX_MESSAGE_SIZE), ("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_SIZE), ]) - task_servicer = RayletServicer() - data_servicer = DataServicer( - task_servicer, ray_connect_handler=ray_connect_handler) + task_servicer = RayletServicer(ray_connect_handler) + data_servicer = DataServicer(task_servicer) logs_servicer = LogstreamServicer() ray_client_pb2_grpc.add_RayletDriverServicer_to_server( task_servicer, server) @@ -491,13 +555,13 @@ def init_and_serve(connection_str, *args, **kwargs): # Disable client mode inside the worker's environment info = ray.init(*args, **kwargs) - def ray_connect_handler(): + def ray_connect_handler(job_config=None): # Ray client will disconnect from ray when # num_clients == 0. if ray.is_initialized(): return info else: - return ray.init(*args, **kwargs) + return ray.init(job_config=job_config, *args, **kwargs) server_handle = serve( connection_str, ray_connect_handler=ray_connect_handler) @@ -511,14 +575,17 @@ def shutdown_with_server(server, _exiting_interpreter=False): def create_ray_handler(redis_address, redis_password): - def ray_connect_handler(): + def ray_connect_handler(job_config: JobConfig = None): if redis_address: if redis_password: - ray.init(address=redis_address, _redis_password=redis_password) + ray.init( + address=redis_address, + _redis_password=redis_password, + job_config=job_config) else: - ray.init(address=redis_address) + ray.init(address=redis_address, job_config=job_config) else: - ray.init() + ray.init(job_config=job_config) return ray_connect_handler diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 99bf059db..515f4c80d 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING import grpc +from ray.job_config import JobConfig import ray.cloudpickle as cloudpickle # Use cloudpickle's version of pickle for UnpicklingError from ray.cloudpickle.compat import pickle @@ -145,6 +146,7 @@ class Worker: self.log_client = LogstreamClient(self.channel, self.metadata) self.log_client.set_logstream_level(logging.INFO) + self.closed = False def _on_channel_state_change(self, conn_state: grpc.ChannelConnectivity): @@ -394,6 +396,11 @@ class Worker: resp = self.server.KVGet(req, metadata=self.metadata) return resp.value + def internal_kv_exists(self, key: bytes) -> bytes: + req = ray_client_pb2.KVGetRequest(key=key) + resp = self.server.KVGet(req, metadata=self.metadata) + return resp.value + def internal_kv_put(self, key: bytes, value: bytes, overwrite: bool) -> bool: req = ray_client_pb2.KVPutRequest( @@ -431,6 +438,30 @@ class Worker: def is_connected(self) -> bool: return self._conn_state == grpc.ChannelConnectivity.READY + def _server_init(self, job_config: JobConfig): + """Initialize the server""" + try: + if job_config is None: + init_req = ray_client_pb2.InitRequest() + self.data_client.Init(init_req) + return + + import ray._private.runtime_env as runtime_env + import tempfile + with tempfile.TemporaryDirectory() as tmp_dir: + if runtime_env.PKG_DIR is None: + runtime_env.PKG_DIR = tmp_dir + # Generate the uri for runtime env + runtime_env.rewrite_working_dir_uri(job_config) + init_req = ray_client_pb2.InitRequest( + job_config=pickle.dumps(job_config)) + self.data_client.Init(init_req) + runtime_env.upload_runtime_env_package_if_needed(job_config) + prep_req = ray_client_pb2.PrepRuntimeEnvRequest() + self.data_client.PrepRuntimeEnv(prep_req) + except grpc.RpcError as e: + raise decode_exception(e.details()) + def _convert_actor(self, actor: "ActorClass") -> str: """Register a ClientActorClass for the ActorClass and return a UUID""" key = uuid.uuid4().hex diff --git a/python/ray/util/client_connect.py b/python/ray/util/client_connect.py index 3a878c684..542f539bb 100644 --- a/python/ray/util/client_connect.py +++ b/python/ray/util/client_connect.py @@ -1,5 +1,5 @@ from ray.util.client import ray - +from ray.job_config import JobConfig from ray._private.client_mode_hook import _enable_client_hook from ray._private.client_mode_hook import _explicitly_enable_client_mode @@ -10,6 +10,7 @@ def connect(conn_str: str, secure: bool = False, metadata: List[Tuple[str, str]] = None, connection_retries: int = 3, + job_config: JobConfig = None, *, ignore_version: bool = False) -> Dict[str, Any]: if ray.is_connected(): @@ -26,6 +27,7 @@ def connect(conn_str: str, # the correct metadata return ray.connect( conn_str, + job_config=job_config, secure=secure, metadata=metadata, connection_retries=3, diff --git a/python/ray/worker.py b/python/ray/worker.py index 220b22a5d..6a6de2b60 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1219,15 +1219,16 @@ def connect(node, node.node_manager_port, node.raylet_ip_address, (mode == LOCAL_MODE), driver_name, log_stdout_file_path, log_stderr_file_path, serialized_job_config, node.metrics_agent_port) - # Notify raylet that the core worker is ready. - worker.core_worker.notify_raylet() # Create an object for interfacing with the global state. # Note, global state should be intialized after `CoreWorker`, because it # will use glog, which is intialized in `CoreWorker`. ray.state.state._initialize_global_state( node.redis_address, redis_password=node.redis_password) - if mode == SCRIPT_MODE: + # If it's a driver and it's not coming from ray client, we'll prepare the + # environment here. If it's ray client, the environmen will be prepared + # at the server side. + if mode == SCRIPT_MODE and not job_config.client_job: runtime_env.upload_runtime_env_package_if_needed(job_config) elif mode == WORKER_MODE: # TODO(ekl) get rid of the env var hack and get runtime env from the @@ -1237,6 +1238,9 @@ def connect(node, worker.core_worker.get_job_config().runtime_env.uris runtime_env.ensure_runtime_env_setup(job_config) + # Notify raylet that the core worker is ready. + worker.core_worker.notify_raylet() + if driver_object_store_memory is not None: worker.core_worker.set_object_store_client_options( f"ray_driver_{os.getpid()}", driver_object_store_memory) diff --git a/src/ray/protobuf/BUILD b/src/ray/protobuf/BUILD index 63d6a96c0..309607248 100644 --- a/src/ray/protobuf/BUILD +++ b/src/ray/protobuf/BUILD @@ -162,11 +162,10 @@ cc_proto_library( proto_library( name = "ray_client_proto", srcs = ["ray_client.proto"], - deps = [], + deps = [":common_proto"], ) python_grpc_compile( name = "ray_client_py_proto", - deps = [":ray_client_proto"] + deps = [":ray_client_proto"], ) - diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index 9df790f88..e6863b33f 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -196,6 +196,14 @@ message TerminateResponse { bool ok = 1; } +message KVExistsRequest { + bytes key = 1; +} + +message KVExistsResponse { + bool exists = 1; +} + message KVGetRequest { bytes key = 1; } @@ -229,7 +237,25 @@ message KVListResponse { repeated bytes keys = 1; } +message InitRequest { + // job_config of ray.init + bytes job_config = 1; +} + +message InitResponse { +} + +message PrepRuntimeEnvRequest { +} + +message PrepRuntimeEnvResponse { +} + service RayletDriver { + rpc Init(InitRequest) returns (InitResponse) { + } + rpc PrepRuntimeEnv(PrepRuntimeEnvRequest) returns (PrepRuntimeEnvResponse) { + } rpc GetObject(GetRequest) returns (GetResponse) { } rpc PutObject(PutRequest) returns (PutResponse) { @@ -250,6 +276,8 @@ service RayletDriver { } rpc KVList(KVListRequest) returns (KVListResponse) { } + rpc KVExists(KVExistsRequest) returns (KVExistsResponse) { + } } message ReleaseRequest { @@ -288,6 +316,8 @@ message DataRequest { PutRequest put = 3; ReleaseRequest release = 4; ConnectionInfoRequest connection_info = 5; + InitRequest init = 6; + PrepRuntimeEnvRequest prep_runtime_env = 7; } } @@ -299,6 +329,8 @@ message DataResponse { PutResponse put = 3; ReleaseResponse release = 4; ConnectionInfoResponse connection_info = 5; + InitResponse init = 6; + PrepRuntimeEnvResponse prep_runtime_env = 7; } }