[core] Minimal support for runtime env (#14270)

This commit is contained in:
Yi Cheng 2021-03-09 11:53:58 -08:00 committed by GitHub
parent ba6cebe30f
commit ed8935406b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 483 additions and 23 deletions

View file

@ -0,0 +1,301 @@
import hashlib
import logging
import inspect
from filelock import FileLock
from pathlib import Path
from zipfile import ZipFile
from ray.job_config import JobConfig
from enum import Enum
from ray.experimental import internal_kv
from ray.core.generated.common_pb2 import RuntimeEnv
from typing import List, Tuple
from types import ModuleType
from urllib.parse import urlparse
import os
import sys
# We need to setup this variable before
# using this module
PKG_DIR = None
logger = logging.getLogger(__name__)
FILE_SIZE_WARNING = 10 * 1024 * 1024 # 10MB
FILE_SIZE_LIMIT = 50 * 1024 * 1024 # 50MB
class Protocol(Enum):
"""A enum for supported backend storage."""
# For docstring
def __new__(cls, value, doc=None):
self = object.__new__(cls)
self._value_ = value
if doc is not None:
self.__doc__ = doc
return self
GCS = "gcs", "For packages created and managed by the system."
PIN_GCS = "pingcs", "For packages created and managed by the users."
def _xor_bytes(left: bytes, right: bytes) -> bytes:
if left and right:
return bytes(a ^ b for (a, b) in zip(left, right))
return left or right
def _zip_module(path: Path, relative_path: Path, zip_handler: ZipFile) -> None:
"""Go through all files and zip them into a zip file"""
for from_file_name in path.glob("**/*"):
file_size = from_file_name.stat().st_size
if file_size >= FILE_SIZE_LIMIT:
raise RuntimeError(f"File {from_file_name} is too big, "
"which currently is not allowd ")
if file_size >= FILE_SIZE_WARNING:
logger.warning(
f"File {from_file_name} is too big ({file_size} bytes). "
"Consider exclude this file in working directory.")
to_file_name = from_file_name.relative_to(relative_path)
zip_handler.write(from_file_name, to_file_name)
def _hash_modules(path: Path) -> bytes:
"""Helper function to create hash of a directory.
It'll go through all the files in the directory and xor
hash(file_name, file_content) to create a hash value.
"""
hash_val = None
BUF_SIZE = 4096 * 1024
for from_file_name in path.glob("**/*"):
md5 = hashlib.md5()
md5.update(str(from_file_name).encode())
if not Path(from_file_name).is_dir():
with open(from_file_name, mode="rb") as f:
data = f.read(BUF_SIZE)
if not data:
break
md5.update(data)
hash_val = _xor_bytes(hash_val, md5.digest())
return hash_val
def _get_local_path(pkg_uri: str) -> str:
assert PKG_DIR, "Please set PKG_DIR in the module first."
(_, pkg_name) = _parse_uri(pkg_uri)
return os.path.join(PKG_DIR, pkg_name)
def _parse_uri(pkg_uri: str) -> Tuple[Protocol, str]:
uri = urlparse(pkg_uri)
protocol = Protocol(uri.scheme)
return (protocol, uri.netloc)
# TODO(yic): Fix this later to handle big directories in better way
def get_project_package_name(working_dir: str, modules: List[str]) -> str:
"""Get the name of the package by working dir and modules.
This function will generate the name of the package by the working
directory and modules. It'll go through all the files in working_dir
and modules and hash the contents of these files to get the hash value
of this package. The final package name is: _ray_pkg_<HASH_VAL>.zip
Right now, only the modules given will be included. The dependencies
are not included automatically.
Examples:
.. code-block:: python
>>> import any_module
>>> get_project_package_name("/working_dir", [any_module])
.... _ray_pkg_af2734982a741.zip
e.g., _ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip
Args:
working_dir (str): The working directory.
modules (list[module]): The python module.
Returns:
Package name as a string.
"""
RAY_PKG_PREFIX = "_ray_pkg_"
hash_val = None
if working_dir:
assert isinstance(working_dir, str)
assert Path(working_dir).exists()
hash_val = _xor_bytes(hash_val, _hash_modules(Path(working_dir)))
for module in modules or []:
assert inspect.ismodule(module)
hash_val = _xor_bytes(hash_val,
_hash_modules(Path(module.__file__).parent))
return RAY_PKG_PREFIX + hash_val.hex() + ".zip" if hash_val else None
def create_project_package(working_dir: str, modules: List[ModuleType],
output_path: str) -> None:
"""Create a pckage that will be used by workers.
This function is used to create a package file based on working directory
and python local modules.
Args:
working_dir (str): The working directory.
modules (list[module]): The python modules to be included.
output_path (str): The path of file to be created.
"""
pkg_file = Path(output_path)
with ZipFile(pkg_file, "w") as zip_handler:
if working_dir:
# put all files in /path/working_dir into zip
working_path = Path(working_dir)
_zip_module(working_path, working_path, zip_handler)
for module in modules or []:
logger.info(module.__file__)
# we only take care of modules with path like this for now:
# /path/module_name/__init__.py
# module_path should be: /path/module_name
module_path = Path(module.__file__).parent
_zip_module(module_path, module_path.parent, zip_handler)
def fetch_package(pkg_uri: str, pkg_file: Path) -> int:
"""Fetch a package from a given uri.
This function is used to fetch a pacakge from the given uri to local
filesystem.
Args:
pkg_uri (str): The uri of the package to download.
pkg_file (pathlib.Path): The path in local filesystem to download the
package.
Returns:
The number of bytes downloaded.
"""
(protocol, pkg_name) = _parse_uri(pkg_uri)
if protocol in (Protocol.GCS, Protocol.PIN_GCS):
code = internal_kv._internal_kv_get(pkg_uri)
code = code or b""
pkg_file.write_bytes(code)
return len(code)
else:
raise NotImplementedError(f"Protocol {protocol} is not supported")
def _store_package_in_gcs(gcs_key: str, data: bytes) -> int:
internal_kv._internal_kv_put(gcs_key, data)
return len(data)
def push_package(pkg_uri: str, pkg_path: str) -> None:
"""Push a package to uri.
This function is to push a local file to remote uri. Right now, only GCS
is supported.
Args:
pkg_uri (str): The uri of the package to upload to.
pkg_path (str): Path of the local file.
Returns:
The number of bytes uploaded.
"""
(protocol, pkg_name) = _parse_uri(pkg_uri)
data = Path(pkg_path).read_bytes()
if protocol in (Protocol.GCS, Protocol.PIN_GCS):
_store_package_in_gcs(pkg_uri, data)
else:
raise NotImplementedError(f"Protocol {protocol} is not supported")
def package_exists(pkg_uri: str) -> bool:
"""Check whether the package with given uri exists or not.
Args:
pkg_uri (str): The uri of the package
Return:
True for package existing and False for not.
"""
(protocol, pkg_name) = _parse_uri(pkg_uri)
if protocol in (Protocol.GCS, Protocol.PIN_GCS):
return internal_kv._internal_kv_exists(pkg_uri)
else:
raise NotImplementedError(f"Protocol {protocol} is not supported")
def rewrite_working_dir_uri(job_config: JobConfig) -> None:
"""Rewrite the working dir uri field in job_config.
This function is used to update the runtime field in job_config. The
runtime field will be generated based on the hash of required files and
modules.
Args:
job_config (JobConfig): The job config.
"""
# For now, we only support local directory and packages
working_dir = job_config.runtime_env.get("working_dir")
required_modules = job_config.runtime_env.get("local_modules")
if (not job_config.runtime_env.get("working_dir_uri")) and (
working_dir or required_modules):
pkg_name = get_project_package_name(working_dir, required_modules)
job_config.runtime_env[
"working_dir_uri"] = Protocol.GCS.value + "://" + pkg_name
def upload_runtime_env_package_if_needed(job_config: JobConfig) -> None:
"""Upload runtime env if it's not there.
It'll check whether the runtime environment exists in the cluster or not.
If it doesn't exist, a package will be created based on the working
directory and modules defined in job config. The package will be
uploaded to the cluster after this.
Args:
job_config (JobConfig): The job config of driver.
"""
pkg_uri = job_config.get_package_uri()
if not pkg_uri:
return
if not package_exists(pkg_uri):
file_path = _get_local_path(pkg_uri)
pkg_file = Path(file_path)
working_dir = job_config.runtime_env.get("working_dir")
required_modules = job_config.runtime_env.get("local_modules")
logger.info(f"{pkg_uri} doesn't exist. Create new package with"
f" {working_dir} and {required_modules}")
if not pkg_file.exists():
create_project_package(working_dir, required_modules, file_path)
# Push the data to remote storage
pkg_size = push_package(pkg_uri, pkg_file)
logger.info(f"{pkg_uri} has been pushed with {pkg_size} bytes")
def ensure_runtime_env_setup(runtime_env: RuntimeEnv) -> None:
"""Make sure all required packages are downloaded it local.
Necessary packages required to run the job will be downloaded
into local file system if it doesn't exist.
Args:
runtime_env (RuntimeEnv): Runtime environment of the job
"""
pkg_uri = runtime_env.working_dir_uri
if not pkg_uri:
return
pkg_file = Path(_get_local_path(pkg_uri))
# For each node, the package will only be downloaded one time
# Locking to avoid multiple process download concurrently
lock = FileLock(str(pkg_file) + ".lock")
with lock:
# TODO(yic): checksum calculation is required
if pkg_file.exists():
logger.info(f"{pkg_uri} has existed locally, skip downloading")
else:
pkg_size = fetch_package(pkg_uri, pkg_file)
logger.info(f"Downloaded {pkg_size} bytes into {pkg_file}")
sys.path.insert(0, str(pkg_file))

View file

@ -871,7 +871,6 @@ cdef class CoreWorker:
options.terminate_asyncio_thread = terminate_asyncio_thread
options.serialized_job_config = serialized_job_config
options.metrics_agent_port = metrics_agent_port
CCoreWorkerProcess.Initialize(options)
def __dealloc__(self):

View file

@ -17,6 +17,12 @@ def _internal_kv_get(key: Union[str, bytes]) -> bytes:
return ray.worker.global_worker.redis_client.hget(key, "value")
@client_mode_hook
def _internal_kv_exists(key: Union[str, bytes]) -> bool:
"""Check key exists or not."""
return ray.worker.global_worker.redis_client.hexists(key, "value")
@client_mode_hook
def _internal_kv_put(key: Union[str, bytes],
value: Union[str, bytes],

View file

@ -264,8 +264,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
(c_bool() nogil) kill_main
CCoreWorkerOptions()
(void() nogil) terminate_asyncio_thread
int metrics_agent_port
c_string serialized_job_config
int metrics_agent_port
cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess":
@staticmethod

View file

@ -13,28 +13,31 @@ class JobConfig:
code_search_path (list): A list of directories or jar files that
specify the search path for user code. This will be used as
`CLASSPATH` in Java and `PYTHONPATH` in Python.
runtime_env (dict): A path to a local directory that will be zipped
up and unpackaged in the working directory of the task/actor.
There are three important fields.
- `working_dir (str)`: A path to a local directory that will be
zipped up and unpackaged in the working directory of the
task/actor.
- `working_dir_uri (str)`: Same as `working_dir` but a URI
referencing a stored archive instead of a local path.
Takes precedence over working_dir.
- `local_modules (list[module])`: A list of local Python modules
that will be zipped up and unpacked in a directory prepended
to the sys.path of tasks/actors.
"""
def __init__(
self,
worker_env=None,
num_java_workers_per_process=1,
jvm_options=None,
code_search_path=None,
):
if worker_env is None:
self.worker_env = dict()
else:
self.worker_env = worker_env
def __init__(self,
worker_env=None,
num_java_workers_per_process=1,
jvm_options=None,
code_search_path=None,
runtime_env=None):
self.worker_env = worker_env or dict()
self.num_java_workers_per_process = num_java_workers_per_process
if jvm_options is None:
self.jvm_options = []
else:
self.jvm_options = jvm_options
if code_search_path is None:
self.code_search_path = []
else:
self.code_search_path = code_search_path
self.jvm_options = jvm_options or []
self.code_search_path = code_search_path or []
self.runtime_env = runtime_env or dict()
def serialize(self):
job_config = ray.gcs_utils.JobConfig()
@ -44,4 +47,15 @@ class JobConfig:
self.num_java_workers_per_process)
job_config.jvm_options.extend(self.jvm_options)
job_config.code_search_path.extend(self.code_search_path)
job_config.runtime_env.CopyFrom(self._get_proto_runtime())
return job_config.SerializeToString()
def get_package_uri(self):
return self.runtime_env.get("working_dir_uri")
def _get_proto_runtime(self):
from ray.core.generated.common_pb2 import RuntimeEnv
runtime_env = RuntimeEnv()
if self.get_package_uri():
runtime_env.working_dir_uri = self.get_package_uri()
return runtime_env

View file

@ -286,6 +286,12 @@ class Node:
try_to_create_directory(self._logs_dir)
old_logs_dir = os.path.join(self._logs_dir, "old")
try_to_create_directory(old_logs_dir)
# Create a directory to be used for runtime environment.
self._runtime_env_dir = os.path.join(self._session_dir,
"runtime_resources")
try_to_create_directory(self._runtime_env_dir)
import ray._private.runtime_env as runtime_env
runtime_env.PKG_DIR = self._runtime_env_dir
def get_resource_spec(self):
"""Resolve and return the current resource spec for the node."""
@ -434,6 +440,10 @@ class Node:
"""Get the path of the temporary directory."""
return self._temp_dir
def get_runtime_env_dir_path(self):
"""Get the path of the runtime env."""
return self._runtime_env_dir
def get_session_dir_path(self):
"""Get the path of the session directory."""
return self._session_dir

View file

@ -62,6 +62,7 @@ py_test_module_list(
"test_reference_counting.py",
"test_reference_counting_2.py",
"test_resource_demand_scheduler.py",
"test_runtime_env.py",
"test_scheduling.py",
"test_serialization.py",
"test_stress.py",

View file

@ -0,0 +1,112 @@
import pytest
import sys
import unittest
from ray.test_utils import run_string_as_driver
driver_script = """
import sys
sys.path.insert(0, "{working_dir}")
import test_module
import ray
job_config = ray.job_config.JobConfig(
runtime_env={runtime_env}
)
ray.init(address="{redis_address}", job_config=job_config)
@ray.remote
def run_test():
return test_module.one()
print(sum(ray.get([run_test.remote()] * 1000)))
ray.shutdown()"""
@pytest.fixture
def working_dir():
import tempfile
from pathlib import Path
with tempfile.TemporaryDirectory() as tmp_dir:
path = Path(tmp_dir)
module_path = path / "test_module"
module_path.mkdir(parents=True)
init_file = module_path / "__init__.py"
test_file = module_path / "test.py"
with test_file.open(mode="w") as f:
f.write("""
def one():
return 1
""")
with init_file.open(mode="w") as f:
f.write("""
from test_module.test import one
""")
yield tmp_dir
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
def test_single_node(ray_start_cluster_head, working_dir):
cluster = ray_start_cluster_head
redis_address = cluster.address
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
script = driver_script.format(
redis_address=redis_address,
working_dir=working_dir,
runtime_env=runtime_env)
out = run_string_as_driver(script)
assert out.strip().split()[-1] == "1000"
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
def test_two_node(two_node_cluster, working_dir):
cluster, _ = two_node_cluster
redis_address = cluster.address
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
script = driver_script.format(
redis_address=redis_address,
working_dir=working_dir,
runtime_env=runtime_env)
out = run_string_as_driver(script)
assert out.strip().split()[-1] == "1000"
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
def test_two_node_module(two_node_cluster, working_dir):
cluster, _ = two_node_cluster
redis_address = cluster.address
runtime_env = """{ "local_modules": [test_module] }"""
script = driver_script.format(
redis_address=redis_address,
working_dir=working_dir,
runtime_env=runtime_env)
print(script)
out = run_string_as_driver(script)
assert out.strip().split()[-1] == "1000"
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
def test_two_node_uri(two_node_cluster, working_dir):
cluster, _ = two_node_cluster
redis_address = cluster.address
import ray._private.runtime_env as runtime_env
import tempfile
with tempfile.NamedTemporaryFile(suffix="zip") as tmp_file:
pkg_name = runtime_env.get_project_package_name(working_dir, [])
pkg_uri = runtime_env.Protocol.PIN_GCS.value + "://" + pkg_name
runtime_env.create_project_package(working_dir, [], tmp_file.name)
runtime_env.push_package(pkg_uri, tmp_file.name)
runtime_env = f"""{{ "working_dir_uri": "{pkg_uri}" }}"""
script = driver_script.format(
redis_address=redis_address,
working_dir=working_dir,
runtime_env=runtime_env)
out = run_string_as_driver(script)
assert out.strip().split()[-1] == "1000"
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-sv", __file__]))

View file

@ -11,8 +11,8 @@ import tempfile
import threading
import time
import uuid
from inspect import signature
from inspect import signature
import numpy as np
import psutil
import ray

View file

@ -28,6 +28,7 @@ import ray.ray_constants as ray_constants
import ray.remote_function
import ray.serialization as serialization
import ray._private.services as services
import ray._private.runtime_env as runtime_env
import ray
import setproctitle
import ray.signature
@ -646,7 +647,6 @@ def init(
if "RAY_ADDRESS" in os.environ:
if address is None or address == "auto":
address = os.environ["RAY_ADDRESS"]
# Convert hostnames to numerical IP address.
if _node_ip_address is not None:
node_ip_address = services.address_to_ip(_node_ip_address)
@ -765,6 +765,11 @@ def init(
spawn_reaper=False,
connect_only=True)
if driver_mode == SCRIPT_MODE and job_config:
# Rewrite the URI. Note the package isn't uploaded to the URI until
# later in the connect
runtime_env.rewrite_working_dir_uri(job_config)
connect(
_global_node,
mode=driver_mode,
@ -1218,6 +1223,7 @@ def connect(node,
)
if job_config is None:
job_config = ray.job_config.JobConfig()
serialized_job_config = job_config.serialize()
worker.core_worker = ray._raylet.CoreWorker(
mode, node.plasma_store_socket_name, node.raylet_socket_name, job_id,
@ -1231,6 +1237,11 @@ def connect(node,
# will use glog, which is intialized in `CoreWorker`.
ray.state.state._initialize_global_state(
node.redis_address, redis_password=node.redis_password)
if mode == SCRIPT_MODE:
runtime_env.upload_runtime_env_package_if_needed(job_config)
elif mode == WORKER_MODE:
runtime_env.ensure_runtime_env_setup(
worker.core_worker.get_job_config().runtime_env)
if driver_object_store_memory is not None:
worker.core_worker.set_object_store_client_options(

View file

@ -133,6 +133,10 @@ message RayException {
string formatted_exception_string = 3;
}
message RuntimeEnv {
string working_dir_uri = 1;
}
/// The task specification encapsulates all immutable information about the
/// task. These fields are determined at submission time, converse to the
/// `TaskExecutionSpec` may change at execution time.

View file

@ -268,6 +268,8 @@ message JobConfig {
// code. This will be used as `CLASSPATH` in Java, and `PYTHONPATH` in
// Python.
repeated string code_search_path = 4;
// Runtime environment to run the code
RuntimeEnv runtime_env = 5;
}
message JobTableData {