diff --git a/dashboard/modules/runtime_env/runtime_env_agent.py b/dashboard/modules/runtime_env/runtime_env_agent.py index 76514801b..b7e2a5072 100644 --- a/dashboard/modules/runtime_env/runtime_env_agent.py +++ b/dashboard/modules/runtime_env/runtime_env_agent.py @@ -77,7 +77,9 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule, # Use a separate logger for each job. per_job_logger = self.get_or_create_logger(request.job_id) - context = RuntimeEnvContext(self._runtime_env_dir) + context = RuntimeEnvContext( + env_vars=runtime_env.get("env_vars"), + resources_dir=self._runtime_env_dir) setup_conda_or_pip(runtime_env, context, logger=per_job_logger) setup_working_dir(runtime_env, context, logger=per_job_logger) return context diff --git a/python/ray/_private/runtime_env/conda.py b/python/ray/_private/runtime_env/conda.py index 454f145f7..154db3f9a 100644 --- a/python/ray/_private/runtime_env/conda.py +++ b/python/ray/_private/runtime_env/conda.py @@ -13,7 +13,8 @@ from pathlib import Path import ray from ray._private.runtime_env import RuntimeEnvContext -from ray._private.runtime_env.conda_utils import get_or_create_conda_env +from ray._private.runtime_env.conda_utils import (get_conda_activate_commands, + get_or_create_conda_env) from ray._private.utils import try_to_create_directory from ray._private.utils import (get_wheel_filename, get_master_wheel_url, get_release_wheel_url) @@ -75,7 +76,7 @@ def setup_conda_or_pip(runtime_env: dict, return logger.debug(f"Setting up conda or pip for runtime_env: {runtime_env}") - conda_dict = get_conda_dict(runtime_env, context.session_dir) + conda_dict = get_conda_dict(runtime_env, context.resources_dir) if isinstance(runtime_env.get("conda"), str): conda_env_name = runtime_env["conda"] else: @@ -96,9 +97,8 @@ def setup_conda_or_pip(runtime_env: dict, # lock for all conda installs. # See https://github.com/ray-project/ray/issues/17086 file_lock_name = "ray-conda-install.lock" - with FileLock(os.path.join(context.session_dir, file_lock_name)): - conda_dir = os.path.join(context.session_dir, "runtime_resources", - "conda") + with FileLock(os.path.join(context.resources_dir, file_lock_name)): + conda_dir = os.path.join(context.resources_dir, "conda") try_to_create_directory(conda_dir) conda_yaml_path = os.path.join(conda_dir, "environment.yml") with open(conda_yaml_path, "w") as file: @@ -113,7 +113,8 @@ def setup_conda_or_pip(runtime_env: dict, conda_path = os.path.join(conda_dir, conda_env_name) _inject_ray_to_conda_site(conda_path, logger) - context.conda_env_name = conda_env_name + context.py_executable = "python" + context.command_prefix += get_conda_activate_commands(conda_env_name) logger.info(f"Finished setting up runtime environment at {conda_env_name}") diff --git a/python/ray/_private/runtime_env/context.py b/python/ray/_private/runtime_env/context.py index 9f301c930..9fc8111d5 100644 --- a/python/ray/_private/runtime_env/context.py +++ b/python/ray/_private/runtime_env/context.py @@ -1,16 +1,24 @@ import json +import logging +import os +import sys +from typing import Dict, List, Optional + +logger = logging.getLogger(__name__) class RuntimeEnvContext: """A context used to describe the created runtime env.""" def __init__(self, - session_dir: str, - conda_env_name: str = None, - working_dir: str = None): - self.conda_env_name: str = conda_env_name - self.session_dir: str = session_dir - self.working_dir: str = working_dir + command_prefix: List[str] = None, + env_vars: Dict[str, str] = None, + py_executable: Optional[str] = None, + resources_dir: Optional[str] = None): + self.command_prefix = command_prefix or [] + self.env_vars = env_vars or {} + self.py_executable = py_executable or sys.executable + self.resources_dir: str = resources_dir def serialize(self) -> str: return json.dumps(self.__dict__) @@ -18,3 +26,11 @@ class RuntimeEnvContext: @staticmethod def deserialize(json_string): return RuntimeEnvContext(**json.loads(json_string)) + + def exec_worker(self, passthrough_args: List[str]): + os.environ.update(self.env_vars) + exec_command = " ".join([f"exec {self.py_executable}"] + + passthrough_args) + command_str = " && ".join(self.command_prefix + [exec_command]) + logger.info(f"Exec'ing worker with command: {command_str}") + os.execvp("bash", ["bash", "-c", command_str]) diff --git a/python/ray/_private/runtime_env/working_dir.py b/python/ray/_private/runtime_env/working_dir.py index e48cbb135..5b31accde 100644 --- a/python/ray/_private/runtime_env/working_dir.py +++ b/python/ray/_private/runtime_env/working_dir.py @@ -420,12 +420,25 @@ def setup_working_dir(runtime_env: dict, # TODO(edoakes): we should be able to remove this by refactoring the # working_dir setup code into a class instead of using global vars. global _logger, PKG_DIR - prev_logger = _logger - prev_pkg_dir = PKG_DIR - _logger = logger - PKG_DIR = context.session_dir + if logger: + prev_logger = _logger + _logger = logger - context.working_dir = ensure_runtime_env_setup(runtime_env["uris"]) + assert context.resources_dir is not None + prev_pkg_dir = PKG_DIR + PKG_DIR = context.resources_dir + + working_dir = ensure_runtime_env_setup(runtime_env["uris"]) + context.command_prefix += [f"cd {working_dir}"] + + # Insert the working_dir as the first entry in PYTHONPATH. This is + # compatible with users providing their own PYTHONPATH in env_vars. + python_path = working_dir + if "PYTHONPATH" in context.env_vars: + python_path += os.pathsep + context.env_vars["PYTHONPATH"] + context.env_vars["PYTHONPATH"] = python_path PKG_DIR = prev_pkg_dir - _logger = prev_logger + + if logger: + _logger = prev_logger diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index 94d5434d9..24d64fcb3 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -1459,7 +1459,6 @@ def start_raylet(redis_address, sys.executable, setup_worker_path, f"--worker-setup-hook={worker_setup_hook}", - f"--session-dir={session_dir}", worker_path, f"--node-ip-address={node_ip_address}", "--node-manager-port=RAY_NODE_MANAGER_PORT_PLACEHOLDER", @@ -1882,15 +1881,16 @@ def start_monitor(redis_address, return process_info -def start_ray_client_server(redis_address, - ray_client_server_port, - stdout_file=None, - stderr_file=None, - redis_password=None, - fate_share=None, - server_type: str = "proxy", - serialized_runtime_env: Optional[str] = None, - session_dir: Optional[str] = None): +def start_ray_client_server( + redis_address, + ray_client_server_port, + stdout_file=None, + stderr_file=None, + redis_password=None, + fate_share=None, + server_type: str = "proxy", + serialized_runtime_env: Optional[str] = None, + serialized_runtime_env_context: Optional[str] = None): """Run the server process of the Ray client. Args: @@ -1903,6 +1903,8 @@ def start_ray_client_server(redis_address, server_type (str): Whether to start the proxy version of Ray Client. serialized_runtime_env (str|None): If specified, the serialized runtime_env to start the client server in. + serialized_runtime_env_context (str|None): If specified, the serialized + runtime_env_context to start the client server in. Returns: ProcessInfo for the process that was started. @@ -1919,17 +1921,18 @@ def start_ray_client_server(redis_address, conda_shim_flag, # These two args are to use the shim process. "-m", "ray.util.client.server", + "--from-ray-client=True", "--redis-address=" + str(redis_address), "--port=" + str(ray_client_server_port), "--mode=" + server_type ] if redis_password: command.append("--redis-password=" + redis_password) - if serialized_runtime_env: command.append("--serialized-runtime-env=" + serialized_runtime_env) - if session_dir: - command.append(f"--session-dir={session_dir}") + if serialized_runtime_env_context: + command.append("--serialized-runtime-env-context=" + + serialized_runtime_env_context) process_info = start_ray_process( command, ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER, diff --git a/python/ray/node.py b/python/ray/node.py index e061e62b9..49da08caa 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -22,7 +22,6 @@ import ray.ray_constants as ray_constants import ray._private.services import ray._private.utils from ray._private.resource_spec import ResourceSpec -from ray._private.runtime_env import working_dir as working_dir_pkg from ray._private.utils import (try_to_create_directory, try_to_symlink, open_log) @@ -324,10 +323,9 @@ class Node: old_logs_dir = os.path.join(self._logs_dir, "old") try_to_create_directory(old_logs_dir) # Create a directory to be used for runtime environment. - self._resource_dir = os.path.join(self._session_dir, - "runtime_resources") - try_to_create_directory(self._resource_dir) - working_dir_pkg.PKG_DIR = self._resource_dir + self._runtime_env_dir = os.path.join(self._session_dir, + "runtime_resources") + try_to_create_directory(self._runtime_env_dir) def get_resource_spec(self): """Resolve and return the current resource spec for the node.""" @@ -812,7 +810,7 @@ class Node: self._ray_params.worker_setup_hook, self._temp_dir, self._session_dir, - self._resource_dir, + self._runtime_env_dir, self._logs_dir, self.get_resource_spec(), plasma_directory, diff --git a/python/ray/tests/test_runtime_env.py b/python/ray/tests/test_runtime_env.py index d84ed2ee8..2e7536ff3 100644 --- a/python/ray/tests/test_runtime_env.py +++ b/python/ray/tests/test_runtime_env.py @@ -123,12 +123,18 @@ from test_module.test import one def start_client_server(cluster, client_mode): - from ray._private.runtime_env.working_dir import PKG_DIR - if not client_mode: - return (cluster.address, {}, PKG_DIR) - ray.worker._global_node._ray_params.ray_client_server_port = "10003" - ray.worker._global_node.start_ray_client_server() - return ("localhost:10003", {"USE_RAY_CLIENT": "1"}, PKG_DIR) + env = {} + if client_mode: + ray.worker._global_node._ray_params.ray_client_server_port = "10003" + ray.worker._global_node.start_ray_client_server() + address = "localhost:10003" + env["USE_RAY_CLIENT"] = "1" + else: + address = cluster.address + + runtime_env_dir = ray.worker._global_node.get_runtime_env_dir_path() + + return address, env, runtime_env_dir @pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.") @@ -207,7 +213,7 @@ The following test cases are related with runtime env. It following these steps @pytest.mark.parametrize("client_mode", [True, False]) def test_empty_working_dir(ray_start_cluster_head, client_mode): cluster = ray_start_cluster_head - (address, env, PKG_DIR) = start_client_server(cluster, client_mode) + address, env, runtime_env_dir = start_client_server(cluster, client_mode) env["EXIT_AFTER_INIT"] = "1" with tempfile.TemporaryDirectory() as working_dir: runtime_env = f"""{{ @@ -225,7 +231,7 @@ def test_empty_working_dir(ray_start_cluster_head, client_mode): @pytest.mark.parametrize("client_mode", [True, False]) def test_invalid_working_dir(ray_start_cluster_head, working_dir, client_mode): cluster = ray_start_cluster_head - (address, env, PKG_DIR) = start_client_server(cluster, client_mode) + address, env, runtime_env_dir = start_client_server(cluster, client_mode) env["EXIT_AFTER_INIT"] = "1" runtime_env = "{ 'working_dir': 10 }" @@ -261,7 +267,7 @@ def test_invalid_working_dir(ray_start_cluster_head, working_dir, client_mode): @pytest.mark.parametrize("client_mode", [True, False]) def test_single_node(ray_start_cluster_head, working_dir, client_mode): cluster = ray_start_cluster_head - (address, env, PKG_DIR) = start_client_server(cluster, client_mode) + address, env, runtime_env_dir = start_client_server(cluster, client_mode) # Setup runtime env here runtime_env = f"""{{ "working_dir": "{working_dir}" }}""" # Execute the following cmd in driver with runtime_env @@ -270,7 +276,7 @@ def test_single_node(ray_start_cluster_head, working_dir, client_mode): out = run_string_as_driver(script, env) print(out) assert out.strip().split()[-1] == "1000" - assert len(list(Path(PKG_DIR).iterdir())) == 1 + assert len(list(Path(runtime_env_dir).iterdir())) == 1 assert len(kv._internal_kv_list("gcs://")) == 0 @@ -278,7 +284,7 @@ def test_single_node(ray_start_cluster_head, working_dir, client_mode): @pytest.mark.parametrize("client_mode", [True, False]) def test_two_node(two_node_cluster, working_dir, client_mode): cluster, _ = two_node_cluster - (address, env, PKG_DIR) = start_client_server(cluster, client_mode) + address, env, runtime_env_dir = start_client_server(cluster, client_mode) # Testing runtime env with working_dir runtime_env = f"""{{ "working_dir": "{working_dir}" }}""" # Execute the following cmd in driver with runtime_env @@ -286,7 +292,7 @@ def test_two_node(two_node_cluster, working_dir, client_mode): script = driver_script.format(**locals()) out = run_string_as_driver(script, env) assert out.strip().split()[-1] == "1000" - assert len(list(Path(PKG_DIR).iterdir())) == 1 + assert len(list(Path(runtime_env_dir).iterdir())) == 1 assert len(kv._internal_kv_list("gcs://")) == 0 @@ -294,7 +300,7 @@ def test_two_node(two_node_cluster, working_dir, client_mode): @pytest.mark.parametrize("client_mode", [True, False]) def test_two_node_module(two_node_cluster, working_dir, client_mode): cluster, _ = two_node_cluster - (address, env, PKG_DIR) = start_client_server(cluster, client_mode) + address, env, runtime_env_dir = start_client_server(cluster, client_mode) # test runtime_env iwth py_modules runtime_env = """{ "py_modules": [test_module.__path__[0]] }""" # Execute the following cmd in driver with runtime_env @@ -302,7 +308,7 @@ def test_two_node_module(two_node_cluster, working_dir, client_mode): script = driver_script.format(**locals()) out = run_string_as_driver(script, env) assert out.strip().split()[-1] == "1000" - assert len(list(Path(PKG_DIR).iterdir())) == 1 + assert len(list(Path(runtime_env_dir).iterdir())) == 1 @pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.") @@ -311,7 +317,7 @@ def test_two_node_local_file(two_node_cluster, working_dir, client_mode): with open(os.path.join(working_dir, "test_file"), "w") as f: f.write("1") cluster, _ = two_node_cluster - (address, env, PKG_DIR) = start_client_server(cluster, client_mode) + address, env, runtime_env_dir = start_client_server(cluster, client_mode) # test runtime_env iwth working_dir runtime_env = f"""{{ "working_dir": "{working_dir}" }}""" # Execute the following cmd in driver with runtime_env @@ -322,7 +328,7 @@ print(sum([int(v) for v in vals])) script = driver_script.format(**locals()) out = run_string_as_driver(script, env) assert out.strip().split()[-1] == "1000" - assert len(list(Path(PKG_DIR).iterdir())) == 1 + assert len(list(Path(runtime_env_dir).iterdir())) == 1 assert len(kv._internal_kv_list("gcs://")) == 0 @@ -330,7 +336,7 @@ print(sum([int(v) for v in vals])) @pytest.mark.parametrize("client_mode", [True, False]) def test_exclusion(ray_start_cluster_head, working_dir, client_mode): cluster = ray_start_cluster_head - (address, env, PKG_DIR) = start_client_server(cluster, client_mode) + address, env, runtime_env_dir = start_client_server(cluster, client_mode) working_path = Path(working_dir) create_file(working_path / "tmp_dir" / "test_1") @@ -386,7 +392,7 @@ def test_exclusion(ray_start_cluster_head, working_dir, client_mode): @pytest.mark.parametrize("client_mode", [True, False]) def test_exclusion_2(ray_start_cluster_head, working_dir, client_mode): cluster = ray_start_cluster_head - (address, env, PKG_DIR) = start_client_server(cluster, client_mode) + address, env, runtime_env_dir = start_client_server(cluster, client_mode) working_path = Path(working_dir) def create_file(p): @@ -450,7 +456,7 @@ cache/ @pytest.mark.parametrize("client_mode", [True, False]) def test_runtime_env_getter(ray_start_cluster_head, working_dir, client_mode): cluster = ray_start_cluster_head - (address, env, PKG_DIR) = start_client_server(cluster, client_mode) + address, env, runtime_env_dir = start_client_server(cluster, client_mode) runtime_env = f"""{{ "working_dir": "{working_dir}" }}""" # Execute the following cmd in driver with runtime_env execute_statement = """ @@ -465,7 +471,7 @@ print(ray.get_runtime_context().runtime_env["working_dir"]) @pytest.mark.parametrize("client_mode", [True, False]) def test_two_node_uri(two_node_cluster, working_dir, client_mode): cluster, _ = two_node_cluster - (address, env, PKG_DIR) = start_client_server(cluster, client_mode) + address, env, runtime_env_dir = start_client_server(cluster, client_mode) with tempfile.NamedTemporaryFile(suffix="zip") as tmp_file: pkg_name = working_dir_pkg.get_project_package_name( working_dir, [], []) @@ -479,7 +485,7 @@ def test_two_node_uri(two_node_cluster, working_dir, client_mode): script = driver_script.format(**locals()) out = run_string_as_driver(script, env) assert out.strip().split()[-1] == "1000" - assert len(list(Path(PKG_DIR).iterdir())) == 1 + assert len(list(Path(runtime_env_dir).iterdir())) == 1 # pinned uri will not be deleted print(list(kv._internal_kv_list(""))) assert len(kv._internal_kv_list("pingcs://")) == 1 @@ -489,7 +495,7 @@ def test_two_node_uri(two_node_cluster, working_dir, client_mode): @pytest.mark.parametrize("client_mode", [True, False]) def test_regular_actors(ray_start_cluster_head, working_dir, client_mode): cluster = ray_start_cluster_head - (address, env, PKG_DIR) = start_client_server(cluster, client_mode) + address, env, runtime_env_dir = start_client_server(cluster, client_mode) runtime_env = f"""{{ "working_dir": "{working_dir}" }}""" # Execute the following cmd in driver with runtime_env execute_statement = """ @@ -499,7 +505,7 @@ print(sum(ray.get([test_actor.one.remote()] * 1000))) script = driver_script.format(**locals()) out = run_string_as_driver(script, env) assert out.strip().split()[-1] == "1000" - assert len(list(Path(PKG_DIR).iterdir())) == 1 + assert len(list(Path(runtime_env_dir).iterdir())) == 1 assert len(kv._internal_kv_list("gcs://")) == 0 @@ -507,7 +513,7 @@ print(sum(ray.get([test_actor.one.remote()] * 1000))) @pytest.mark.parametrize("client_mode", [True, False]) def test_detached_actors(ray_start_cluster_head, working_dir, client_mode): cluster = ray_start_cluster_head - (address, env, PKG_DIR) = start_client_server(cluster, client_mode) + address, env, runtime_env_dir = start_client_server(cluster, client_mode) runtime_env = f"""{{ "working_dir": "{working_dir}" }}""" # Execute the following cmd in driver with runtime_env execute_statement = """ @@ -519,15 +525,14 @@ print(sum(ray.get([test_actor.one.remote()] * 1000))) assert out.strip().split()[-1] == "1000" # It's a detached actors, so it should still be there assert len(kv._internal_kv_list("gcs://")) == 1 - assert len(list(Path(PKG_DIR).iterdir())) == 2 - pkg_dir = [f for f in Path(PKG_DIR).glob("*") if f.is_dir()][0] - import sys + assert len(list(Path(runtime_env_dir).iterdir())) == 2 + pkg_dir = [f for f in Path(runtime_env_dir).glob("*") if f.is_dir()][0] sys.path.insert(0, str(pkg_dir)) test_actor = ray.get_actor("test_actor") assert sum(ray.get([test_actor.one.remote()] * 1000)) == 1000 ray.kill(test_actor) time.sleep(5) - assert len(list(Path(PKG_DIR).iterdir())) == 1 + assert len(list(Path(runtime_env_dir).iterdir())) == 1 assert len(kv._internal_kv_list("gcs://")) == 0 @@ -536,7 +541,7 @@ def test_jobconfig_compatible_1(ray_start_cluster_head, working_dir): # start job_config=None # start job_config=something cluster = ray_start_cluster_head - (address, env, PKG_DIR) = start_client_server(cluster, True) + address, env, runtime_env_dir = start_client_server(cluster, True) runtime_env = None # To make the first one hanging there execute_statement = """ @@ -562,7 +567,7 @@ def test_jobconfig_compatible_2(ray_start_cluster_head, working_dir): # start job_config=something # start job_config=None cluster = ray_start_cluster_head - (address, env, PKG_DIR) = start_client_server(cluster, True) + address, env, runtime_env_dir = start_client_server(cluster, True) runtime_env = """{ "py_modules": [test_module.__path__[0]] }""" # To make the first one hanging there execute_statement = """ @@ -587,7 +592,7 @@ def test_jobconfig_compatible_3(ray_start_cluster_head, working_dir): # start job_config=something # start job_config=something else cluster = ray_start_cluster_head - (address, env, PKG_DIR) = start_client_server(cluster, True) + address, env, runtime_env_dir = start_client_server(cluster, True) runtime_env = """{ "py_modules": [test_module.__path__[0]] }""" # To make the first one hanging ther execute_statement = """ @@ -623,7 +628,7 @@ def one(): cluster = Cluster() cluster.add_node(num_cpus=1) ray.init(address=cluster.address) - (address, env, PKG_DIR) = start_client_server(cluster, True) + address, env, runtime_env_dir = start_client_server(cluster, True) script = f""" import ray import ray.util diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 67d48d8f2..6494c7eef 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -5,7 +5,6 @@ import grpc import logging from itertools import chain import json -import os import socket import sys from threading import Lock, Thread, RLock @@ -23,6 +22,7 @@ from ray.util.client.common import (ClientServerHandle, CLIENT_SERVER_MAX_THREADS, GRPC_OPTIONS) from ray._private.client_mode_hook import disable_client_hook from ray._private.parameter import RayParams +from ray._private.runtime_env import RuntimeEnvContext import ray._private.runtime_env.working_dir as working_dir_pkg from ray._private.services import ProcessInfo, start_ray_client_server from ray._private.utils import detect_fate_sharing_support @@ -215,26 +215,16 @@ class ProxyManager(): output, error = self.node.get_log_file_handles( f"ray_client_server_{specific_server.port}", unique=True) + serialized_runtime_env = job_config.get_serialized_runtime_env() + runtime_env = json.loads(serialized_runtime_env) + # Set up the working_dir for the server. # TODO(edoakes): this should go be unified with the worker setup code # by going through the runtime_env agent. - uris = job_config.get_runtime_env_uris() if job_config else [] - if uris: - # Download and set up the working_dir locally. - working_dir = working_dir_pkg.ensure_runtime_env_setup(uris) - - # Set PYTHONPATH in the environment variables so the working_dir - # is included in the module search path. - runtime_env = job_config.runtime_env - env_vars = runtime_env.get("env_vars", None) or {} - python_path = working_dir - if "PYTHONPATH" in env_vars: - python_path += (os.pathsep + runtime_env["PYTHONPATH"]) - env_vars["PYTHONPATH"] = python_path - runtime_env["env_vars"] = env_vars - job_config.set_runtime_env(runtime_env) - - serialized_runtime_env = job_config.get_serialized_runtime_env() + context = RuntimeEnvContext( + env_vars=runtime_env.get("env_vars"), + resources_dir=self.node.get_runtime_env_dir_path()) + working_dir_pkg.setup_working_dir(runtime_env, context) proc = start_ray_client_server( self.redis_address, @@ -244,7 +234,7 @@ class ProxyManager(): fate_share=self.fate_share, server_type="specific-server", serialized_runtime_env=serialized_runtime_env, - session_dir=self.node.get_session_dir_path(), + serialized_runtime_env_context=context.serialize(), redis_password=self._redis_password) # Wait for the process being run transitions from the shim process diff --git a/python/ray/worker.py b/python/ray/worker.py index da91da904..9ee9a550c 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1393,9 +1393,10 @@ def connect(node, worker.gcs_client = worker.core_worker.get_gcs_client() # If it's a driver and it's not coming from ray client, we'll prepare the - # environment here. If it's ray client, the environmen will be prepared + # environment here. If it's ray client, the environment will be prepared # at the server side. if mode == SCRIPT_MODE and not job_config.client_job: + working_dir_pkg.PKG_DIR = worker.node.get_runtime_env_dir_path() working_dir_pkg.upload_runtime_env_package_if_needed(job_config) # Notify raylet that the core worker is ready. diff --git a/python/ray/workers/setup_runtime_env.py b/python/ray/workers/setup_runtime_env.py index e823439a9..f9e48a496 100644 --- a/python/ray/workers/setup_runtime_env.py +++ b/python/ray/workers/setup_runtime_env.py @@ -1,4 +1,3 @@ -import os import sys import argparse import json @@ -6,7 +5,6 @@ import logging from ray._private.runtime_env import RuntimeEnvContext from ray._private.runtime_env.conda import setup_conda_or_pip -from ray._private.runtime_env.conda_utils import get_conda_activate_commands logger = logging.getLogger(__name__) parser = argparse.ArgumentParser() @@ -15,15 +13,12 @@ parser.add_argument( "--serialized-runtime-env", type=str, help="the serialized parsed runtime env dict") - parser.add_argument( "--serialized-runtime-env-context", type=str, help="the serialized runtime env context") - -# The worker is not set up yet, so we can't get session_dir from the worker. parser.add_argument( - "--session-dir", type=str, help="the directory for the current session") + "--from-ray-client", type=bool, required=False, default=False) def setup_worker(input_args): @@ -31,54 +26,21 @@ def setup_worker(input_args): # minus the python executable, e.g. default_worker.py --node-ip-address=... args, remaining_args = parser.parse_known_args(args=input_args) - commands = [] - py_executable: str = sys.executable runtime_env: dict = json.loads(args.serialized_runtime_env or "{}") runtime_env_context: RuntimeEnvContext = None if args.serialized_runtime_env_context: runtime_env_context = RuntimeEnvContext.deserialize( args.serialized_runtime_env_context) + else: + runtime_env_context = RuntimeEnvContext( + env_vars=runtime_env.get("env_vars")) # Ray client server setups runtime env by itself instead of agent. - if runtime_env.get("conda") or runtime_env.get("pip"): - if not args.serialized_runtime_env_context: - runtime_env_context = RuntimeEnvContext(args.session_dir) + if args.from_ray_client: + if runtime_env.get("conda") or runtime_env.get("pip"): setup_conda_or_pip(runtime_env, runtime_env_context, logger=logger) - if runtime_env_context and runtime_env_context.working_dir is not None: - commands += [f"cd {runtime_env_context.working_dir}"] - - # Insert the working_dir as the first entry in PYTHONPATH. This is - # compatible with users providing their own PYTHONPATH in env_vars. - env_vars = runtime_env.get("env_vars", None) or {} - python_path = runtime_env_context.working_dir - if "PYTHONPATH" in env_vars: - python_path += os.pathsep + runtime_env["PYTHONPATH"] - env_vars["PYTHONPATH"] = python_path - runtime_env["env_vars"] = env_vars - - # Add a conda activate command prefix if using a conda env. - if runtime_env_context and runtime_env_context.conda_env_name is not None: - py_executable = "python" - conda_activate_commands = get_conda_activate_commands( - runtime_env_context.conda_env_name) - if (conda_activate_commands): - commands += conda_activate_commands - elif runtime_env.get("conda"): - logger.warning( - "Conda env name is not found in context, " - "but conda exists in runtime env. The runtime env %s, " - "the context %s.", args.serialized_runtime_env, - args.serialized_runtime_env_context) - - commands += [" ".join([f"exec {py_executable}"] + remaining_args)] - command_str = " && ".join(commands) - - # update env vars - if runtime_env.get("env_vars"): - env_vars = runtime_env["env_vars"] - os.environ.update(env_vars) - os.execvp("bash", ["bash", "-c", command_str]) + runtime_env_context.exec_worker(remaining_args) if __name__ == "__main__": diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index a90a9efae..864966822 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -331,10 +331,10 @@ Process WorkerPool::StartWorkerProcess( // The "shim process" setup worker is not needed, so do not run it. // Check that the arg really is the path to the setup worker before erasing it, to // prevent breaking tests that mock out the worker command args. - if (worker_command_args.size() >= 4 && + if (worker_command_args.size() >= 3 && worker_command_args[1].find(kSetupWorkerFilename) != std::string::npos) { worker_command_args.erase(worker_command_args.begin() + 1, - worker_command_args.begin() + 4); + worker_command_args.begin() + 3); } }