[Ray Client] [runtime env] Skip env hook in Ray client server (#26688)

Previously, using an env_hook with Ray Client would only execute the env_hook on the server side (a Ray cluster machine).  An env_hook defined on the client side would never be executed.  But the main problem is with the server-side env_hook.

Consider the simple example where the env_hook rewrites the `working_dir` or `py_modules` with a local directory.

Currently, when using Ray Client, the `working_dir` and `py_modules` are uploaded to the GCS before `ray.init()` is called on the server.   This is a fundamental constraint because the server-side driver script needs to be able to import modules from the `working_dir` or `py_modules`.  After the upload, these fields are overwritten with the URIs for the uploaded packages.  

After this happens, on the server side Ray expects the `working_dir` and `py_modules` fields to only contain GCS URIs.  So overwriting `working_dir` to be a local directory after this occurs doesn't make sense (and Ray will rightfully throw a RuntimeEnv validation error here.)

If a cluster is set up with such an env hook, it will only work when `ray.init()` is called by the user on a cluster machine; i.e. it will only work in non-Ray Client cases.  If a user ever wants to use Ray Client with this cluster, it will be broken with no way to disable the env hook.  To remedy this, this PR disables the execution of the env_hook when using Ray Client.

We can consider adding support in the future for env_hooks to be executed on the client side when using Ray Client.
This commit is contained in:
Archit Kulkarni 2022-07-21 08:10:11 -07:00 committed by GitHub
parent 3a48a79fd7
commit 1aad5d2136
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 13 deletions

View file

@ -1180,7 +1180,8 @@ def init(
"_tracing_startup_hook", None
)
_node_name: str = kwargs.pop("_node_name", None)
# Fix for https://github.com/ray-project/ray/issues/26729
_skip_env_hook: bool = kwargs.pop("_skip_env_hook", False)
if not logging_format:
logging_format = ray_constants.LOGGER_FORMAT
@ -1264,7 +1265,7 @@ def init(
job_config_json = json.loads(os.environ.get(RAY_JOB_CONFIG_JSON_ENV_VAR))
job_config = ray.job_config.JobConfig.from_json(job_config_json)
if ray_constants.RAY_RUNTIME_ENV_HOOK in os.environ:
if ray_constants.RAY_RUNTIME_ENV_HOOK in os.environ and not _skip_env_hook:
runtime_env = _load_class(os.environ[ray_constants.RAY_RUNTIME_ENV_HOOK])(
job_config.runtime_env
)
@ -1273,7 +1274,7 @@ def init(
# RAY_JOB_CONFIG_JSON_ENV_VAR is only set at ray job manager level and has
# higher priority in case user also provided runtime_env for ray.init()
else:
if ray_constants.RAY_RUNTIME_ENV_HOOK in os.environ:
if ray_constants.RAY_RUNTIME_ENV_HOOK in os.environ and not _skip_env_hook:
runtime_env = _load_class(os.environ[ray_constants.RAY_RUNTIME_ENV_HOOK])(
runtime_env
)

View file

@ -49,8 +49,8 @@ def init_and_serve_lazy():
cluster.wait_for_nodes(1)
address = cluster.address
def connect(job_config=None):
ray.init(address=address, job_config=job_config)
def connect(job_config=None, **ray_init_kwargs):
ray.init(address=address, job_config=job_config, **ray_init_kwargs)
server_handle = ray_client_server.serve("localhost:50051", connect)
yield server_handle

View file

@ -44,11 +44,16 @@ def _hook(env):
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_runtime_env_hook():
script = """
@pytest.mark.parametrize("skip_hook", [True, False])
def test_runtime_env_hook(skip_hook):
ray_init_snippet = "ray.init(_skip_env_hook=True)" if skip_hook else ""
script = f"""
import ray
import os
{ray_init_snippet}
@ray.remote
def f():
return os.environ.get("HOOK_KEY")
@ -61,9 +66,28 @@ print(ray.get(f.remote()))
)
out_str = proc.stdout.read().decode("ascii") + proc.stderr.read().decode("ascii")
print(out_str)
if skip_hook:
assert "HOOK_VALUE" not in out_str
else:
assert "HOOK_VALUE" in out_str
def test_env_hook_skipped_for_ray_client(start_cluster, monkeypatch):
monkeypatch.setenv("RAY_RUNTIME_ENV_HOOK", "ray.tests.test_output._hook")
cluster, address = start_cluster
ray.init(address)
@ray.remote
def f():
return os.environ.get("HOOK_KEY")
using_ray_client = address.startswith("ray://")
if using_ray_client:
assert ray.get(f.remote()) is None
else:
assert ray.get(f.remote()) == "HOOK_VALUE"
def test_autoscaler_infeasible():
script = """
import ray

View file

@ -77,7 +77,12 @@ class _ClientContext:
logging_level = ray_constants.LOGGER_LEVEL
logging_format = ray_constants.LOGGER_FORMAT
if ray_init_kwargs is not None:
if ray_init_kwargs is None:
ray_init_kwargs = {}
# NOTE(architkulkarni): env_hook is not supported with Ray Client.
ray_init_kwargs["_skip_env_hook"] = True
if ray_init_kwargs.get("logging_level") is not None:
logging_level = ray_init_kwargs["logging_level"]
if ray_init_kwargs.get("logging_format") is not None: