From ed8935406b0cd4f62ec9da536118fb46eadeaad7 Mon Sep 17 00:00:00 2001 From: Yi Cheng <74173148+iycheng@users.noreply.github.com> Date: Tue, 9 Mar 2021 11:53:58 -0800 Subject: [PATCH] [core] Minimal support for runtime env (#14270) --- python/ray/_private/runtime_env.py | 301 +++++++++++++++++++++++++ python/ray/_raylet.pyx | 1 - python/ray/experimental/internal_kv.py | 6 + python/ray/includes/libcoreworker.pxd | 2 +- python/ray/job_config.py | 52 +++-- python/ray/node.py | 10 + python/ray/tests/BUILD | 1 + python/ray/tests/test_runtime_env.py | 112 +++++++++ python/ray/utils.py | 2 +- python/ray/worker.py | 13 +- src/ray/protobuf/common.proto | 4 + src/ray/protobuf/gcs.proto | 2 + 12 files changed, 483 insertions(+), 23 deletions(-) create mode 100644 python/ray/_private/runtime_env.py create mode 100644 python/ray/tests/test_runtime_env.py diff --git a/python/ray/_private/runtime_env.py b/python/ray/_private/runtime_env.py new file mode 100644 index 000000000..1c10d2859 --- /dev/null +++ b/python/ray/_private/runtime_env.py @@ -0,0 +1,301 @@ +import hashlib +import logging +import inspect + +from filelock import FileLock +from pathlib import Path +from zipfile import ZipFile +from ray.job_config import JobConfig +from enum import Enum +from ray.experimental import internal_kv +from ray.core.generated.common_pb2 import RuntimeEnv +from typing import List, Tuple +from types import ModuleType +from urllib.parse import urlparse +import os +import sys + +# We need to setup this variable before +# using this module +PKG_DIR = None + +logger = logging.getLogger(__name__) + +FILE_SIZE_WARNING = 10 * 1024 * 1024 # 10MB +FILE_SIZE_LIMIT = 50 * 1024 * 1024 # 50MB + + +class Protocol(Enum): + """A enum for supported backend storage.""" + + # For docstring + def __new__(cls, value, doc=None): + self = object.__new__(cls) + self._value_ = value + if doc is not None: + self.__doc__ = doc + return self + + GCS = "gcs", "For packages created and managed by the system." + PIN_GCS = "pingcs", "For packages created and managed by the users." + + +def _xor_bytes(left: bytes, right: bytes) -> bytes: + if left and right: + return bytes(a ^ b for (a, b) in zip(left, right)) + return left or right + + +def _zip_module(path: Path, relative_path: Path, zip_handler: ZipFile) -> None: + """Go through all files and zip them into a zip file""" + for from_file_name in path.glob("**/*"): + file_size = from_file_name.stat().st_size + if file_size >= FILE_SIZE_LIMIT: + raise RuntimeError(f"File {from_file_name} is too big, " + "which currently is not allowd ") + if file_size >= FILE_SIZE_WARNING: + logger.warning( + f"File {from_file_name} is too big ({file_size} bytes). " + "Consider exclude this file in working directory.") + to_file_name = from_file_name.relative_to(relative_path) + zip_handler.write(from_file_name, to_file_name) + + +def _hash_modules(path: Path) -> bytes: + """Helper function to create hash of a directory. + + It'll go through all the files in the directory and xor + hash(file_name, file_content) to create a hash value. + """ + hash_val = None + BUF_SIZE = 4096 * 1024 + for from_file_name in path.glob("**/*"): + md5 = hashlib.md5() + md5.update(str(from_file_name).encode()) + if not Path(from_file_name).is_dir(): + with open(from_file_name, mode="rb") as f: + data = f.read(BUF_SIZE) + if not data: + break + md5.update(data) + hash_val = _xor_bytes(hash_val, md5.digest()) + return hash_val + + +def _get_local_path(pkg_uri: str) -> str: + assert PKG_DIR, "Please set PKG_DIR in the module first." + (_, pkg_name) = _parse_uri(pkg_uri) + return os.path.join(PKG_DIR, pkg_name) + + +def _parse_uri(pkg_uri: str) -> Tuple[Protocol, str]: + uri = urlparse(pkg_uri) + protocol = Protocol(uri.scheme) + return (protocol, uri.netloc) + + +# TODO(yic): Fix this later to handle big directories in better way +def get_project_package_name(working_dir: str, modules: List[str]) -> str: + """Get the name of the package by working dir and modules. + + This function will generate the name of the package by the working + directory and modules. It'll go through all the files in working_dir + and modules and hash the contents of these files to get the hash value + of this package. The final package name is: _ray_pkg_.zip + Right now, only the modules given will be included. The dependencies + are not included automatically. + + Examples: + + .. code-block:: python + >>> import any_module + >>> get_project_package_name("/working_dir", [any_module]) + .... _ray_pkg_af2734982a741.zip + + e.g., _ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip + Args: + working_dir (str): The working directory. + modules (list[module]): The python module. + + Returns: + Package name as a string. + """ + RAY_PKG_PREFIX = "_ray_pkg_" + hash_val = None + if working_dir: + assert isinstance(working_dir, str) + assert Path(working_dir).exists() + hash_val = _xor_bytes(hash_val, _hash_modules(Path(working_dir))) + for module in modules or []: + assert inspect.ismodule(module) + hash_val = _xor_bytes(hash_val, + _hash_modules(Path(module.__file__).parent)) + return RAY_PKG_PREFIX + hash_val.hex() + ".zip" if hash_val else None + + +def create_project_package(working_dir: str, modules: List[ModuleType], + output_path: str) -> None: + """Create a pckage that will be used by workers. + + This function is used to create a package file based on working directory + and python local modules. + + Args: + working_dir (str): The working directory. + modules (list[module]): The python modules to be included. + output_path (str): The path of file to be created. + """ + pkg_file = Path(output_path) + with ZipFile(pkg_file, "w") as zip_handler: + if working_dir: + # put all files in /path/working_dir into zip + working_path = Path(working_dir) + _zip_module(working_path, working_path, zip_handler) + for module in modules or []: + logger.info(module.__file__) + # we only take care of modules with path like this for now: + # /path/module_name/__init__.py + # module_path should be: /path/module_name + module_path = Path(module.__file__).parent + _zip_module(module_path, module_path.parent, zip_handler) + + +def fetch_package(pkg_uri: str, pkg_file: Path) -> int: + """Fetch a package from a given uri. + + This function is used to fetch a pacakge from the given uri to local + filesystem. + + Args: + pkg_uri (str): The uri of the package to download. + pkg_file (pathlib.Path): The path in local filesystem to download the + package. + + Returns: + The number of bytes downloaded. + """ + (protocol, pkg_name) = _parse_uri(pkg_uri) + if protocol in (Protocol.GCS, Protocol.PIN_GCS): + code = internal_kv._internal_kv_get(pkg_uri) + code = code or b"" + pkg_file.write_bytes(code) + return len(code) + else: + raise NotImplementedError(f"Protocol {protocol} is not supported") + + +def _store_package_in_gcs(gcs_key: str, data: bytes) -> int: + internal_kv._internal_kv_put(gcs_key, data) + return len(data) + + +def push_package(pkg_uri: str, pkg_path: str) -> None: + """Push a package to uri. + + This function is to push a local file to remote uri. Right now, only GCS + is supported. + + Args: + pkg_uri (str): The uri of the package to upload to. + pkg_path (str): Path of the local file. + + Returns: + The number of bytes uploaded. + """ + (protocol, pkg_name) = _parse_uri(pkg_uri) + data = Path(pkg_path).read_bytes() + if protocol in (Protocol.GCS, Protocol.PIN_GCS): + _store_package_in_gcs(pkg_uri, data) + else: + raise NotImplementedError(f"Protocol {protocol} is not supported") + + +def package_exists(pkg_uri: str) -> bool: + """Check whether the package with given uri exists or not. + + Args: + pkg_uri (str): The uri of the package + + Return: + True for package existing and False for not. + """ + (protocol, pkg_name) = _parse_uri(pkg_uri) + if protocol in (Protocol.GCS, Protocol.PIN_GCS): + return internal_kv._internal_kv_exists(pkg_uri) + else: + raise NotImplementedError(f"Protocol {protocol} is not supported") + + +def rewrite_working_dir_uri(job_config: JobConfig) -> None: + """Rewrite the working dir uri field in job_config. + + This function is used to update the runtime field in job_config. The + runtime field will be generated based on the hash of required files and + modules. + + Args: + job_config (JobConfig): The job config. + """ + # For now, we only support local directory and packages + working_dir = job_config.runtime_env.get("working_dir") + required_modules = job_config.runtime_env.get("local_modules") + + if (not job_config.runtime_env.get("working_dir_uri")) and ( + working_dir or required_modules): + pkg_name = get_project_package_name(working_dir, required_modules) + job_config.runtime_env[ + "working_dir_uri"] = Protocol.GCS.value + "://" + pkg_name + + +def upload_runtime_env_package_if_needed(job_config: JobConfig) -> None: + """Upload runtime env if it's not there. + + It'll check whether the runtime environment exists in the cluster or not. + If it doesn't exist, a package will be created based on the working + directory and modules defined in job config. The package will be + uploaded to the cluster after this. + + Args: + job_config (JobConfig): The job config of driver. + """ + pkg_uri = job_config.get_package_uri() + if not pkg_uri: + return + if not package_exists(pkg_uri): + file_path = _get_local_path(pkg_uri) + pkg_file = Path(file_path) + working_dir = job_config.runtime_env.get("working_dir") + required_modules = job_config.runtime_env.get("local_modules") + logger.info(f"{pkg_uri} doesn't exist. Create new package with" + f" {working_dir} and {required_modules}") + if not pkg_file.exists(): + create_project_package(working_dir, required_modules, file_path) + # Push the data to remote storage + pkg_size = push_package(pkg_uri, pkg_file) + logger.info(f"{pkg_uri} has been pushed with {pkg_size} bytes") + + +def ensure_runtime_env_setup(runtime_env: RuntimeEnv) -> None: + """Make sure all required packages are downloaded it local. + + Necessary packages required to run the job will be downloaded + into local file system if it doesn't exist. + + Args: + runtime_env (RuntimeEnv): Runtime environment of the job + """ + pkg_uri = runtime_env.working_dir_uri + if not pkg_uri: + return + pkg_file = Path(_get_local_path(pkg_uri)) + # For each node, the package will only be downloaded one time + # Locking to avoid multiple process download concurrently + lock = FileLock(str(pkg_file) + ".lock") + with lock: + # TODO(yic): checksum calculation is required + if pkg_file.exists(): + logger.info(f"{pkg_uri} has existed locally, skip downloading") + else: + pkg_size = fetch_package(pkg_uri, pkg_file) + logger.info(f"Downloaded {pkg_size} bytes into {pkg_file}") + sys.path.insert(0, str(pkg_file)) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 36dae66df..514859416 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -871,7 +871,6 @@ cdef class CoreWorker: options.terminate_asyncio_thread = terminate_asyncio_thread options.serialized_job_config = serialized_job_config options.metrics_agent_port = metrics_agent_port - CCoreWorkerProcess.Initialize(options) def __dealloc__(self): diff --git a/python/ray/experimental/internal_kv.py b/python/ray/experimental/internal_kv.py index deb33902d..7835a3fb0 100644 --- a/python/ray/experimental/internal_kv.py +++ b/python/ray/experimental/internal_kv.py @@ -17,6 +17,12 @@ def _internal_kv_get(key: Union[str, bytes]) -> bytes: return ray.worker.global_worker.redis_client.hget(key, "value") +@client_mode_hook +def _internal_kv_exists(key: Union[str, bytes]) -> bool: + """Check key exists or not.""" + return ray.worker.global_worker.redis_client.hexists(key, "value") + + @client_mode_hook def _internal_kv_put(key: Union[str, bytes], value: Union[str, bytes], diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 0350e0ed0..c6699eb36 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -264,8 +264,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: (c_bool() nogil) kill_main CCoreWorkerOptions() (void() nogil) terminate_asyncio_thread - int metrics_agent_port c_string serialized_job_config + int metrics_agent_port cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess": @staticmethod diff --git a/python/ray/job_config.py b/python/ray/job_config.py index 92474a693..520e86433 100644 --- a/python/ray/job_config.py +++ b/python/ray/job_config.py @@ -13,28 +13,31 @@ class JobConfig: code_search_path (list): A list of directories or jar files that specify the search path for user code. This will be used as `CLASSPATH` in Java and `PYTHONPATH` in Python. + runtime_env (dict): A path to a local directory that will be zipped + up and unpackaged in the working directory of the task/actor. + There are three important fields. + - `working_dir (str)`: A path to a local directory that will be + zipped up and unpackaged in the working directory of the + task/actor. + - `working_dir_uri (str)`: Same as `working_dir` but a URI + referencing a stored archive instead of a local path. + Takes precedence over working_dir. + - `local_modules (list[module])`: A list of local Python modules + that will be zipped up and unpacked in a directory prepended + to the sys.path of tasks/actors. """ - def __init__( - self, - worker_env=None, - num_java_workers_per_process=1, - jvm_options=None, - code_search_path=None, - ): - if worker_env is None: - self.worker_env = dict() - else: - self.worker_env = worker_env + def __init__(self, + worker_env=None, + num_java_workers_per_process=1, + jvm_options=None, + code_search_path=None, + runtime_env=None): + self.worker_env = worker_env or dict() self.num_java_workers_per_process = num_java_workers_per_process - if jvm_options is None: - self.jvm_options = [] - else: - self.jvm_options = jvm_options - if code_search_path is None: - self.code_search_path = [] - else: - self.code_search_path = code_search_path + self.jvm_options = jvm_options or [] + self.code_search_path = code_search_path or [] + self.runtime_env = runtime_env or dict() def serialize(self): job_config = ray.gcs_utils.JobConfig() @@ -44,4 +47,15 @@ class JobConfig: self.num_java_workers_per_process) job_config.jvm_options.extend(self.jvm_options) job_config.code_search_path.extend(self.code_search_path) + job_config.runtime_env.CopyFrom(self._get_proto_runtime()) return job_config.SerializeToString() + + def get_package_uri(self): + return self.runtime_env.get("working_dir_uri") + + def _get_proto_runtime(self): + from ray.core.generated.common_pb2 import RuntimeEnv + runtime_env = RuntimeEnv() + if self.get_package_uri(): + runtime_env.working_dir_uri = self.get_package_uri() + return runtime_env diff --git a/python/ray/node.py b/python/ray/node.py index dc73348e2..838b81066 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -286,6 +286,12 @@ class Node: try_to_create_directory(self._logs_dir) old_logs_dir = os.path.join(self._logs_dir, "old") try_to_create_directory(old_logs_dir) + # Create a directory to be used for runtime environment. + self._runtime_env_dir = os.path.join(self._session_dir, + "runtime_resources") + try_to_create_directory(self._runtime_env_dir) + import ray._private.runtime_env as runtime_env + runtime_env.PKG_DIR = self._runtime_env_dir def get_resource_spec(self): """Resolve and return the current resource spec for the node.""" @@ -434,6 +440,10 @@ class Node: """Get the path of the temporary directory.""" return self._temp_dir + def get_runtime_env_dir_path(self): + """Get the path of the runtime env.""" + return self._runtime_env_dir + def get_session_dir_path(self): """Get the path of the session directory.""" return self._session_dir diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 10c83e66d..b3981d0da 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -62,6 +62,7 @@ py_test_module_list( "test_reference_counting.py", "test_reference_counting_2.py", "test_resource_demand_scheduler.py", + "test_runtime_env.py", "test_scheduling.py", "test_serialization.py", "test_stress.py", diff --git a/python/ray/tests/test_runtime_env.py b/python/ray/tests/test_runtime_env.py new file mode 100644 index 000000000..8c309f7b2 --- /dev/null +++ b/python/ray/tests/test_runtime_env.py @@ -0,0 +1,112 @@ +import pytest +import sys +import unittest +from ray.test_utils import run_string_as_driver + +driver_script = """ +import sys +sys.path.insert(0, "{working_dir}") +import test_module +import ray + +job_config = ray.job_config.JobConfig( + runtime_env={runtime_env} +) + +ray.init(address="{redis_address}", job_config=job_config) + +@ray.remote +def run_test(): + return test_module.one() + +print(sum(ray.get([run_test.remote()] * 1000))) + +ray.shutdown()""" + + +@pytest.fixture +def working_dir(): + import tempfile + from pathlib import Path + with tempfile.TemporaryDirectory() as tmp_dir: + path = Path(tmp_dir) + module_path = path / "test_module" + module_path.mkdir(parents=True) + init_file = module_path / "__init__.py" + test_file = module_path / "test.py" + with test_file.open(mode="w") as f: + f.write(""" +def one(): + return 1 +""") + with init_file.open(mode="w") as f: + f.write(""" +from test_module.test import one +""") + yield tmp_dir + + +@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.") +def test_single_node(ray_start_cluster_head, working_dir): + cluster = ray_start_cluster_head + redis_address = cluster.address + runtime_env = f"""{{ "working_dir": "{working_dir}" }}""" + script = driver_script.format( + redis_address=redis_address, + working_dir=working_dir, + runtime_env=runtime_env) + + out = run_string_as_driver(script) + assert out.strip().split()[-1] == "1000" + + +@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.") +def test_two_node(two_node_cluster, working_dir): + cluster, _ = two_node_cluster + redis_address = cluster.address + runtime_env = f"""{{ "working_dir": "{working_dir}" }}""" + script = driver_script.format( + redis_address=redis_address, + working_dir=working_dir, + runtime_env=runtime_env) + out = run_string_as_driver(script) + assert out.strip().split()[-1] == "1000" + + +@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.") +def test_two_node_module(two_node_cluster, working_dir): + cluster, _ = two_node_cluster + redis_address = cluster.address + runtime_env = """{ "local_modules": [test_module] }""" + script = driver_script.format( + redis_address=redis_address, + working_dir=working_dir, + runtime_env=runtime_env) + print(script) + out = run_string_as_driver(script) + assert out.strip().split()[-1] == "1000" + + +@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.") +def test_two_node_uri(two_node_cluster, working_dir): + cluster, _ = two_node_cluster + redis_address = cluster.address + import ray._private.runtime_env as runtime_env + import tempfile + with tempfile.NamedTemporaryFile(suffix="zip") as tmp_file: + pkg_name = runtime_env.get_project_package_name(working_dir, []) + pkg_uri = runtime_env.Protocol.PIN_GCS.value + "://" + pkg_name + runtime_env.create_project_package(working_dir, [], tmp_file.name) + runtime_env.push_package(pkg_uri, tmp_file.name) + runtime_env = f"""{{ "working_dir_uri": "{pkg_uri}" }}""" + script = driver_script.format( + redis_address=redis_address, + working_dir=working_dir, + runtime_env=runtime_env) + out = run_string_as_driver(script) + assert out.strip().split()[-1] == "1000" + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/utils.py b/python/ray/utils.py index 2704e07cc..d31bd8297 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -11,8 +11,8 @@ import tempfile import threading import time import uuid -from inspect import signature +from inspect import signature import numpy as np import psutil import ray diff --git a/python/ray/worker.py b/python/ray/worker.py index 8fff0e8b4..3c66aede4 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -28,6 +28,7 @@ import ray.ray_constants as ray_constants import ray.remote_function import ray.serialization as serialization import ray._private.services as services +import ray._private.runtime_env as runtime_env import ray import setproctitle import ray.signature @@ -646,7 +647,6 @@ def init( if "RAY_ADDRESS" in os.environ: if address is None or address == "auto": address = os.environ["RAY_ADDRESS"] - # Convert hostnames to numerical IP address. if _node_ip_address is not None: node_ip_address = services.address_to_ip(_node_ip_address) @@ -765,6 +765,11 @@ def init( spawn_reaper=False, connect_only=True) + if driver_mode == SCRIPT_MODE and job_config: + # Rewrite the URI. Note the package isn't uploaded to the URI until + # later in the connect + runtime_env.rewrite_working_dir_uri(job_config) + connect( _global_node, mode=driver_mode, @@ -1218,6 +1223,7 @@ def connect(node, ) if job_config is None: job_config = ray.job_config.JobConfig() + serialized_job_config = job_config.serialize() worker.core_worker = ray._raylet.CoreWorker( mode, node.plasma_store_socket_name, node.raylet_socket_name, job_id, @@ -1231,6 +1237,11 @@ def connect(node, # will use glog, which is intialized in `CoreWorker`. ray.state.state._initialize_global_state( node.redis_address, redis_password=node.redis_password) + if mode == SCRIPT_MODE: + runtime_env.upload_runtime_env_package_if_needed(job_config) + elif mode == WORKER_MODE: + runtime_env.ensure_runtime_env_setup( + worker.core_worker.get_job_config().runtime_env) if driver_object_store_memory is not None: worker.core_worker.set_object_store_client_options( diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 7178fe715..faebdc2e7 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -133,6 +133,10 @@ message RayException { string formatted_exception_string = 3; } +message RuntimeEnv { + string working_dir_uri = 1; +} + /// The task specification encapsulates all immutable information about the /// task. These fields are determined at submission time, converse to the /// `TaskExecutionSpec` may change at execution time. diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 3c3117af1..08dfeedf5 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -268,6 +268,8 @@ message JobConfig { // code. This will be used as `CLASSPATH` in Java, and `PYTHONPATH` in // Python. repeated string code_search_path = 4; + // Runtime environment to run the code + RuntimeEnv runtime_env = 5; } message JobTableData {