mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Ray Client] [runtime env] Skip env hook in Ray client server (#26688)
Previously, using an env_hook with Ray Client would only execute the env_hook on the server side (a Ray cluster machine). An env_hook defined on the client side would never be executed. But the main problem is with the server-side env_hook. Consider the simple example where the env_hook rewrites the `working_dir` or `py_modules` with a local directory. Currently, when using Ray Client, the `working_dir` and `py_modules` are uploaded to the GCS before `ray.init()` is called on the server. This is a fundamental constraint because the server-side driver script needs to be able to import modules from the `working_dir` or `py_modules`. After the upload, these fields are overwritten with the URIs for the uploaded packages. After this happens, on the server side Ray expects the `working_dir` and `py_modules` fields to only contain GCS URIs. So overwriting `working_dir` to be a local directory after this occurs doesn't make sense (and Ray will rightfully throw a RuntimeEnv validation error here.) If a cluster is set up with such an env hook, it will only work when `ray.init()` is called by the user on a cluster machine; i.e. it will only work in non-Ray Client cases. If a user ever wants to use Ray Client with this cluster, it will be broken with no way to disable the env hook. To remedy this, this PR disables the execution of the env_hook when using Ray Client. We can consider adding support in the future for env_hooks to be executed on the client side when using Ray Client.
This commit is contained in:
parent
3a48a79fd7
commit
1aad5d2136
4 changed files with 43 additions and 13 deletions
|
@ -1180,7 +1180,8 @@ def init(
|
|||
"_tracing_startup_hook", None
|
||||
)
|
||||
_node_name: str = kwargs.pop("_node_name", None)
|
||||
|
||||
# Fix for https://github.com/ray-project/ray/issues/26729
|
||||
_skip_env_hook: bool = kwargs.pop("_skip_env_hook", False)
|
||||
if not logging_format:
|
||||
logging_format = ray_constants.LOGGER_FORMAT
|
||||
|
||||
|
@ -1264,7 +1265,7 @@ def init(
|
|||
job_config_json = json.loads(os.environ.get(RAY_JOB_CONFIG_JSON_ENV_VAR))
|
||||
job_config = ray.job_config.JobConfig.from_json(job_config_json)
|
||||
|
||||
if ray_constants.RAY_RUNTIME_ENV_HOOK in os.environ:
|
||||
if ray_constants.RAY_RUNTIME_ENV_HOOK in os.environ and not _skip_env_hook:
|
||||
runtime_env = _load_class(os.environ[ray_constants.RAY_RUNTIME_ENV_HOOK])(
|
||||
job_config.runtime_env
|
||||
)
|
||||
|
@ -1273,7 +1274,7 @@ def init(
|
|||
# RAY_JOB_CONFIG_JSON_ENV_VAR is only set at ray job manager level and has
|
||||
# higher priority in case user also provided runtime_env for ray.init()
|
||||
else:
|
||||
if ray_constants.RAY_RUNTIME_ENV_HOOK in os.environ:
|
||||
if ray_constants.RAY_RUNTIME_ENV_HOOK in os.environ and not _skip_env_hook:
|
||||
runtime_env = _load_class(os.environ[ray_constants.RAY_RUNTIME_ENV_HOOK])(
|
||||
runtime_env
|
||||
)
|
||||
|
|
|
@ -49,8 +49,8 @@ def init_and_serve_lazy():
|
|||
cluster.wait_for_nodes(1)
|
||||
address = cluster.address
|
||||
|
||||
def connect(job_config=None):
|
||||
ray.init(address=address, job_config=job_config)
|
||||
def connect(job_config=None, **ray_init_kwargs):
|
||||
ray.init(address=address, job_config=job_config, **ray_init_kwargs)
|
||||
|
||||
server_handle = ray_client_server.serve("localhost:50051", connect)
|
||||
yield server_handle
|
||||
|
|
|
@ -44,11 +44,16 @@ def _hook(env):
|
|||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
||||
def test_runtime_env_hook():
|
||||
script = """
|
||||
@pytest.mark.parametrize("skip_hook", [True, False])
|
||||
def test_runtime_env_hook(skip_hook):
|
||||
ray_init_snippet = "ray.init(_skip_env_hook=True)" if skip_hook else ""
|
||||
|
||||
script = f"""
|
||||
import ray
|
||||
import os
|
||||
|
||||
{ray_init_snippet}
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return os.environ.get("HOOK_KEY")
|
||||
|
@ -61,9 +66,28 @@ print(ray.get(f.remote()))
|
|||
)
|
||||
out_str = proc.stdout.read().decode("ascii") + proc.stderr.read().decode("ascii")
|
||||
print(out_str)
|
||||
if skip_hook:
|
||||
assert "HOOK_VALUE" not in out_str
|
||||
else:
|
||||
assert "HOOK_VALUE" in out_str
|
||||
|
||||
|
||||
def test_env_hook_skipped_for_ray_client(start_cluster, monkeypatch):
|
||||
monkeypatch.setenv("RAY_RUNTIME_ENV_HOOK", "ray.tests.test_output._hook")
|
||||
cluster, address = start_cluster
|
||||
ray.init(address)
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return os.environ.get("HOOK_KEY")
|
||||
|
||||
using_ray_client = address.startswith("ray://")
|
||||
if using_ray_client:
|
||||
assert ray.get(f.remote()) is None
|
||||
else:
|
||||
assert ray.get(f.remote()) == "HOOK_VALUE"
|
||||
|
||||
|
||||
def test_autoscaler_infeasible():
|
||||
script = """
|
||||
import ray
|
||||
|
|
|
@ -77,7 +77,12 @@ class _ClientContext:
|
|||
logging_level = ray_constants.LOGGER_LEVEL
|
||||
logging_format = ray_constants.LOGGER_FORMAT
|
||||
|
||||
if ray_init_kwargs is not None:
|
||||
if ray_init_kwargs is None:
|
||||
ray_init_kwargs = {}
|
||||
|
||||
# NOTE(architkulkarni): env_hook is not supported with Ray Client.
|
||||
ray_init_kwargs["_skip_env_hook"] = True
|
||||
|
||||
if ray_init_kwargs.get("logging_level") is not None:
|
||||
logging_level = ray_init_kwargs["logging_level"]
|
||||
if ray_init_kwargs.get("logging_format") is not None:
|
||||
|
|
Loading…
Add table
Reference in a new issue