mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[core] Minimal support for runtime env (#14270)
This commit is contained in:
parent
ba6cebe30f
commit
ed8935406b
12 changed files with 483 additions and 23 deletions
301
python/ray/_private/runtime_env.py
Normal file
301
python/ray/_private/runtime_env.py
Normal file
|
@ -0,0 +1,301 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import inspect
|
||||
|
||||
from filelock import FileLock
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
from ray.job_config import JobConfig
|
||||
from enum import Enum
|
||||
from ray.experimental import internal_kv
|
||||
from ray.core.generated.common_pb2 import RuntimeEnv
|
||||
from typing import List, Tuple
|
||||
from types import ModuleType
|
||||
from urllib.parse import urlparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
# We need to setup this variable before
|
||||
# using this module
|
||||
PKG_DIR = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FILE_SIZE_WARNING = 10 * 1024 * 1024 # 10MB
|
||||
FILE_SIZE_LIMIT = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
|
||||
class Protocol(Enum):
|
||||
"""A enum for supported backend storage."""
|
||||
|
||||
# For docstring
|
||||
def __new__(cls, value, doc=None):
|
||||
self = object.__new__(cls)
|
||||
self._value_ = value
|
||||
if doc is not None:
|
||||
self.__doc__ = doc
|
||||
return self
|
||||
|
||||
GCS = "gcs", "For packages created and managed by the system."
|
||||
PIN_GCS = "pingcs", "For packages created and managed by the users."
|
||||
|
||||
|
||||
def _xor_bytes(left: bytes, right: bytes) -> bytes:
|
||||
if left and right:
|
||||
return bytes(a ^ b for (a, b) in zip(left, right))
|
||||
return left or right
|
||||
|
||||
|
||||
def _zip_module(path: Path, relative_path: Path, zip_handler: ZipFile) -> None:
|
||||
"""Go through all files and zip them into a zip file"""
|
||||
for from_file_name in path.glob("**/*"):
|
||||
file_size = from_file_name.stat().st_size
|
||||
if file_size >= FILE_SIZE_LIMIT:
|
||||
raise RuntimeError(f"File {from_file_name} is too big, "
|
||||
"which currently is not allowd ")
|
||||
if file_size >= FILE_SIZE_WARNING:
|
||||
logger.warning(
|
||||
f"File {from_file_name} is too big ({file_size} bytes). "
|
||||
"Consider exclude this file in working directory.")
|
||||
to_file_name = from_file_name.relative_to(relative_path)
|
||||
zip_handler.write(from_file_name, to_file_name)
|
||||
|
||||
|
||||
def _hash_modules(path: Path) -> bytes:
|
||||
"""Helper function to create hash of a directory.
|
||||
|
||||
It'll go through all the files in the directory and xor
|
||||
hash(file_name, file_content) to create a hash value.
|
||||
"""
|
||||
hash_val = None
|
||||
BUF_SIZE = 4096 * 1024
|
||||
for from_file_name in path.glob("**/*"):
|
||||
md5 = hashlib.md5()
|
||||
md5.update(str(from_file_name).encode())
|
||||
if not Path(from_file_name).is_dir():
|
||||
with open(from_file_name, mode="rb") as f:
|
||||
data = f.read(BUF_SIZE)
|
||||
if not data:
|
||||
break
|
||||
md5.update(data)
|
||||
hash_val = _xor_bytes(hash_val, md5.digest())
|
||||
return hash_val
|
||||
|
||||
|
||||
def _get_local_path(pkg_uri: str) -> str:
|
||||
assert PKG_DIR, "Please set PKG_DIR in the module first."
|
||||
(_, pkg_name) = _parse_uri(pkg_uri)
|
||||
return os.path.join(PKG_DIR, pkg_name)
|
||||
|
||||
|
||||
def _parse_uri(pkg_uri: str) -> Tuple[Protocol, str]:
|
||||
uri = urlparse(pkg_uri)
|
||||
protocol = Protocol(uri.scheme)
|
||||
return (protocol, uri.netloc)
|
||||
|
||||
|
||||
# TODO(yic): Fix this later to handle big directories in better way
|
||||
def get_project_package_name(working_dir: str, modules: List[str]) -> str:
|
||||
"""Get the name of the package by working dir and modules.
|
||||
|
||||
This function will generate the name of the package by the working
|
||||
directory and modules. It'll go through all the files in working_dir
|
||||
and modules and hash the contents of these files to get the hash value
|
||||
of this package. The final package name is: _ray_pkg_<HASH_VAL>.zip
|
||||
Right now, only the modules given will be included. The dependencies
|
||||
are not included automatically.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
>>> import any_module
|
||||
>>> get_project_package_name("/working_dir", [any_module])
|
||||
.... _ray_pkg_af2734982a741.zip
|
||||
|
||||
e.g., _ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip
|
||||
Args:
|
||||
working_dir (str): The working directory.
|
||||
modules (list[module]): The python module.
|
||||
|
||||
Returns:
|
||||
Package name as a string.
|
||||
"""
|
||||
RAY_PKG_PREFIX = "_ray_pkg_"
|
||||
hash_val = None
|
||||
if working_dir:
|
||||
assert isinstance(working_dir, str)
|
||||
assert Path(working_dir).exists()
|
||||
hash_val = _xor_bytes(hash_val, _hash_modules(Path(working_dir)))
|
||||
for module in modules or []:
|
||||
assert inspect.ismodule(module)
|
||||
hash_val = _xor_bytes(hash_val,
|
||||
_hash_modules(Path(module.__file__).parent))
|
||||
return RAY_PKG_PREFIX + hash_val.hex() + ".zip" if hash_val else None
|
||||
|
||||
|
||||
def create_project_package(working_dir: str, modules: List[ModuleType],
|
||||
output_path: str) -> None:
|
||||
"""Create a pckage that will be used by workers.
|
||||
|
||||
This function is used to create a package file based on working directory
|
||||
and python local modules.
|
||||
|
||||
Args:
|
||||
working_dir (str): The working directory.
|
||||
modules (list[module]): The python modules to be included.
|
||||
output_path (str): The path of file to be created.
|
||||
"""
|
||||
pkg_file = Path(output_path)
|
||||
with ZipFile(pkg_file, "w") as zip_handler:
|
||||
if working_dir:
|
||||
# put all files in /path/working_dir into zip
|
||||
working_path = Path(working_dir)
|
||||
_zip_module(working_path, working_path, zip_handler)
|
||||
for module in modules or []:
|
||||
logger.info(module.__file__)
|
||||
# we only take care of modules with path like this for now:
|
||||
# /path/module_name/__init__.py
|
||||
# module_path should be: /path/module_name
|
||||
module_path = Path(module.__file__).parent
|
||||
_zip_module(module_path, module_path.parent, zip_handler)
|
||||
|
||||
|
||||
def fetch_package(pkg_uri: str, pkg_file: Path) -> int:
|
||||
"""Fetch a package from a given uri.
|
||||
|
||||
This function is used to fetch a pacakge from the given uri to local
|
||||
filesystem.
|
||||
|
||||
Args:
|
||||
pkg_uri (str): The uri of the package to download.
|
||||
pkg_file (pathlib.Path): The path in local filesystem to download the
|
||||
package.
|
||||
|
||||
Returns:
|
||||
The number of bytes downloaded.
|
||||
"""
|
||||
(protocol, pkg_name) = _parse_uri(pkg_uri)
|
||||
if protocol in (Protocol.GCS, Protocol.PIN_GCS):
|
||||
code = internal_kv._internal_kv_get(pkg_uri)
|
||||
code = code or b""
|
||||
pkg_file.write_bytes(code)
|
||||
return len(code)
|
||||
else:
|
||||
raise NotImplementedError(f"Protocol {protocol} is not supported")
|
||||
|
||||
|
||||
def _store_package_in_gcs(gcs_key: str, data: bytes) -> int:
|
||||
internal_kv._internal_kv_put(gcs_key, data)
|
||||
return len(data)
|
||||
|
||||
|
||||
def push_package(pkg_uri: str, pkg_path: str) -> None:
|
||||
"""Push a package to uri.
|
||||
|
||||
This function is to push a local file to remote uri. Right now, only GCS
|
||||
is supported.
|
||||
|
||||
Args:
|
||||
pkg_uri (str): The uri of the package to upload to.
|
||||
pkg_path (str): Path of the local file.
|
||||
|
||||
Returns:
|
||||
The number of bytes uploaded.
|
||||
"""
|
||||
(protocol, pkg_name) = _parse_uri(pkg_uri)
|
||||
data = Path(pkg_path).read_bytes()
|
||||
if protocol in (Protocol.GCS, Protocol.PIN_GCS):
|
||||
_store_package_in_gcs(pkg_uri, data)
|
||||
else:
|
||||
raise NotImplementedError(f"Protocol {protocol} is not supported")
|
||||
|
||||
|
||||
def package_exists(pkg_uri: str) -> bool:
|
||||
"""Check whether the package with given uri exists or not.
|
||||
|
||||
Args:
|
||||
pkg_uri (str): The uri of the package
|
||||
|
||||
Return:
|
||||
True for package existing and False for not.
|
||||
"""
|
||||
(protocol, pkg_name) = _parse_uri(pkg_uri)
|
||||
if protocol in (Protocol.GCS, Protocol.PIN_GCS):
|
||||
return internal_kv._internal_kv_exists(pkg_uri)
|
||||
else:
|
||||
raise NotImplementedError(f"Protocol {protocol} is not supported")
|
||||
|
||||
|
||||
def rewrite_working_dir_uri(job_config: JobConfig) -> None:
|
||||
"""Rewrite the working dir uri field in job_config.
|
||||
|
||||
This function is used to update the runtime field in job_config. The
|
||||
runtime field will be generated based on the hash of required files and
|
||||
modules.
|
||||
|
||||
Args:
|
||||
job_config (JobConfig): The job config.
|
||||
"""
|
||||
# For now, we only support local directory and packages
|
||||
working_dir = job_config.runtime_env.get("working_dir")
|
||||
required_modules = job_config.runtime_env.get("local_modules")
|
||||
|
||||
if (not job_config.runtime_env.get("working_dir_uri")) and (
|
||||
working_dir or required_modules):
|
||||
pkg_name = get_project_package_name(working_dir, required_modules)
|
||||
job_config.runtime_env[
|
||||
"working_dir_uri"] = Protocol.GCS.value + "://" + pkg_name
|
||||
|
||||
|
||||
def upload_runtime_env_package_if_needed(job_config: JobConfig) -> None:
|
||||
"""Upload runtime env if it's not there.
|
||||
|
||||
It'll check whether the runtime environment exists in the cluster or not.
|
||||
If it doesn't exist, a package will be created based on the working
|
||||
directory and modules defined in job config. The package will be
|
||||
uploaded to the cluster after this.
|
||||
|
||||
Args:
|
||||
job_config (JobConfig): The job config of driver.
|
||||
"""
|
||||
pkg_uri = job_config.get_package_uri()
|
||||
if not pkg_uri:
|
||||
return
|
||||
if not package_exists(pkg_uri):
|
||||
file_path = _get_local_path(pkg_uri)
|
||||
pkg_file = Path(file_path)
|
||||
working_dir = job_config.runtime_env.get("working_dir")
|
||||
required_modules = job_config.runtime_env.get("local_modules")
|
||||
logger.info(f"{pkg_uri} doesn't exist. Create new package with"
|
||||
f" {working_dir} and {required_modules}")
|
||||
if not pkg_file.exists():
|
||||
create_project_package(working_dir, required_modules, file_path)
|
||||
# Push the data to remote storage
|
||||
pkg_size = push_package(pkg_uri, pkg_file)
|
||||
logger.info(f"{pkg_uri} has been pushed with {pkg_size} bytes")
|
||||
|
||||
|
||||
def ensure_runtime_env_setup(runtime_env: RuntimeEnv) -> None:
|
||||
"""Make sure all required packages are downloaded it local.
|
||||
|
||||
Necessary packages required to run the job will be downloaded
|
||||
into local file system if it doesn't exist.
|
||||
|
||||
Args:
|
||||
runtime_env (RuntimeEnv): Runtime environment of the job
|
||||
"""
|
||||
pkg_uri = runtime_env.working_dir_uri
|
||||
if not pkg_uri:
|
||||
return
|
||||
pkg_file = Path(_get_local_path(pkg_uri))
|
||||
# For each node, the package will only be downloaded one time
|
||||
# Locking to avoid multiple process download concurrently
|
||||
lock = FileLock(str(pkg_file) + ".lock")
|
||||
with lock:
|
||||
# TODO(yic): checksum calculation is required
|
||||
if pkg_file.exists():
|
||||
logger.info(f"{pkg_uri} has existed locally, skip downloading")
|
||||
else:
|
||||
pkg_size = fetch_package(pkg_uri, pkg_file)
|
||||
logger.info(f"Downloaded {pkg_size} bytes into {pkg_file}")
|
||||
sys.path.insert(0, str(pkg_file))
|
|
@ -871,7 +871,6 @@ cdef class CoreWorker:
|
|||
options.terminate_asyncio_thread = terminate_asyncio_thread
|
||||
options.serialized_job_config = serialized_job_config
|
||||
options.metrics_agent_port = metrics_agent_port
|
||||
|
||||
CCoreWorkerProcess.Initialize(options)
|
||||
|
||||
def __dealloc__(self):
|
||||
|
|
|
@ -17,6 +17,12 @@ def _internal_kv_get(key: Union[str, bytes]) -> bytes:
|
|||
return ray.worker.global_worker.redis_client.hget(key, "value")
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def _internal_kv_exists(key: Union[str, bytes]) -> bool:
|
||||
"""Check key exists or not."""
|
||||
return ray.worker.global_worker.redis_client.hexists(key, "value")
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def _internal_kv_put(key: Union[str, bytes],
|
||||
value: Union[str, bytes],
|
||||
|
|
|
@ -264,8 +264,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
|||
(c_bool() nogil) kill_main
|
||||
CCoreWorkerOptions()
|
||||
(void() nogil) terminate_asyncio_thread
|
||||
int metrics_agent_port
|
||||
c_string serialized_job_config
|
||||
int metrics_agent_port
|
||||
|
||||
cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess":
|
||||
@staticmethod
|
||||
|
|
|
@ -13,28 +13,31 @@ class JobConfig:
|
|||
code_search_path (list): A list of directories or jar files that
|
||||
specify the search path for user code. This will be used as
|
||||
`CLASSPATH` in Java and `PYTHONPATH` in Python.
|
||||
runtime_env (dict): A path to a local directory that will be zipped
|
||||
up and unpackaged in the working directory of the task/actor.
|
||||
There are three important fields.
|
||||
- `working_dir (str)`: A path to a local directory that will be
|
||||
zipped up and unpackaged in the working directory of the
|
||||
task/actor.
|
||||
- `working_dir_uri (str)`: Same as `working_dir` but a URI
|
||||
referencing a stored archive instead of a local path.
|
||||
Takes precedence over working_dir.
|
||||
- `local_modules (list[module])`: A list of local Python modules
|
||||
that will be zipped up and unpacked in a directory prepended
|
||||
to the sys.path of tasks/actors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_env=None,
|
||||
num_java_workers_per_process=1,
|
||||
jvm_options=None,
|
||||
code_search_path=None,
|
||||
):
|
||||
if worker_env is None:
|
||||
self.worker_env = dict()
|
||||
else:
|
||||
self.worker_env = worker_env
|
||||
def __init__(self,
|
||||
worker_env=None,
|
||||
num_java_workers_per_process=1,
|
||||
jvm_options=None,
|
||||
code_search_path=None,
|
||||
runtime_env=None):
|
||||
self.worker_env = worker_env or dict()
|
||||
self.num_java_workers_per_process = num_java_workers_per_process
|
||||
if jvm_options is None:
|
||||
self.jvm_options = []
|
||||
else:
|
||||
self.jvm_options = jvm_options
|
||||
if code_search_path is None:
|
||||
self.code_search_path = []
|
||||
else:
|
||||
self.code_search_path = code_search_path
|
||||
self.jvm_options = jvm_options or []
|
||||
self.code_search_path = code_search_path or []
|
||||
self.runtime_env = runtime_env or dict()
|
||||
|
||||
def serialize(self):
|
||||
job_config = ray.gcs_utils.JobConfig()
|
||||
|
@ -44,4 +47,15 @@ class JobConfig:
|
|||
self.num_java_workers_per_process)
|
||||
job_config.jvm_options.extend(self.jvm_options)
|
||||
job_config.code_search_path.extend(self.code_search_path)
|
||||
job_config.runtime_env.CopyFrom(self._get_proto_runtime())
|
||||
return job_config.SerializeToString()
|
||||
|
||||
def get_package_uri(self):
|
||||
return self.runtime_env.get("working_dir_uri")
|
||||
|
||||
def _get_proto_runtime(self):
|
||||
from ray.core.generated.common_pb2 import RuntimeEnv
|
||||
runtime_env = RuntimeEnv()
|
||||
if self.get_package_uri():
|
||||
runtime_env.working_dir_uri = self.get_package_uri()
|
||||
return runtime_env
|
||||
|
|
|
@ -286,6 +286,12 @@ class Node:
|
|||
try_to_create_directory(self._logs_dir)
|
||||
old_logs_dir = os.path.join(self._logs_dir, "old")
|
||||
try_to_create_directory(old_logs_dir)
|
||||
# Create a directory to be used for runtime environment.
|
||||
self._runtime_env_dir = os.path.join(self._session_dir,
|
||||
"runtime_resources")
|
||||
try_to_create_directory(self._runtime_env_dir)
|
||||
import ray._private.runtime_env as runtime_env
|
||||
runtime_env.PKG_DIR = self._runtime_env_dir
|
||||
|
||||
def get_resource_spec(self):
|
||||
"""Resolve and return the current resource spec for the node."""
|
||||
|
@ -434,6 +440,10 @@ class Node:
|
|||
"""Get the path of the temporary directory."""
|
||||
return self._temp_dir
|
||||
|
||||
def get_runtime_env_dir_path(self):
|
||||
"""Get the path of the runtime env."""
|
||||
return self._runtime_env_dir
|
||||
|
||||
def get_session_dir_path(self):
|
||||
"""Get the path of the session directory."""
|
||||
return self._session_dir
|
||||
|
|
|
@ -62,6 +62,7 @@ py_test_module_list(
|
|||
"test_reference_counting.py",
|
||||
"test_reference_counting_2.py",
|
||||
"test_resource_demand_scheduler.py",
|
||||
"test_runtime_env.py",
|
||||
"test_scheduling.py",
|
||||
"test_serialization.py",
|
||||
"test_stress.py",
|
||||
|
|
112
python/ray/tests/test_runtime_env.py
Normal file
112
python/ray/tests/test_runtime_env.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import pytest
|
||||
import sys
|
||||
import unittest
|
||||
from ray.test_utils import run_string_as_driver
|
||||
|
||||
driver_script = """
|
||||
import sys
|
||||
sys.path.insert(0, "{working_dir}")
|
||||
import test_module
|
||||
import ray
|
||||
|
||||
job_config = ray.job_config.JobConfig(
|
||||
runtime_env={runtime_env}
|
||||
)
|
||||
|
||||
ray.init(address="{redis_address}", job_config=job_config)
|
||||
|
||||
@ray.remote
|
||||
def run_test():
|
||||
return test_module.one()
|
||||
|
||||
print(sum(ray.get([run_test.remote()] * 1000)))
|
||||
|
||||
ray.shutdown()"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def working_dir():
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
path = Path(tmp_dir)
|
||||
module_path = path / "test_module"
|
||||
module_path.mkdir(parents=True)
|
||||
init_file = module_path / "__init__.py"
|
||||
test_file = module_path / "test.py"
|
||||
with test_file.open(mode="w") as f:
|
||||
f.write("""
|
||||
def one():
|
||||
return 1
|
||||
""")
|
||||
with init_file.open(mode="w") as f:
|
||||
f.write("""
|
||||
from test_module.test import one
|
||||
""")
|
||||
yield tmp_dir
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
||||
def test_single_node(ray_start_cluster_head, working_dir):
|
||||
cluster = ray_start_cluster_head
|
||||
redis_address = cluster.address
|
||||
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
|
||||
script = driver_script.format(
|
||||
redis_address=redis_address,
|
||||
working_dir=working_dir,
|
||||
runtime_env=runtime_env)
|
||||
|
||||
out = run_string_as_driver(script)
|
||||
assert out.strip().split()[-1] == "1000"
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
||||
def test_two_node(two_node_cluster, working_dir):
|
||||
cluster, _ = two_node_cluster
|
||||
redis_address = cluster.address
|
||||
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
|
||||
script = driver_script.format(
|
||||
redis_address=redis_address,
|
||||
working_dir=working_dir,
|
||||
runtime_env=runtime_env)
|
||||
out = run_string_as_driver(script)
|
||||
assert out.strip().split()[-1] == "1000"
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
||||
def test_two_node_module(two_node_cluster, working_dir):
|
||||
cluster, _ = two_node_cluster
|
||||
redis_address = cluster.address
|
||||
runtime_env = """{ "local_modules": [test_module] }"""
|
||||
script = driver_script.format(
|
||||
redis_address=redis_address,
|
||||
working_dir=working_dir,
|
||||
runtime_env=runtime_env)
|
||||
print(script)
|
||||
out = run_string_as_driver(script)
|
||||
assert out.strip().split()[-1] == "1000"
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
||||
def test_two_node_uri(two_node_cluster, working_dir):
|
||||
cluster, _ = two_node_cluster
|
||||
redis_address = cluster.address
|
||||
import ray._private.runtime_env as runtime_env
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix="zip") as tmp_file:
|
||||
pkg_name = runtime_env.get_project_package_name(working_dir, [])
|
||||
pkg_uri = runtime_env.Protocol.PIN_GCS.value + "://" + pkg_name
|
||||
runtime_env.create_project_package(working_dir, [], tmp_file.name)
|
||||
runtime_env.push_package(pkg_uri, tmp_file.name)
|
||||
runtime_env = f"""{{ "working_dir_uri": "{pkg_uri}" }}"""
|
||||
script = driver_script.format(
|
||||
redis_address=redis_address,
|
||||
working_dir=working_dir,
|
||||
runtime_env=runtime_env)
|
||||
out = run_string_as_driver(script)
|
||||
assert out.strip().split()[-1] == "1000"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-sv", __file__]))
|
|
@ -11,8 +11,8 @@ import tempfile
|
|||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from inspect import signature
|
||||
|
||||
from inspect import signature
|
||||
import numpy as np
|
||||
import psutil
|
||||
import ray
|
||||
|
|
|
@ -28,6 +28,7 @@ import ray.ray_constants as ray_constants
|
|||
import ray.remote_function
|
||||
import ray.serialization as serialization
|
||||
import ray._private.services as services
|
||||
import ray._private.runtime_env as runtime_env
|
||||
import ray
|
||||
import setproctitle
|
||||
import ray.signature
|
||||
|
@ -646,7 +647,6 @@ def init(
|
|||
if "RAY_ADDRESS" in os.environ:
|
||||
if address is None or address == "auto":
|
||||
address = os.environ["RAY_ADDRESS"]
|
||||
|
||||
# Convert hostnames to numerical IP address.
|
||||
if _node_ip_address is not None:
|
||||
node_ip_address = services.address_to_ip(_node_ip_address)
|
||||
|
@ -765,6 +765,11 @@ def init(
|
|||
spawn_reaper=False,
|
||||
connect_only=True)
|
||||
|
||||
if driver_mode == SCRIPT_MODE and job_config:
|
||||
# Rewrite the URI. Note the package isn't uploaded to the URI until
|
||||
# later in the connect
|
||||
runtime_env.rewrite_working_dir_uri(job_config)
|
||||
|
||||
connect(
|
||||
_global_node,
|
||||
mode=driver_mode,
|
||||
|
@ -1218,6 +1223,7 @@ def connect(node,
|
|||
)
|
||||
if job_config is None:
|
||||
job_config = ray.job_config.JobConfig()
|
||||
|
||||
serialized_job_config = job_config.serialize()
|
||||
worker.core_worker = ray._raylet.CoreWorker(
|
||||
mode, node.plasma_store_socket_name, node.raylet_socket_name, job_id,
|
||||
|
@ -1231,6 +1237,11 @@ def connect(node,
|
|||
# will use glog, which is intialized in `CoreWorker`.
|
||||
ray.state.state._initialize_global_state(
|
||||
node.redis_address, redis_password=node.redis_password)
|
||||
if mode == SCRIPT_MODE:
|
||||
runtime_env.upload_runtime_env_package_if_needed(job_config)
|
||||
elif mode == WORKER_MODE:
|
||||
runtime_env.ensure_runtime_env_setup(
|
||||
worker.core_worker.get_job_config().runtime_env)
|
||||
|
||||
if driver_object_store_memory is not None:
|
||||
worker.core_worker.set_object_store_client_options(
|
||||
|
|
|
@ -133,6 +133,10 @@ message RayException {
|
|||
string formatted_exception_string = 3;
|
||||
}
|
||||
|
||||
message RuntimeEnv {
|
||||
string working_dir_uri = 1;
|
||||
}
|
||||
|
||||
/// The task specification encapsulates all immutable information about the
|
||||
/// task. These fields are determined at submission time, converse to the
|
||||
/// `TaskExecutionSpec` may change at execution time.
|
||||
|
|
|
@ -268,6 +268,8 @@ message JobConfig {
|
|||
// code. This will be used as `CLASSPATH` in Java, and `PYTHONPATH` in
|
||||
// Python.
|
||||
repeated string code_search_path = 4;
|
||||
// Runtime environment to run the code
|
||||
RuntimeEnv runtime_env = 5;
|
||||
}
|
||||
|
||||
message JobTableData {
|
||||
|
|
Loading…
Add table
Reference in a new issue