mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[core] Integration runtime_env with ray client (#14881)
* server side ready * client size * py * fix * up * format * add files * add pyx * up * up * up * add keys * format * update * format * add unittests * add files * up * up * fix * up * fix thread issue * format * fix * update proto * Fix * format * fix * more * fix conflict * fix * fix order * format * add * up * compiling fix * lint * fix * format * fix some * some fix * fix comment * test cases * add test * comments * fix name * format * fix * revert gcs-kv * fix comments * fix failure * fix test * format * fix timeout * fix * fix * fix * format * format * fix flaky test Co-authored-by: Yi Cheng <singye888@gmail.com>
This commit is contained in:
parent
91cf272c2e
commit
4480132229
18 changed files with 340 additions and 146 deletions
|
@ -6,6 +6,7 @@ from functools import wraps
|
|||
RAY_CLIENT_MODE_ATTR = "__ray_client_mode_key__"
|
||||
|
||||
client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1"
|
||||
os.environ.update({"RAY_CLIENT_MODE": "0"})
|
||||
|
||||
_client_hook_enabled = True
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import inspect
|
||||
|
||||
from filelock import FileLock
|
||||
from pathlib import Path
|
||||
|
@ -9,10 +8,10 @@ from ray.job_config import JobConfig
|
|||
from enum import Enum
|
||||
|
||||
from ray.experimental.internal_kv import (_internal_kv_put, _internal_kv_get,
|
||||
_internal_kv_exists)
|
||||
_internal_kv_exists,
|
||||
_internal_kv_initialized)
|
||||
|
||||
from typing import List, Tuple
|
||||
from types import ModuleType
|
||||
from urllib.parse import urlparse
|
||||
import os
|
||||
import sys
|
||||
|
@ -160,7 +159,7 @@ def _parse_uri(pkg_uri: str) -> Tuple[Protocol, str]:
|
|||
|
||||
|
||||
# TODO(yic): Fix this later to handle big directories in better way
|
||||
def get_project_package_name(working_dir: str, modules: List[str]) -> str:
|
||||
def get_project_package_name(working_dir: str, py_modules: List[str]) -> str:
|
||||
"""Get the name of the package by working dir and modules.
|
||||
|
||||
This function will generate the name of the package by the working
|
||||
|
@ -180,7 +179,7 @@ def get_project_package_name(working_dir: str, modules: List[str]) -> str:
|
|||
e.g., _ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip
|
||||
Args:
|
||||
working_dir (str): The working directory.
|
||||
modules (list[module]): The python module.
|
||||
py_modules (list[str]): The python module.
|
||||
|
||||
Returns:
|
||||
Package name as a string.
|
||||
|
@ -191,14 +190,12 @@ def get_project_package_name(working_dir: str, modules: List[str]) -> str:
|
|||
assert isinstance(working_dir, str)
|
||||
assert Path(working_dir).exists()
|
||||
hash_val = _xor_bytes(hash_val, _hash_modules(Path(working_dir)))
|
||||
for module in modules or []:
|
||||
assert inspect.ismodule(module)
|
||||
hash_val = _xor_bytes(hash_val,
|
||||
_hash_modules(Path(module.__file__).parent))
|
||||
for py_module in py_modules or []:
|
||||
hash_val = _xor_bytes(hash_val, _hash_modules(Path(py_module).parent))
|
||||
return RAY_PKG_PREFIX + hash_val.hex() + ".zip" if hash_val else None
|
||||
|
||||
|
||||
def create_project_package(working_dir: str, modules: List[ModuleType],
|
||||
def create_project_package(working_dir: str, py_modules: List[str],
|
||||
output_path: str) -> None:
|
||||
"""Create a pckage that will be used by workers.
|
||||
|
||||
|
@ -207,7 +204,8 @@ def create_project_package(working_dir: str, modules: List[ModuleType],
|
|||
|
||||
Args:
|
||||
working_dir (str): The working directory.
|
||||
modules (list[module]): The python modules to be included.
|
||||
py_modules (list[str]): The list of path of python modules to be
|
||||
included.
|
||||
output_path (str): The path of file to be created.
|
||||
"""
|
||||
pkg_file = Path(output_path)
|
||||
|
@ -216,20 +214,15 @@ def create_project_package(working_dir: str, modules: List[ModuleType],
|
|||
# put all files in /path/working_dir into zip
|
||||
working_path = Path(working_dir)
|
||||
_zip_module(working_path, working_path, zip_handler)
|
||||
for module in modules or []:
|
||||
logger.info(module.__file__)
|
||||
# we only take care of modules with path like this for now:
|
||||
# /path/module_name/__init__.py
|
||||
# module_path should be: /path/module_name
|
||||
module_path = Path(module.__file__).parent
|
||||
_zip_module(module_path, module_path.parent, zip_handler)
|
||||
for py_module in py_modules or []:
|
||||
_zip_module(Path(py_module), Path(py_module).parent, zip_handler)
|
||||
|
||||
|
||||
def fetch_package(pkg_uri: str, pkg_file: Path) -> int:
|
||||
def fetch_package(pkg_uri: str, pkg_file: Path = None) -> int:
|
||||
"""Fetch a package from a given uri.
|
||||
|
||||
This function is used to fetch a pacakge from the given uri to local
|
||||
filesystem.
|
||||
filesystem. If it exists, it'll just return 0.
|
||||
|
||||
Args:
|
||||
pkg_uri (str): The uri of the package to download.
|
||||
|
@ -239,9 +232,15 @@ def fetch_package(pkg_uri: str, pkg_file: Path) -> int:
|
|||
Returns:
|
||||
The number of bytes downloaded.
|
||||
"""
|
||||
if pkg_file is None:
|
||||
pkg_file = Path(_get_local_path(pkg_uri))
|
||||
if pkg_file.exists():
|
||||
return 0
|
||||
(protocol, pkg_name) = _parse_uri(pkg_uri)
|
||||
if protocol in (Protocol.GCS, Protocol.PIN_GCS):
|
||||
code = _internal_kv_get(pkg_uri)
|
||||
if code is None:
|
||||
raise IOError("Fetch uri failed")
|
||||
code = code or b""
|
||||
pkg_file.write_bytes(code)
|
||||
return len(code)
|
||||
|
@ -284,6 +283,7 @@ def package_exists(pkg_uri: str) -> bool:
|
|||
Return:
|
||||
True for package existing and False for not.
|
||||
"""
|
||||
assert _internal_kv_initialized()
|
||||
(protocol, pkg_name) = _parse_uri(pkg_uri)
|
||||
if protocol in (Protocol.GCS, Protocol.PIN_GCS):
|
||||
return _internal_kv_exists(pkg_uri)
|
||||
|
@ -303,11 +303,11 @@ def rewrite_working_dir_uri(job_config: JobConfig) -> None:
|
|||
"""
|
||||
# For now, we only support local directory and packages
|
||||
working_dir = job_config.runtime_env.get("working_dir")
|
||||
required_modules = job_config.runtime_env.get("local_modules")
|
||||
py_modules = job_config.runtime_env.get("py_modules")
|
||||
|
||||
if (not job_config.runtime_env.get("working_dir_uri")) and (
|
||||
working_dir or required_modules):
|
||||
pkg_name = get_project_package_name(working_dir, required_modules)
|
||||
if (not job_config.runtime_env.get("working_dir_uri")) and (working_dir
|
||||
or py_modules):
|
||||
pkg_name = get_project_package_name(working_dir, py_modules)
|
||||
job_config.runtime_env[
|
||||
"working_dir_uri"] = Protocol.GCS.value + "://" + pkg_name
|
||||
|
||||
|
@ -323,18 +323,18 @@ def upload_runtime_env_package_if_needed(job_config: JobConfig) -> None:
|
|||
Args:
|
||||
job_config (JobConfig): The job config of driver.
|
||||
"""
|
||||
assert _internal_kv_initialized()
|
||||
pkg_uris = job_config.get_runtime_env_uris()
|
||||
for pkg_uri in pkg_uris:
|
||||
if not package_exists(pkg_uri):
|
||||
file_path = _get_local_path(pkg_uri)
|
||||
pkg_file = Path(file_path)
|
||||
working_dir = job_config.runtime_env.get("working_dir")
|
||||
required_modules = job_config.runtime_env.get("local_modules")
|
||||
py_modules = job_config.runtime_env.get("py_modules")
|
||||
logger.info(f"{pkg_uri} doesn't exist. Create new package with"
|
||||
f" {working_dir} and {required_modules}")
|
||||
f" {working_dir} and {py_modules}")
|
||||
if not pkg_file.exists():
|
||||
create_project_package(working_dir, required_modules,
|
||||
file_path)
|
||||
create_project_package(working_dir, py_modules, file_path)
|
||||
# Push the data to remote storage
|
||||
pkg_size = push_package(pkg_uri, pkg_file)
|
||||
logger.info(f"{pkg_uri} has been pushed with {pkg_size} bytes")
|
||||
|
@ -349,12 +349,13 @@ def ensure_runtime_env_setup(pkg_uris: List[str]) -> None:
|
|||
Args:
|
||||
pkg_uri list(str): Package of the working dir for the runtime env.
|
||||
"""
|
||||
|
||||
assert _internal_kv_initialized()
|
||||
for pkg_uri in pkg_uris:
|
||||
pkg_file = Path(_get_local_path(pkg_uri))
|
||||
# For each node, the package will only be downloaded one time
|
||||
# Locking to avoid multiple process download concurrently
|
||||
lock = FileLock(str(pkg_file) + ".lock")
|
||||
with lock:
|
||||
with FileLock(str(pkg_file) + ".lock"):
|
||||
# TODO(yic): checksum calculation is required
|
||||
if pkg_file.exists():
|
||||
logger.debug(
|
||||
|
|
|
@ -69,14 +69,14 @@ def load_package(config_path: str) -> "_RuntimePackage":
|
|||
# Autofill working directory by uploading to GCS storage.
|
||||
if "working_dir" not in runtime_env:
|
||||
pkg_name = runtime_support.get_project_package_name(
|
||||
working_dir=base_dir, modules=[])
|
||||
working_dir=base_dir, py_modules=[])
|
||||
pkg_uri = runtime_support.Protocol.GCS.value + "://" + pkg_name
|
||||
|
||||
def do_register_package():
|
||||
if not runtime_support.package_exists(pkg_uri):
|
||||
tmp_path = os.path.join(_pkg_tmp(), "_tmp{}".format(pkg_name))
|
||||
runtime_support.create_project_package(
|
||||
working_dir=base_dir, modules=[], output_path=tmp_path)
|
||||
working_dir=base_dir, py_modules=[], output_path=tmp_path)
|
||||
# TODO(ekl) does this get garbage collected correctly with the
|
||||
# current job id?
|
||||
runtime_support.push_package(pkg_uri, tmp_path)
|
||||
|
|
|
@ -15,6 +15,7 @@ class JobConfig:
|
|||
`CLASSPATH` in Java and `PYTHONPATH` in Python.
|
||||
runtime_env (dict): A runtime environment dictionary (see
|
||||
``runtime_env.py`` for detailed documentation).
|
||||
client_job (bool): A boolean represent the source of the job.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -22,7 +23,8 @@ class JobConfig:
|
|||
num_java_workers_per_process=1,
|
||||
jvm_options=None,
|
||||
code_search_path=None,
|
||||
runtime_env=None):
|
||||
runtime_env=None,
|
||||
client_job=False):
|
||||
if worker_env is None:
|
||||
self.worker_env = dict()
|
||||
else:
|
||||
|
@ -45,8 +47,15 @@ class JobConfig:
|
|||
f"The type of code search path is incorrect: " \
|
||||
f"{type(code_search_path)}"
|
||||
self.runtime_env = runtime_env or dict()
|
||||
self.client_job = client_job
|
||||
|
||||
def serialize(self):
|
||||
"""Serialize the struct into protobuf string"""
|
||||
job_config = self.get_proto_job_config()
|
||||
return job_config.SerializeToString()
|
||||
|
||||
def get_proto_job_config(self):
|
||||
"""Return the prototype structure of JobConfig"""
|
||||
job_config = ray.gcs_utils.JobConfig()
|
||||
for key in self.worker_env:
|
||||
job_config.worker_env[key] = self.worker_env[key]
|
||||
|
@ -55,9 +64,10 @@ class JobConfig:
|
|||
job_config.jvm_options.extend(self.jvm_options)
|
||||
job_config.code_search_path.extend(self.code_search_path)
|
||||
job_config.runtime_env.CopyFrom(self._get_proto_runtime())
|
||||
return job_config.SerializeToString()
|
||||
return job_config
|
||||
|
||||
def get_runtime_env_uris(self):
|
||||
"""Get the uris of runtime environment"""
|
||||
if self.runtime_env.get("working_dir_uri"):
|
||||
return [self.runtime_env.get("working_dir_uri")]
|
||||
return []
|
||||
|
|
|
@ -8,7 +8,7 @@ import sys
|
|||
import time
|
||||
import socket
|
||||
import math
|
||||
|
||||
from typing import Dict
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
import yaml
|
||||
|
||||
|
@ -178,11 +178,12 @@ def kill_process_by_name(name, SIGKILL=False):
|
|||
p.terminate()
|
||||
|
||||
|
||||
def run_string_as_driver(driver_script):
|
||||
def run_string_as_driver(driver_script: str, env: Dict = None):
|
||||
"""Run a driver as a separate process.
|
||||
|
||||
Args:
|
||||
driver_script: A string to run as a Python script.
|
||||
driver_script (str): A string to run as a Python script.
|
||||
env (dict): The environment variables for the driver.
|
||||
|
||||
Returns:
|
||||
The script's output.
|
||||
|
@ -191,7 +192,9 @@ def run_string_as_driver(driver_script):
|
|||
[sys.executable, "-"],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT)
|
||||
stderr=subprocess.STDOUT,
|
||||
env=env,
|
||||
)
|
||||
with proc:
|
||||
output = proc.communicate(driver_script.encode("ascii"))[0]
|
||||
if proc.returncode:
|
||||
|
@ -489,7 +492,7 @@ def new_scheduler_enabled():
|
|||
|
||||
|
||||
def client_test_enabled() -> bool:
|
||||
return os.environ.get("RAY_CLIENT_MODE") == "1"
|
||||
return os.environ.get("RAY_CLIENT_MODE") is not None
|
||||
|
||||
|
||||
def fetch_prometheus(prom_addresses):
|
||||
|
|
|
@ -54,8 +54,8 @@ def init_and_serve_lazy():
|
|||
cluster.add_node(num_cpus=1, num_gpus=0)
|
||||
address = cluster.address
|
||||
|
||||
def connect():
|
||||
ray.init(address=address)
|
||||
def connect(job_config=None):
|
||||
ray.init(address=address, job_config=job_config)
|
||||
|
||||
server_handle = ray_client_server.serve("localhost:50051", connect)
|
||||
yield server_handle
|
||||
|
|
|
@ -5,9 +5,9 @@ import unittest
|
|||
|
||||
import subprocess
|
||||
|
||||
import tempfile
|
||||
from unittest import mock
|
||||
from pathlib import Path
|
||||
|
||||
import ray
|
||||
from ray.test_utils import run_string_as_driver
|
||||
from ray._private.utils import get_conda_env_dir, get_conda_bin_executable
|
||||
|
@ -19,14 +19,19 @@ import logging
|
|||
sys.path.insert(0, "{working_dir}")
|
||||
import test_module
|
||||
import ray
|
||||
import ray.util
|
||||
import os
|
||||
|
||||
job_config = ray.job_config.JobConfig(
|
||||
runtime_env={runtime_env}
|
||||
)
|
||||
|
||||
ray.init(address="{redis_address}",
|
||||
job_config=job_config,
|
||||
logging_level=logging.DEBUG)
|
||||
if os.environ.get("USE_RAY_CLIENT"):
|
||||
ray.util.connect("{address}", job_config=job_config)
|
||||
else:
|
||||
ray.init(address="{address}",
|
||||
job_config=job_config,
|
||||
logging_level=logging.DEBUG)
|
||||
|
||||
@ray.remote
|
||||
def run_test():
|
||||
|
@ -40,15 +45,17 @@ class TestActor(object):
|
|||
|
||||
{execute_statement}
|
||||
|
||||
ray.shutdown()
|
||||
if os.environ.get("USE_RAY_CLIENT"):
|
||||
ray.util.disconnect()
|
||||
else:
|
||||
ray.shutdown()
|
||||
from time import sleep
|
||||
sleep(5)
|
||||
sleep(10)
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def working_dir():
|
||||
import tempfile
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
path = Path(tmp_dir)
|
||||
module_path = path / "test_module"
|
||||
|
@ -68,50 +75,59 @@ from test_module.test import one
|
|||
yield tmp_dir
|
||||
|
||||
|
||||
def start_client_server(cluster, client_mode):
|
||||
from ray._private.runtime_env import PKG_DIR
|
||||
if not client_mode:
|
||||
return (cluster.address, None, PKG_DIR)
|
||||
ray.worker._global_node._ray_params.ray_client_server_port = "10003"
|
||||
ray.worker._global_node.start_ray_client_server()
|
||||
return ("localhost:10003", {"USE_RAY_CLIENT": "1"}, PKG_DIR)
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
||||
def test_single_node(ray_start_cluster_head, working_dir):
|
||||
@pytest.mark.parametrize("client_mode", [True, False])
|
||||
def test_single_node(ray_start_cluster_head, working_dir, client_mode):
|
||||
cluster = ray_start_cluster_head
|
||||
redis_address = cluster.address
|
||||
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
|
||||
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
|
||||
execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))"
|
||||
script = driver_script.format(**locals())
|
||||
out = run_string_as_driver(script)
|
||||
out = run_string_as_driver(script, env)
|
||||
assert out.strip().split()[-1] == "1000"
|
||||
from ray._private.runtime_env import PKG_DIR
|
||||
assert len(list(Path(PKG_DIR).iterdir())) == 1
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
||||
def test_two_node(two_node_cluster, working_dir):
|
||||
@pytest.mark.parametrize("client_mode", [True, False])
|
||||
def test_two_node(two_node_cluster, working_dir, client_mode):
|
||||
cluster, _ = two_node_cluster
|
||||
redis_address = cluster.address
|
||||
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
|
||||
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
|
||||
execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))"
|
||||
script = driver_script.format(**locals())
|
||||
out = run_string_as_driver(script)
|
||||
out = run_string_as_driver(script, env)
|
||||
assert out.strip().split()[-1] == "1000"
|
||||
from ray._private.runtime_env import PKG_DIR
|
||||
assert len(list(Path(PKG_DIR).iterdir())) == 1
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
||||
def test_two_node_module(two_node_cluster, working_dir):
|
||||
@pytest.mark.parametrize("client_mode", [True, False])
|
||||
def test_two_node_module(two_node_cluster, working_dir, client_mode):
|
||||
cluster, _ = two_node_cluster
|
||||
redis_address = cluster.address
|
||||
runtime_env = """{ "local_modules": [test_module] }"""
|
||||
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
|
||||
runtime_env = """{ "py_modules": [test_module.__path__[0]] }"""
|
||||
execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))"
|
||||
script = driver_script.format(**locals())
|
||||
print(script)
|
||||
out = run_string_as_driver(script)
|
||||
out = run_string_as_driver(script, env)
|
||||
assert out.strip().split()[-1] == "1000"
|
||||
from ray._private.runtime_env import PKG_DIR
|
||||
assert len(list(Path(PKG_DIR).iterdir())) == 1
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
||||
def test_two_node_uri(two_node_cluster, working_dir):
|
||||
@pytest.mark.parametrize("client_mode", [True, False])
|
||||
def test_two_node_uri(two_node_cluster, working_dir, client_mode):
|
||||
cluster, _ = two_node_cluster
|
||||
redis_address = cluster.address
|
||||
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
|
||||
import ray._private.runtime_env as runtime_env
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix="zip") as tmp_file:
|
||||
|
@ -122,41 +138,40 @@ def test_two_node_uri(two_node_cluster, working_dir):
|
|||
runtime_env = f"""{{ "working_dir_uri": "{pkg_uri}" }}"""
|
||||
execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))"
|
||||
script = driver_script.format(**locals())
|
||||
out = run_string_as_driver(script)
|
||||
out = run_string_as_driver(script, env)
|
||||
assert out.strip().split()[-1] == "1000"
|
||||
from ray._private.runtime_env import PKG_DIR
|
||||
assert len(list(Path(PKG_DIR).iterdir())) == 1
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
||||
def test_regular_actors(ray_start_cluster_head, working_dir):
|
||||
@pytest.mark.parametrize("client_mode", [True, False])
|
||||
def test_regular_actors(ray_start_cluster_head, working_dir, client_mode):
|
||||
cluster = ray_start_cluster_head
|
||||
redis_address = cluster.address
|
||||
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
|
||||
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
|
||||
execute_statement = """
|
||||
test_actor = TestActor.options(name="test_actor").remote()
|
||||
print(sum(ray.get([test_actor.one.remote()] * 1000)))
|
||||
"""
|
||||
script = driver_script.format(**locals())
|
||||
out = run_string_as_driver(script)
|
||||
out = run_string_as_driver(script, env)
|
||||
assert out.strip().split()[-1] == "1000"
|
||||
from ray._private.runtime_env import PKG_DIR
|
||||
assert len(list(Path(PKG_DIR).iterdir())) == 1
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
||||
def test_detached_actors(ray_start_cluster_head, working_dir):
|
||||
@pytest.mark.parametrize("client_mode", [True, False])
|
||||
def test_detached_actors(ray_start_cluster_head, working_dir, client_mode):
|
||||
cluster = ray_start_cluster_head
|
||||
redis_address = cluster.address
|
||||
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
|
||||
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
|
||||
execute_statement = """
|
||||
test_actor = TestActor.options(name="test_actor", lifetime="detached").remote()
|
||||
print(sum(ray.get([test_actor.one.remote()] * 1000)))
|
||||
"""
|
||||
script = driver_script.format(**locals())
|
||||
out = run_string_as_driver(script)
|
||||
out = run_string_as_driver(script, env)
|
||||
assert out.strip().split()[-1] == "1000"
|
||||
from ray._private.runtime_env import PKG_DIR
|
||||
# It's a detached actors, so it should still be there
|
||||
assert len(list(Path(PKG_DIR).iterdir())) == 2
|
||||
pkg = list(Path(PKG_DIR).glob("*.zip"))[0]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import List, Tuple, Dict, Any
|
||||
|
||||
from ray.job_config import JobConfig
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
@ -29,6 +29,7 @@ class RayAPIStub:
|
|||
|
||||
def connect(self,
|
||||
conn_str: str,
|
||||
job_config: JobConfig = None,
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
connection_retries: int = 3,
|
||||
|
@ -38,6 +39,7 @@ class RayAPIStub:
|
|||
|
||||
Args:
|
||||
conn_str: Connection string, in the form "[host]:port"
|
||||
job_config: The job config of the server.
|
||||
secure: Whether to use a TLS secured gRPC channel
|
||||
metadata: gRPC metadata to send on connect
|
||||
connection_retries: number of connection attempts to make
|
||||
|
@ -67,6 +69,7 @@ class RayAPIStub:
|
|||
metadata=metadata,
|
||||
connection_retries=connection_retries)
|
||||
self.api.worker = self.client_worker
|
||||
self.client_worker._server_init(job_config)
|
||||
conn_info = self.client_worker.connection_info()
|
||||
self._check_versions(conn_info, ignore_version)
|
||||
return conn_info
|
||||
|
|
|
@ -247,6 +247,10 @@ class ClientAPI:
|
|||
"""Hook for internal_kv._internal_kv_initialized."""
|
||||
return self.is_initialized()
|
||||
|
||||
def _internal_kv_exists(self, key: bytes) -> bool:
|
||||
"""Hook for internal_kv._internal_kv_exists."""
|
||||
return self.worker.internal_kv_exists(as_bytes(key))
|
||||
|
||||
def _internal_kv_get(self, key: bytes) -> bytes:
|
||||
"""Hook for internal_kv._internal_kv_get."""
|
||||
return self.worker.internal_kv_get(as_bytes(key))
|
||||
|
|
|
@ -117,6 +117,19 @@ class DataClient:
|
|||
req.req_id = req_id
|
||||
self.request_queue.put(req)
|
||||
|
||||
def Init(self, request: ray_client_pb2.InitRequest,
|
||||
context=None) -> ray_client_pb2.InitResponse:
|
||||
datareq = ray_client_pb2.DataRequest(init=request, )
|
||||
resp = self._blocking_send(datareq)
|
||||
return resp.init
|
||||
|
||||
def PrepRuntimeEnv(self,
|
||||
request: ray_client_pb2.PrepRuntimeEnvRequest,
|
||||
context=None) -> ray_client_pb2.PrepRuntimeEnvResponse:
|
||||
datareq = ray_client_pb2.DataRequest(prep_runtime_env=request, )
|
||||
resp = self._blocking_send(datareq)
|
||||
return resp.prep_runtime_env
|
||||
|
||||
def ConnectionInfo(self,
|
||||
context=None) -> ray_client_pb2.ConnectionInfoResponse:
|
||||
datareq = ray_client_pb2.DataRequest(
|
||||
|
|
|
@ -29,7 +29,7 @@ def ray_start_client_server_pair():
|
|||
def ray_start_cluster_client_server_pair(address):
|
||||
ray._inside_client_test = True
|
||||
|
||||
def ray_connect_handler():
|
||||
def ray_connect_handler(job_config=None):
|
||||
real_ray.init(address=address)
|
||||
|
||||
server = ray_client_server.serve(
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import grpc
|
||||
import sys
|
||||
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import TYPE_CHECKING
|
||||
from threading import Lock
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
|
@ -20,12 +20,10 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
def __init__(self, basic_service: "RayletServicer",
|
||||
ray_connect_handler: Callable):
|
||||
def __init__(self, basic_service: "RayletServicer"):
|
||||
self.basic_service = basic_service
|
||||
self.clients_lock = Lock()
|
||||
self.num_clients = 0 # guarded by self.clients_lock
|
||||
self.ray_connect_handler = ray_connect_handler
|
||||
|
||||
def Datapath(self, request_iterator, context):
|
||||
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||
|
@ -36,55 +34,47 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
return
|
||||
logger.debug(f"New data connection from client {client_id}: ")
|
||||
try:
|
||||
with self.clients_lock:
|
||||
with disable_client_hook():
|
||||
# It's important to keep the ray initialization call
|
||||
# within this locked context or else Ray could hang.
|
||||
if self.num_clients == 0 and not ray.is_initialized():
|
||||
self.ray_connect_handler()
|
||||
threshold = int(CLIENT_SERVER_MAX_THREADS / 2)
|
||||
if self.num_clients >= threshold:
|
||||
context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
|
||||
logger.warning(
|
||||
f"[Data Servicer]: Num clients {self.num_clients} "
|
||||
f"has reached the threshold {threshold}. "
|
||||
f"Rejecting client: {metadata['client_id']}. ")
|
||||
if log_once("client_threshold"):
|
||||
logger.warning(
|
||||
"You can configure the client connection "
|
||||
"threshold by setting the "
|
||||
"RAY_CLIENT_SERVER_MAX_THREADS env var "
|
||||
f"(currently set to {CLIENT_SERVER_MAX_THREADS}).")
|
||||
return
|
||||
|
||||
self.num_clients += 1
|
||||
logger.debug(f"Accepted data connection from {client_id}. "
|
||||
f"Total clients: {self.num_clients}")
|
||||
accepted_connection = True
|
||||
for req in request_iterator:
|
||||
resp = None
|
||||
req_type = req.WhichOneof("type")
|
||||
if req_type == "get":
|
||||
get_resp = self.basic_service._get_object(
|
||||
req.get, client_id)
|
||||
resp = ray_client_pb2.DataResponse(get=get_resp)
|
||||
elif req_type == "put":
|
||||
put_resp = self.basic_service._put_object(
|
||||
req.put, client_id)
|
||||
resp = ray_client_pb2.DataResponse(put=put_resp)
|
||||
elif req_type == "release":
|
||||
released = []
|
||||
for rel_id in req.release.ids:
|
||||
rel = self.basic_service.release(client_id, rel_id)
|
||||
released.append(rel)
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
release=ray_client_pb2.ReleaseResponse(ok=released))
|
||||
elif req_type == "connection_info":
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
connection_info=self._build_connection_response())
|
||||
if req_type == "init":
|
||||
resp = self._init(req.init, client_id)
|
||||
if resp is None:
|
||||
context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
|
||||
return
|
||||
logger.debug(f"Accepted data connection from {client_id}. "
|
||||
f"Total clients: {self.num_clients}")
|
||||
accepted_connection = True
|
||||
else:
|
||||
raise Exception(f"Unreachable code: Request type "
|
||||
f"{req_type} not handled in Datapath")
|
||||
assert accepted_connection
|
||||
if req_type == "get":
|
||||
get_resp = self.basic_service._get_object(
|
||||
req.get, client_id)
|
||||
resp = ray_client_pb2.DataResponse(get=get_resp)
|
||||
elif req_type == "put":
|
||||
put_resp = self.basic_service._put_object(
|
||||
req.put, client_id)
|
||||
resp = ray_client_pb2.DataResponse(put=put_resp)
|
||||
elif req_type == "release":
|
||||
released = []
|
||||
for rel_id in req.release.ids:
|
||||
rel = self.basic_service.release(client_id, rel_id)
|
||||
released.append(rel)
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
release=ray_client_pb2.ReleaseResponse(
|
||||
ok=released))
|
||||
elif req_type == "connection_info":
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
connection_info=self._build_connection_response())
|
||||
elif req_type == "prep_runtime_env":
|
||||
with self.clients_lock:
|
||||
resp_prep = self.basic_service.PrepRuntimeEnv(
|
||||
req.prep_runtime_env)
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
prep_runtime_env=resp_prep)
|
||||
else:
|
||||
raise Exception(f"Unreachable code: Request type "
|
||||
f"{req_type} not handled in Datapath")
|
||||
resp.req_id = req.req_id
|
||||
yield resp
|
||||
except grpc.RpcError as e:
|
||||
|
@ -106,6 +96,25 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
logger.debug("Shutting down ray.")
|
||||
ray.shutdown()
|
||||
|
||||
def _init(self, req_init, client_id):
|
||||
with self.clients_lock:
|
||||
threshold = int(CLIENT_SERVER_MAX_THREADS / 2)
|
||||
if self.num_clients >= threshold:
|
||||
logger.warning(
|
||||
f"[Data Servicer]: Num clients {self.num_clients} "
|
||||
f"has reached the threshold {threshold}. "
|
||||
f"Rejecting client: {client_id}. ")
|
||||
if log_once("client_threshold"):
|
||||
logger.warning(
|
||||
"You can configure the client connection "
|
||||
"threshold by setting the "
|
||||
"RAY_CLIENT_SERVER_MAX_THREADS env var "
|
||||
f"(currently set to {CLIENT_SERVER_MAX_THREADS}).")
|
||||
return None
|
||||
resp_init = self.basic_service.Init(req_init)
|
||||
self.num_clients += 1
|
||||
return ray_client_pb2.DataResponse(init=resp_init, )
|
||||
|
||||
def _build_connection_response(self):
|
||||
with self.clients_lock:
|
||||
cur_num_clients = self.num_clients
|
||||
|
|
|
@ -4,14 +4,16 @@ import grpc
|
|||
import base64
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Dict
|
||||
from typing import Set
|
||||
from typing import Optional
|
||||
|
||||
from typing import Callable
|
||||
from ray import cloudpickle
|
||||
from ray.job_config import JobConfig
|
||||
import ray
|
||||
import ray.state
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
|
@ -33,7 +35,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
def __init__(self):
|
||||
def __init__(self, ray_connect_handler: Callable):
|
||||
"""Construct a raylet service
|
||||
|
||||
Args:
|
||||
ray_connect_handler (Callable): Function to connect to ray cluster
|
||||
"""
|
||||
# Stores client_id -> (ref_id -> ObjectRef)
|
||||
self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict(
|
||||
dict)
|
||||
|
@ -46,6 +53,42 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
self.registered_actor_classes = {}
|
||||
self.named_actors = set()
|
||||
self.state_lock = threading.Lock()
|
||||
self.ray_connect_handler = ray_connect_handler
|
||||
|
||||
def Init(self, request, context=None) -> ray_client_pb2.InitResponse:
|
||||
import pickle
|
||||
if request.job_config:
|
||||
job_config = pickle.loads(request.job_config)
|
||||
job_config.client_job = True
|
||||
else:
|
||||
job_config = None
|
||||
current_job_config = None
|
||||
with disable_client_hook():
|
||||
if ray.is_initialized():
|
||||
worker = ray.worker.global_worker
|
||||
current_job_config = worker.core_worker.get_job_config()
|
||||
else:
|
||||
self.ray_connect_handler(job_config)
|
||||
if job_config is None:
|
||||
return ray_client_pb2.InitResponse()
|
||||
job_config = job_config.get_proto_job_config()
|
||||
# If the server has been initialized, we need to compare whether the
|
||||
# runtime env is compatible.
|
||||
if current_job_config and set(job_config.runtime_env.uris) != set(
|
||||
current_job_config.runtime_env.uris):
|
||||
raise grpc.RpcError(
|
||||
"Runtime environment doesn't match "
|
||||
f"request one {job_config.runtime_env.uris} "
|
||||
f"current one {current_job_config.runtime_env.uris}")
|
||||
return ray_client_pb2.InitResponse()
|
||||
|
||||
def PrepRuntimeEnv(self, request,
|
||||
context=None) -> ray_client_pb2.PrepRuntimeEnvResponse:
|
||||
job_config = ray.worker.global_worker.core_worker.get_job_config()
|
||||
missing_uris = self._prepare_runtime_env(job_config.runtime_env)
|
||||
if len(missing_uris) != 0:
|
||||
raise grpc.RpcError(f"Missing uris: {missing_uris}")
|
||||
return ray_client_pb2.PrepRuntimeEnvResponse()
|
||||
|
||||
def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse:
|
||||
with disable_client_hook():
|
||||
|
@ -69,6 +112,13 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
request.prefix)
|
||||
return ray_client_pb2.KVListResponse(keys=keys)
|
||||
|
||||
def KVExists(self, request,
|
||||
context=None) -> ray_client_pb2.KVExistsResponse:
|
||||
with disable_client_hook():
|
||||
exists = ray.experimental.internal_kv._internal_kv_exists(
|
||||
request.key)
|
||||
return ray_client_pb2.KVExistsResponse(exists=exists)
|
||||
|
||||
def ClusterInfo(self, request,
|
||||
context=None) -> ray_client_pb2.ClusterInfoResponse:
|
||||
resp = ray_client_pb2.ClusterInfoResponse()
|
||||
|
@ -371,6 +421,21 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
kwargout[k] = convert_from_arg(kwarg_map[k], self)
|
||||
return argout, kwargout
|
||||
|
||||
def _prepare_runtime_env(self, job_runtime_env) -> List[str]:
|
||||
"""Download runtime environment to local node"""
|
||||
missing_uris = []
|
||||
uris = job_runtime_env.uris
|
||||
from ray._private import runtime_env
|
||||
with disable_client_hook():
|
||||
for uri in uris:
|
||||
try:
|
||||
runtime_env.fetch_package(uri)
|
||||
print("Adding!: ", runtime_env._get_local_path(uri))
|
||||
sys.path.insert(0, str(runtime_env._get_local_path(uri)))
|
||||
except IOError:
|
||||
missing_uris.append(uri)
|
||||
return missing_uris
|
||||
|
||||
def lookup_or_register_func(
|
||||
self, id: bytes, client_id: str,
|
||||
options: Optional[Dict]) -> ray.remote_function.RemoteFunction:
|
||||
|
@ -453,10 +518,10 @@ class ClientServerHandle:
|
|||
|
||||
|
||||
def serve(connection_str, ray_connect_handler=None):
|
||||
def default_connect_handler():
|
||||
def default_connect_handler(job_config: JobConfig = None):
|
||||
with disable_client_hook():
|
||||
if not ray.is_initialized():
|
||||
return ray.init()
|
||||
return ray.init(job_config=job_config)
|
||||
|
||||
ray_connect_handler = ray_connect_handler or default_connect_handler
|
||||
server = grpc.server(
|
||||
|
@ -465,9 +530,8 @@ def serve(connection_str, ray_connect_handler=None):
|
|||
("grpc.max_send_message_length", GRPC_MAX_MESSAGE_SIZE),
|
||||
("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_SIZE),
|
||||
])
|
||||
task_servicer = RayletServicer()
|
||||
data_servicer = DataServicer(
|
||||
task_servicer, ray_connect_handler=ray_connect_handler)
|
||||
task_servicer = RayletServicer(ray_connect_handler)
|
||||
data_servicer = DataServicer(task_servicer)
|
||||
logs_servicer = LogstreamServicer()
|
||||
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||
task_servicer, server)
|
||||
|
@ -491,13 +555,13 @@ def init_and_serve(connection_str, *args, **kwargs):
|
|||
# Disable client mode inside the worker's environment
|
||||
info = ray.init(*args, **kwargs)
|
||||
|
||||
def ray_connect_handler():
|
||||
def ray_connect_handler(job_config=None):
|
||||
# Ray client will disconnect from ray when
|
||||
# num_clients == 0.
|
||||
if ray.is_initialized():
|
||||
return info
|
||||
else:
|
||||
return ray.init(*args, **kwargs)
|
||||
return ray.init(job_config=job_config, *args, **kwargs)
|
||||
|
||||
server_handle = serve(
|
||||
connection_str, ray_connect_handler=ray_connect_handler)
|
||||
|
@ -511,14 +575,17 @@ def shutdown_with_server(server, _exiting_interpreter=False):
|
|||
|
||||
|
||||
def create_ray_handler(redis_address, redis_password):
|
||||
def ray_connect_handler():
|
||||
def ray_connect_handler(job_config: JobConfig = None):
|
||||
if redis_address:
|
||||
if redis_password:
|
||||
ray.init(address=redis_address, _redis_password=redis_password)
|
||||
ray.init(
|
||||
address=redis_address,
|
||||
_redis_password=redis_password,
|
||||
job_config=job_config)
|
||||
else:
|
||||
ray.init(address=redis_address)
|
||||
ray.init(address=redis_address, job_config=job_config)
|
||||
else:
|
||||
ray.init()
|
||||
ray.init(job_config=job_config)
|
||||
|
||||
return ray_connect_handler
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from typing import TYPE_CHECKING
|
|||
|
||||
import grpc
|
||||
|
||||
from ray.job_config import JobConfig
|
||||
import ray.cloudpickle as cloudpickle
|
||||
# Use cloudpickle's version of pickle for UnpicklingError
|
||||
from ray.cloudpickle.compat import pickle
|
||||
|
@ -145,6 +146,7 @@ class Worker:
|
|||
|
||||
self.log_client = LogstreamClient(self.channel, self.metadata)
|
||||
self.log_client.set_logstream_level(logging.INFO)
|
||||
|
||||
self.closed = False
|
||||
|
||||
def _on_channel_state_change(self, conn_state: grpc.ChannelConnectivity):
|
||||
|
@ -394,6 +396,11 @@ class Worker:
|
|||
resp = self.server.KVGet(req, metadata=self.metadata)
|
||||
return resp.value
|
||||
|
||||
def internal_kv_exists(self, key: bytes) -> bytes:
|
||||
req = ray_client_pb2.KVGetRequest(key=key)
|
||||
resp = self.server.KVGet(req, metadata=self.metadata)
|
||||
return resp.value
|
||||
|
||||
def internal_kv_put(self, key: bytes, value: bytes,
|
||||
overwrite: bool) -> bool:
|
||||
req = ray_client_pb2.KVPutRequest(
|
||||
|
@ -431,6 +438,30 @@ class Worker:
|
|||
def is_connected(self) -> bool:
|
||||
return self._conn_state == grpc.ChannelConnectivity.READY
|
||||
|
||||
def _server_init(self, job_config: JobConfig):
|
||||
"""Initialize the server"""
|
||||
try:
|
||||
if job_config is None:
|
||||
init_req = ray_client_pb2.InitRequest()
|
||||
self.data_client.Init(init_req)
|
||||
return
|
||||
|
||||
import ray._private.runtime_env as runtime_env
|
||||
import tempfile
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
if runtime_env.PKG_DIR is None:
|
||||
runtime_env.PKG_DIR = tmp_dir
|
||||
# Generate the uri for runtime env
|
||||
runtime_env.rewrite_working_dir_uri(job_config)
|
||||
init_req = ray_client_pb2.InitRequest(
|
||||
job_config=pickle.dumps(job_config))
|
||||
self.data_client.Init(init_req)
|
||||
runtime_env.upload_runtime_env_package_if_needed(job_config)
|
||||
prep_req = ray_client_pb2.PrepRuntimeEnvRequest()
|
||||
self.data_client.PrepRuntimeEnv(prep_req)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details())
|
||||
|
||||
def _convert_actor(self, actor: "ActorClass") -> str:
|
||||
"""Register a ClientActorClass for the ActorClass and return a UUID"""
|
||||
key = uuid.uuid4().hex
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from ray.util.client import ray
|
||||
|
||||
from ray.job_config import JobConfig
|
||||
from ray._private.client_mode_hook import _enable_client_hook
|
||||
from ray._private.client_mode_hook import _explicitly_enable_client_mode
|
||||
|
||||
|
@ -10,6 +10,7 @@ def connect(conn_str: str,
|
|||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
connection_retries: int = 3,
|
||||
job_config: JobConfig = None,
|
||||
*,
|
||||
ignore_version: bool = False) -> Dict[str, Any]:
|
||||
if ray.is_connected():
|
||||
|
@ -26,6 +27,7 @@ def connect(conn_str: str,
|
|||
# the correct metadata
|
||||
return ray.connect(
|
||||
conn_str,
|
||||
job_config=job_config,
|
||||
secure=secure,
|
||||
metadata=metadata,
|
||||
connection_retries=3,
|
||||
|
|
|
@ -1219,15 +1219,16 @@ def connect(node,
|
|||
node.node_manager_port, node.raylet_ip_address, (mode == LOCAL_MODE),
|
||||
driver_name, log_stdout_file_path, log_stderr_file_path,
|
||||
serialized_job_config, node.metrics_agent_port)
|
||||
# Notify raylet that the core worker is ready.
|
||||
worker.core_worker.notify_raylet()
|
||||
|
||||
# Create an object for interfacing with the global state.
|
||||
# Note, global state should be intialized after `CoreWorker`, because it
|
||||
# will use glog, which is intialized in `CoreWorker`.
|
||||
ray.state.state._initialize_global_state(
|
||||
node.redis_address, redis_password=node.redis_password)
|
||||
if mode == SCRIPT_MODE:
|
||||
# If it's a driver and it's not coming from ray client, we'll prepare the
|
||||
# environment here. If it's ray client, the environmen will be prepared
|
||||
# at the server side.
|
||||
if mode == SCRIPT_MODE and not job_config.client_job:
|
||||
runtime_env.upload_runtime_env_package_if_needed(job_config)
|
||||
elif mode == WORKER_MODE:
|
||||
# TODO(ekl) get rid of the env var hack and get runtime env from the
|
||||
|
@ -1237,6 +1238,9 @@ def connect(node,
|
|||
worker.core_worker.get_job_config().runtime_env.uris
|
||||
runtime_env.ensure_runtime_env_setup(job_config)
|
||||
|
||||
# Notify raylet that the core worker is ready.
|
||||
worker.core_worker.notify_raylet()
|
||||
|
||||
if driver_object_store_memory is not None:
|
||||
worker.core_worker.set_object_store_client_options(
|
||||
f"ray_driver_{os.getpid()}", driver_object_store_memory)
|
||||
|
|
|
@ -162,11 +162,10 @@ cc_proto_library(
|
|||
proto_library(
|
||||
name = "ray_client_proto",
|
||||
srcs = ["ray_client.proto"],
|
||||
deps = [],
|
||||
deps = [":common_proto"],
|
||||
)
|
||||
|
||||
python_grpc_compile(
|
||||
name = "ray_client_py_proto",
|
||||
deps = [":ray_client_proto"]
|
||||
deps = [":ray_client_proto"],
|
||||
)
|
||||
|
||||
|
|
|
@ -196,6 +196,14 @@ message TerminateResponse {
|
|||
bool ok = 1;
|
||||
}
|
||||
|
||||
message KVExistsRequest {
|
||||
bytes key = 1;
|
||||
}
|
||||
|
||||
message KVExistsResponse {
|
||||
bool exists = 1;
|
||||
}
|
||||
|
||||
message KVGetRequest {
|
||||
bytes key = 1;
|
||||
}
|
||||
|
@ -229,7 +237,25 @@ message KVListResponse {
|
|||
repeated bytes keys = 1;
|
||||
}
|
||||
|
||||
message InitRequest {
|
||||
// job_config of ray.init
|
||||
bytes job_config = 1;
|
||||
}
|
||||
|
||||
message InitResponse {
|
||||
}
|
||||
|
||||
message PrepRuntimeEnvRequest {
|
||||
}
|
||||
|
||||
message PrepRuntimeEnvResponse {
|
||||
}
|
||||
|
||||
service RayletDriver {
|
||||
rpc Init(InitRequest) returns (InitResponse) {
|
||||
}
|
||||
rpc PrepRuntimeEnv(PrepRuntimeEnvRequest) returns (PrepRuntimeEnvResponse) {
|
||||
}
|
||||
rpc GetObject(GetRequest) returns (GetResponse) {
|
||||
}
|
||||
rpc PutObject(PutRequest) returns (PutResponse) {
|
||||
|
@ -250,6 +276,8 @@ service RayletDriver {
|
|||
}
|
||||
rpc KVList(KVListRequest) returns (KVListResponse) {
|
||||
}
|
||||
rpc KVExists(KVExistsRequest) returns (KVExistsResponse) {
|
||||
}
|
||||
}
|
||||
|
||||
message ReleaseRequest {
|
||||
|
@ -288,6 +316,8 @@ message DataRequest {
|
|||
PutRequest put = 3;
|
||||
ReleaseRequest release = 4;
|
||||
ConnectionInfoRequest connection_info = 5;
|
||||
InitRequest init = 6;
|
||||
PrepRuntimeEnvRequest prep_runtime_env = 7;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -299,6 +329,8 @@ message DataResponse {
|
|||
PutResponse put = 3;
|
||||
ReleaseResponse release = 4;
|
||||
ConnectionInfoResponse connection_info = 5;
|
||||
InitResponse init = 6;
|
||||
PrepRuntimeEnvResponse prep_runtime_env = 7;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue