[runtime_env] Move worker process startup logic to context (#18341)

This commit is contained in:
Edward Oakes 2021-09-08 17:08:27 -05:00 committed by GitHub
parent dd6abed6ce
commit f0555f88d6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 128 additions and 137 deletions

View file

@ -77,7 +77,9 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
# Use a separate logger for each job.
per_job_logger = self.get_or_create_logger(request.job_id)
context = RuntimeEnvContext(self._runtime_env_dir)
context = RuntimeEnvContext(
env_vars=runtime_env.get("env_vars"),
resources_dir=self._runtime_env_dir)
setup_conda_or_pip(runtime_env, context, logger=per_job_logger)
setup_working_dir(runtime_env, context, logger=per_job_logger)
return context

View file

@ -13,7 +13,8 @@ from pathlib import Path
import ray
from ray._private.runtime_env import RuntimeEnvContext
from ray._private.runtime_env.conda_utils import get_or_create_conda_env
from ray._private.runtime_env.conda_utils import (get_conda_activate_commands,
get_or_create_conda_env)
from ray._private.utils import try_to_create_directory
from ray._private.utils import (get_wheel_filename, get_master_wheel_url,
get_release_wheel_url)
@ -75,7 +76,7 @@ def setup_conda_or_pip(runtime_env: dict,
return
logger.debug(f"Setting up conda or pip for runtime_env: {runtime_env}")
conda_dict = get_conda_dict(runtime_env, context.session_dir)
conda_dict = get_conda_dict(runtime_env, context.resources_dir)
if isinstance(runtime_env.get("conda"), str):
conda_env_name = runtime_env["conda"]
else:
@ -96,9 +97,8 @@ def setup_conda_or_pip(runtime_env: dict,
# lock for all conda installs.
# See https://github.com/ray-project/ray/issues/17086
file_lock_name = "ray-conda-install.lock"
with FileLock(os.path.join(context.session_dir, file_lock_name)):
conda_dir = os.path.join(context.session_dir, "runtime_resources",
"conda")
with FileLock(os.path.join(context.resources_dir, file_lock_name)):
conda_dir = os.path.join(context.resources_dir, "conda")
try_to_create_directory(conda_dir)
conda_yaml_path = os.path.join(conda_dir, "environment.yml")
with open(conda_yaml_path, "w") as file:
@ -113,7 +113,8 @@ def setup_conda_or_pip(runtime_env: dict,
conda_path = os.path.join(conda_dir, conda_env_name)
_inject_ray_to_conda_site(conda_path, logger)
context.conda_env_name = conda_env_name
context.py_executable = "python"
context.command_prefix += get_conda_activate_commands(conda_env_name)
logger.info(f"Finished setting up runtime environment at {conda_env_name}")

View file

@ -1,16 +1,24 @@
import json
import logging
import os
import sys
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
class RuntimeEnvContext:
"""A context used to describe the created runtime env."""
def __init__(self,
session_dir: str,
conda_env_name: str = None,
working_dir: str = None):
self.conda_env_name: str = conda_env_name
self.session_dir: str = session_dir
self.working_dir: str = working_dir
command_prefix: List[str] = None,
env_vars: Dict[str, str] = None,
py_executable: Optional[str] = None,
resources_dir: Optional[str] = None):
self.command_prefix = command_prefix or []
self.env_vars = env_vars or {}
self.py_executable = py_executable or sys.executable
self.resources_dir: str = resources_dir
def serialize(self) -> str:
return json.dumps(self.__dict__)
@ -18,3 +26,11 @@ class RuntimeEnvContext:
@staticmethod
def deserialize(json_string):
return RuntimeEnvContext(**json.loads(json_string))
def exec_worker(self, passthrough_args: List[str]):
os.environ.update(self.env_vars)
exec_command = " ".join([f"exec {self.py_executable}"] +
passthrough_args)
command_str = " && ".join(self.command_prefix + [exec_command])
logger.info(f"Exec'ing worker with command: {command_str}")
os.execvp("bash", ["bash", "-c", command_str])

View file

@ -420,12 +420,25 @@ def setup_working_dir(runtime_env: dict,
# TODO(edoakes): we should be able to remove this by refactoring the
# working_dir setup code into a class instead of using global vars.
global _logger, PKG_DIR
prev_logger = _logger
prev_pkg_dir = PKG_DIR
_logger = logger
PKG_DIR = context.session_dir
if logger:
prev_logger = _logger
_logger = logger
context.working_dir = ensure_runtime_env_setup(runtime_env["uris"])
assert context.resources_dir is not None
prev_pkg_dir = PKG_DIR
PKG_DIR = context.resources_dir
working_dir = ensure_runtime_env_setup(runtime_env["uris"])
context.command_prefix += [f"cd {working_dir}"]
# Insert the working_dir as the first entry in PYTHONPATH. This is
# compatible with users providing their own PYTHONPATH in env_vars.
python_path = working_dir
if "PYTHONPATH" in context.env_vars:
python_path += os.pathsep + context.env_vars["PYTHONPATH"]
context.env_vars["PYTHONPATH"] = python_path
PKG_DIR = prev_pkg_dir
_logger = prev_logger
if logger:
_logger = prev_logger

View file

@ -1459,7 +1459,6 @@ def start_raylet(redis_address,
sys.executable,
setup_worker_path,
f"--worker-setup-hook={worker_setup_hook}",
f"--session-dir={session_dir}",
worker_path,
f"--node-ip-address={node_ip_address}",
"--node-manager-port=RAY_NODE_MANAGER_PORT_PLACEHOLDER",
@ -1882,15 +1881,16 @@ def start_monitor(redis_address,
return process_info
def start_ray_client_server(redis_address,
ray_client_server_port,
stdout_file=None,
stderr_file=None,
redis_password=None,
fate_share=None,
server_type: str = "proxy",
serialized_runtime_env: Optional[str] = None,
session_dir: Optional[str] = None):
def start_ray_client_server(
redis_address,
ray_client_server_port,
stdout_file=None,
stderr_file=None,
redis_password=None,
fate_share=None,
server_type: str = "proxy",
serialized_runtime_env: Optional[str] = None,
serialized_runtime_env_context: Optional[str] = None):
"""Run the server process of the Ray client.
Args:
@ -1903,6 +1903,8 @@ def start_ray_client_server(redis_address,
server_type (str): Whether to start the proxy version of Ray Client.
serialized_runtime_env (str|None): If specified, the serialized
runtime_env to start the client server in.
serialized_runtime_env_context (str|None): If specified, the serialized
runtime_env_context to start the client server in.
Returns:
ProcessInfo for the process that was started.
@ -1919,17 +1921,18 @@ def start_ray_client_server(redis_address,
conda_shim_flag, # These two args are to use the shim process.
"-m",
"ray.util.client.server",
"--from-ray-client=True",
"--redis-address=" + str(redis_address),
"--port=" + str(ray_client_server_port),
"--mode=" + server_type
]
if redis_password:
command.append("--redis-password=" + redis_password)
if serialized_runtime_env:
command.append("--serialized-runtime-env=" + serialized_runtime_env)
if session_dir:
command.append(f"--session-dir={session_dir}")
if serialized_runtime_env_context:
command.append("--serialized-runtime-env-context=" +
serialized_runtime_env_context)
process_info = start_ray_process(
command,
ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER,

View file

@ -22,7 +22,6 @@ import ray.ray_constants as ray_constants
import ray._private.services
import ray._private.utils
from ray._private.resource_spec import ResourceSpec
from ray._private.runtime_env import working_dir as working_dir_pkg
from ray._private.utils import (try_to_create_directory, try_to_symlink,
open_log)
@ -324,10 +323,9 @@ class Node:
old_logs_dir = os.path.join(self._logs_dir, "old")
try_to_create_directory(old_logs_dir)
# Create a directory to be used for runtime environment.
self._resource_dir = os.path.join(self._session_dir,
"runtime_resources")
try_to_create_directory(self._resource_dir)
working_dir_pkg.PKG_DIR = self._resource_dir
self._runtime_env_dir = os.path.join(self._session_dir,
"runtime_resources")
try_to_create_directory(self._runtime_env_dir)
def get_resource_spec(self):
"""Resolve and return the current resource spec for the node."""
@ -812,7 +810,7 @@ class Node:
self._ray_params.worker_setup_hook,
self._temp_dir,
self._session_dir,
self._resource_dir,
self._runtime_env_dir,
self._logs_dir,
self.get_resource_spec(),
plasma_directory,

View file

@ -123,12 +123,18 @@ from test_module.test import one
def start_client_server(cluster, client_mode):
from ray._private.runtime_env.working_dir import PKG_DIR
if not client_mode:
return (cluster.address, {}, PKG_DIR)
ray.worker._global_node._ray_params.ray_client_server_port = "10003"
ray.worker._global_node.start_ray_client_server()
return ("localhost:10003", {"USE_RAY_CLIENT": "1"}, PKG_DIR)
env = {}
if client_mode:
ray.worker._global_node._ray_params.ray_client_server_port = "10003"
ray.worker._global_node.start_ray_client_server()
address = "localhost:10003"
env["USE_RAY_CLIENT"] = "1"
else:
address = cluster.address
runtime_env_dir = ray.worker._global_node.get_runtime_env_dir_path()
return address, env, runtime_env_dir
@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.")
@ -207,7 +213,7 @@ The following test cases are related with runtime env. It following these steps
@pytest.mark.parametrize("client_mode", [True, False])
def test_empty_working_dir(ray_start_cluster_head, client_mode):
cluster = ray_start_cluster_head
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
address, env, runtime_env_dir = start_client_server(cluster, client_mode)
env["EXIT_AFTER_INIT"] = "1"
with tempfile.TemporaryDirectory() as working_dir:
runtime_env = f"""{{
@ -225,7 +231,7 @@ def test_empty_working_dir(ray_start_cluster_head, client_mode):
@pytest.mark.parametrize("client_mode", [True, False])
def test_invalid_working_dir(ray_start_cluster_head, working_dir, client_mode):
cluster = ray_start_cluster_head
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
address, env, runtime_env_dir = start_client_server(cluster, client_mode)
env["EXIT_AFTER_INIT"] = "1"
runtime_env = "{ 'working_dir': 10 }"
@ -261,7 +267,7 @@ def test_invalid_working_dir(ray_start_cluster_head, working_dir, client_mode):
@pytest.mark.parametrize("client_mode", [True, False])
def test_single_node(ray_start_cluster_head, working_dir, client_mode):
cluster = ray_start_cluster_head
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
address, env, runtime_env_dir = start_client_server(cluster, client_mode)
# Setup runtime env here
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
# Execute the following cmd in driver with runtime_env
@ -270,7 +276,7 @@ def test_single_node(ray_start_cluster_head, working_dir, client_mode):
out = run_string_as_driver(script, env)
print(out)
assert out.strip().split()[-1] == "1000"
assert len(list(Path(PKG_DIR).iterdir())) == 1
assert len(list(Path(runtime_env_dir).iterdir())) == 1
assert len(kv._internal_kv_list("gcs://")) == 0
@ -278,7 +284,7 @@ def test_single_node(ray_start_cluster_head, working_dir, client_mode):
@pytest.mark.parametrize("client_mode", [True, False])
def test_two_node(two_node_cluster, working_dir, client_mode):
cluster, _ = two_node_cluster
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
address, env, runtime_env_dir = start_client_server(cluster, client_mode)
# Testing runtime env with working_dir
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
# Execute the following cmd in driver with runtime_env
@ -286,7 +292,7 @@ def test_two_node(two_node_cluster, working_dir, client_mode):
script = driver_script.format(**locals())
out = run_string_as_driver(script, env)
assert out.strip().split()[-1] == "1000"
assert len(list(Path(PKG_DIR).iterdir())) == 1
assert len(list(Path(runtime_env_dir).iterdir())) == 1
assert len(kv._internal_kv_list("gcs://")) == 0
@ -294,7 +300,7 @@ def test_two_node(two_node_cluster, working_dir, client_mode):
@pytest.mark.parametrize("client_mode", [True, False])
def test_two_node_module(two_node_cluster, working_dir, client_mode):
cluster, _ = two_node_cluster
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
address, env, runtime_env_dir = start_client_server(cluster, client_mode)
# test runtime_env iwth py_modules
runtime_env = """{ "py_modules": [test_module.__path__[0]] }"""
# Execute the following cmd in driver with runtime_env
@ -302,7 +308,7 @@ def test_two_node_module(two_node_cluster, working_dir, client_mode):
script = driver_script.format(**locals())
out = run_string_as_driver(script, env)
assert out.strip().split()[-1] == "1000"
assert len(list(Path(PKG_DIR).iterdir())) == 1
assert len(list(Path(runtime_env_dir).iterdir())) == 1
@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.")
@ -311,7 +317,7 @@ def test_two_node_local_file(two_node_cluster, working_dir, client_mode):
with open(os.path.join(working_dir, "test_file"), "w") as f:
f.write("1")
cluster, _ = two_node_cluster
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
address, env, runtime_env_dir = start_client_server(cluster, client_mode)
# test runtime_env iwth working_dir
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
# Execute the following cmd in driver with runtime_env
@ -322,7 +328,7 @@ print(sum([int(v) for v in vals]))
script = driver_script.format(**locals())
out = run_string_as_driver(script, env)
assert out.strip().split()[-1] == "1000"
assert len(list(Path(PKG_DIR).iterdir())) == 1
assert len(list(Path(runtime_env_dir).iterdir())) == 1
assert len(kv._internal_kv_list("gcs://")) == 0
@ -330,7 +336,7 @@ print(sum([int(v) for v in vals]))
@pytest.mark.parametrize("client_mode", [True, False])
def test_exclusion(ray_start_cluster_head, working_dir, client_mode):
cluster = ray_start_cluster_head
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
address, env, runtime_env_dir = start_client_server(cluster, client_mode)
working_path = Path(working_dir)
create_file(working_path / "tmp_dir" / "test_1")
@ -386,7 +392,7 @@ def test_exclusion(ray_start_cluster_head, working_dir, client_mode):
@pytest.mark.parametrize("client_mode", [True, False])
def test_exclusion_2(ray_start_cluster_head, working_dir, client_mode):
cluster = ray_start_cluster_head
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
address, env, runtime_env_dir = start_client_server(cluster, client_mode)
working_path = Path(working_dir)
def create_file(p):
@ -450,7 +456,7 @@ cache/
@pytest.mark.parametrize("client_mode", [True, False])
def test_runtime_env_getter(ray_start_cluster_head, working_dir, client_mode):
cluster = ray_start_cluster_head
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
address, env, runtime_env_dir = start_client_server(cluster, client_mode)
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
# Execute the following cmd in driver with runtime_env
execute_statement = """
@ -465,7 +471,7 @@ print(ray.get_runtime_context().runtime_env["working_dir"])
@pytest.mark.parametrize("client_mode", [True, False])
def test_two_node_uri(two_node_cluster, working_dir, client_mode):
cluster, _ = two_node_cluster
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
address, env, runtime_env_dir = start_client_server(cluster, client_mode)
with tempfile.NamedTemporaryFile(suffix="zip") as tmp_file:
pkg_name = working_dir_pkg.get_project_package_name(
working_dir, [], [])
@ -479,7 +485,7 @@ def test_two_node_uri(two_node_cluster, working_dir, client_mode):
script = driver_script.format(**locals())
out = run_string_as_driver(script, env)
assert out.strip().split()[-1] == "1000"
assert len(list(Path(PKG_DIR).iterdir())) == 1
assert len(list(Path(runtime_env_dir).iterdir())) == 1
# pinned uri will not be deleted
print(list(kv._internal_kv_list("")))
assert len(kv._internal_kv_list("pingcs://")) == 1
@ -489,7 +495,7 @@ def test_two_node_uri(two_node_cluster, working_dir, client_mode):
@pytest.mark.parametrize("client_mode", [True, False])
def test_regular_actors(ray_start_cluster_head, working_dir, client_mode):
cluster = ray_start_cluster_head
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
address, env, runtime_env_dir = start_client_server(cluster, client_mode)
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
# Execute the following cmd in driver with runtime_env
execute_statement = """
@ -499,7 +505,7 @@ print(sum(ray.get([test_actor.one.remote()] * 1000)))
script = driver_script.format(**locals())
out = run_string_as_driver(script, env)
assert out.strip().split()[-1] == "1000"
assert len(list(Path(PKG_DIR).iterdir())) == 1
assert len(list(Path(runtime_env_dir).iterdir())) == 1
assert len(kv._internal_kv_list("gcs://")) == 0
@ -507,7 +513,7 @@ print(sum(ray.get([test_actor.one.remote()] * 1000)))
@pytest.mark.parametrize("client_mode", [True, False])
def test_detached_actors(ray_start_cluster_head, working_dir, client_mode):
cluster = ray_start_cluster_head
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
address, env, runtime_env_dir = start_client_server(cluster, client_mode)
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
# Execute the following cmd in driver with runtime_env
execute_statement = """
@ -519,15 +525,14 @@ print(sum(ray.get([test_actor.one.remote()] * 1000)))
assert out.strip().split()[-1] == "1000"
# It's a detached actors, so it should still be there
assert len(kv._internal_kv_list("gcs://")) == 1
assert len(list(Path(PKG_DIR).iterdir())) == 2
pkg_dir = [f for f in Path(PKG_DIR).glob("*") if f.is_dir()][0]
import sys
assert len(list(Path(runtime_env_dir).iterdir())) == 2
pkg_dir = [f for f in Path(runtime_env_dir).glob("*") if f.is_dir()][0]
sys.path.insert(0, str(pkg_dir))
test_actor = ray.get_actor("test_actor")
assert sum(ray.get([test_actor.one.remote()] * 1000)) == 1000
ray.kill(test_actor)
time.sleep(5)
assert len(list(Path(PKG_DIR).iterdir())) == 1
assert len(list(Path(runtime_env_dir).iterdir())) == 1
assert len(kv._internal_kv_list("gcs://")) == 0
@ -536,7 +541,7 @@ def test_jobconfig_compatible_1(ray_start_cluster_head, working_dir):
# start job_config=None
# start job_config=something
cluster = ray_start_cluster_head
(address, env, PKG_DIR) = start_client_server(cluster, True)
address, env, runtime_env_dir = start_client_server(cluster, True)
runtime_env = None
# To make the first one hanging there
execute_statement = """
@ -562,7 +567,7 @@ def test_jobconfig_compatible_2(ray_start_cluster_head, working_dir):
# start job_config=something
# start job_config=None
cluster = ray_start_cluster_head
(address, env, PKG_DIR) = start_client_server(cluster, True)
address, env, runtime_env_dir = start_client_server(cluster, True)
runtime_env = """{ "py_modules": [test_module.__path__[0]] }"""
# To make the first one hanging there
execute_statement = """
@ -587,7 +592,7 @@ def test_jobconfig_compatible_3(ray_start_cluster_head, working_dir):
# start job_config=something
# start job_config=something else
cluster = ray_start_cluster_head
(address, env, PKG_DIR) = start_client_server(cluster, True)
address, env, runtime_env_dir = start_client_server(cluster, True)
runtime_env = """{ "py_modules": [test_module.__path__[0]] }"""
# To make the first one hanging ther
execute_statement = """
@ -623,7 +628,7 @@ def one():
cluster = Cluster()
cluster.add_node(num_cpus=1)
ray.init(address=cluster.address)
(address, env, PKG_DIR) = start_client_server(cluster, True)
address, env, runtime_env_dir = start_client_server(cluster, True)
script = f"""
import ray
import ray.util

View file

@ -5,7 +5,6 @@ import grpc
import logging
from itertools import chain
import json
import os
import socket
import sys
from threading import Lock, Thread, RLock
@ -23,6 +22,7 @@ from ray.util.client.common import (ClientServerHandle,
CLIENT_SERVER_MAX_THREADS, GRPC_OPTIONS)
from ray._private.client_mode_hook import disable_client_hook
from ray._private.parameter import RayParams
from ray._private.runtime_env import RuntimeEnvContext
import ray._private.runtime_env.working_dir as working_dir_pkg
from ray._private.services import ProcessInfo, start_ray_client_server
from ray._private.utils import detect_fate_sharing_support
@ -215,26 +215,16 @@ class ProxyManager():
output, error = self.node.get_log_file_handles(
f"ray_client_server_{specific_server.port}", unique=True)
serialized_runtime_env = job_config.get_serialized_runtime_env()
runtime_env = json.loads(serialized_runtime_env)
# Set up the working_dir for the server.
# TODO(edoakes): this should go be unified with the worker setup code
# by going through the runtime_env agent.
uris = job_config.get_runtime_env_uris() if job_config else []
if uris:
# Download and set up the working_dir locally.
working_dir = working_dir_pkg.ensure_runtime_env_setup(uris)
# Set PYTHONPATH in the environment variables so the working_dir
# is included in the module search path.
runtime_env = job_config.runtime_env
env_vars = runtime_env.get("env_vars", None) or {}
python_path = working_dir
if "PYTHONPATH" in env_vars:
python_path += (os.pathsep + runtime_env["PYTHONPATH"])
env_vars["PYTHONPATH"] = python_path
runtime_env["env_vars"] = env_vars
job_config.set_runtime_env(runtime_env)
serialized_runtime_env = job_config.get_serialized_runtime_env()
context = RuntimeEnvContext(
env_vars=runtime_env.get("env_vars"),
resources_dir=self.node.get_runtime_env_dir_path())
working_dir_pkg.setup_working_dir(runtime_env, context)
proc = start_ray_client_server(
self.redis_address,
@ -244,7 +234,7 @@ class ProxyManager():
fate_share=self.fate_share,
server_type="specific-server",
serialized_runtime_env=serialized_runtime_env,
session_dir=self.node.get_session_dir_path(),
serialized_runtime_env_context=context.serialize(),
redis_password=self._redis_password)
# Wait for the process being run transitions from the shim process

View file

@ -1393,9 +1393,10 @@ def connect(node,
worker.gcs_client = worker.core_worker.get_gcs_client()
# If it's a driver and it's not coming from ray client, we'll prepare the
# environment here. If it's ray client, the environmen will be prepared
# environment here. If it's ray client, the environment will be prepared
# at the server side.
if mode == SCRIPT_MODE and not job_config.client_job:
working_dir_pkg.PKG_DIR = worker.node.get_runtime_env_dir_path()
working_dir_pkg.upload_runtime_env_package_if_needed(job_config)
# Notify raylet that the core worker is ready.

View file

@ -1,4 +1,3 @@
import os
import sys
import argparse
import json
@ -6,7 +5,6 @@ import logging
from ray._private.runtime_env import RuntimeEnvContext
from ray._private.runtime_env.conda import setup_conda_or_pip
from ray._private.runtime_env.conda_utils import get_conda_activate_commands
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
@ -15,15 +13,12 @@ parser.add_argument(
"--serialized-runtime-env",
type=str,
help="the serialized parsed runtime env dict")
parser.add_argument(
"--serialized-runtime-env-context",
type=str,
help="the serialized runtime env context")
# The worker is not set up yet, so we can't get session_dir from the worker.
parser.add_argument(
"--session-dir", type=str, help="the directory for the current session")
"--from-ray-client", type=bool, required=False, default=False)
def setup_worker(input_args):
@ -31,54 +26,21 @@ def setup_worker(input_args):
# minus the python executable, e.g. default_worker.py --node-ip-address=...
args, remaining_args = parser.parse_known_args(args=input_args)
commands = []
py_executable: str = sys.executable
runtime_env: dict = json.loads(args.serialized_runtime_env or "{}")
runtime_env_context: RuntimeEnvContext = None
if args.serialized_runtime_env_context:
runtime_env_context = RuntimeEnvContext.deserialize(
args.serialized_runtime_env_context)
else:
runtime_env_context = RuntimeEnvContext(
env_vars=runtime_env.get("env_vars"))
# Ray client server setups runtime env by itself instead of agent.
if runtime_env.get("conda") or runtime_env.get("pip"):
if not args.serialized_runtime_env_context:
runtime_env_context = RuntimeEnvContext(args.session_dir)
if args.from_ray_client:
if runtime_env.get("conda") or runtime_env.get("pip"):
setup_conda_or_pip(runtime_env, runtime_env_context, logger=logger)
if runtime_env_context and runtime_env_context.working_dir is not None:
commands += [f"cd {runtime_env_context.working_dir}"]
# Insert the working_dir as the first entry in PYTHONPATH. This is
# compatible with users providing their own PYTHONPATH in env_vars.
env_vars = runtime_env.get("env_vars", None) or {}
python_path = runtime_env_context.working_dir
if "PYTHONPATH" in env_vars:
python_path += os.pathsep + runtime_env["PYTHONPATH"]
env_vars["PYTHONPATH"] = python_path
runtime_env["env_vars"] = env_vars
# Add a conda activate command prefix if using a conda env.
if runtime_env_context and runtime_env_context.conda_env_name is not None:
py_executable = "python"
conda_activate_commands = get_conda_activate_commands(
runtime_env_context.conda_env_name)
if (conda_activate_commands):
commands += conda_activate_commands
elif runtime_env.get("conda"):
logger.warning(
"Conda env name is not found in context, "
"but conda exists in runtime env. The runtime env %s, "
"the context %s.", args.serialized_runtime_env,
args.serialized_runtime_env_context)
commands += [" ".join([f"exec {py_executable}"] + remaining_args)]
command_str = " && ".join(commands)
# update env vars
if runtime_env.get("env_vars"):
env_vars = runtime_env["env_vars"]
os.environ.update(env_vars)
os.execvp("bash", ["bash", "-c", command_str])
runtime_env_context.exec_worker(remaining_args)
if __name__ == "__main__":

View file

@ -331,10 +331,10 @@ Process WorkerPool::StartWorkerProcess(
// The "shim process" setup worker is not needed, so do not run it.
// Check that the arg really is the path to the setup worker before erasing it, to
// prevent breaking tests that mock out the worker command args.
if (worker_command_args.size() >= 4 &&
if (worker_command_args.size() >= 3 &&
worker_command_args[1].find(kSetupWorkerFilename) != std::string::npos) {
worker_command_args.erase(worker_command_args.begin() + 1,
worker_command_args.begin() + 4);
worker_command_args.begin() + 3);
}
}