[runtime env] Put runtime env into runtime context; (#15895)

This commit is contained in:
Yi Cheng 2021-05-20 08:08:45 -07:00 committed by GitHub
parent d042aa6d73
commit 874558e813
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 50 additions and 7 deletions

View file

@ -27,7 +27,7 @@ def dispatch(request: str, args: List[str]):
"""
if request == "DEL_FILE" and len(args) == 1:
path = pathlib.Path(args[0])
if path.is_dir():
if path.is_dir() and not path.is_symlink():
shutil.rmtree(str(path))
else:
path.unlink()

View file

@ -1,5 +1,8 @@
import ray
import uuid
import json
from ray.core.generated.common_pb2 import RuntimeEnv as RuntimeEnvPB
class JobConfig:
@ -103,8 +106,8 @@ class JobConfig:
"""Return the JSON-serialized parsed runtime env dict"""
return self._parsed_runtime_env.serialize()
def _get_proto_runtime(self):
from ray.core.generated.common_pb2 import RuntimeEnv
runtime_env = RuntimeEnv()
def _get_proto_runtime(self) -> RuntimeEnvPB:
runtime_env = RuntimeEnvPB()
runtime_env.uris[:] = self.get_runtime_env_uris()
runtime_env.raw_json = json.dumps(self.runtime_env)
return runtime_env

View file

@ -151,6 +151,15 @@ class RuntimeContext(object):
"""
return self.worker.should_capture_child_tasks_in_placement_group
@property
def runtime_env(self):
"""Get the runtime env passed to job_config
Returns:
The runtime env currently using by this worker.
"""
return self.worker.runtime_env
_runtime_context = None

View file

@ -435,6 +435,21 @@ cache/
"FAILED,Test,Test,FAILED,FAILED,Test,Test,FAILED,FAILED,FAILED,FAILED"
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
@pytest.mark.parametrize("client_mode", [True, False])
def test_runtime_env_getter(ray_start_cluster_head, working_dir, client_mode):
cluster = ray_start_cluster_head
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
# Execute the following cmd in driver with runtime_env
execute_statement = """
print(ray.get_runtime_context().runtime_env["working_dir"])
"""
script = driver_script.format(**locals())
out = run_string_as_driver(script, env)
assert out.strip().split()[-1] == working_dir
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
@pytest.mark.parametrize("client_mode", [True, False])
def test_two_node_uri(two_node_cluster, working_dir, client_mode):

View file

@ -1,4 +1,5 @@
from typing import TYPE_CHECKING
import json
from typing import TYPE_CHECKING, Dict
if TYPE_CHECKING:
from ray.runtime_context import RuntimeContext
from ray import JobID
@ -41,3 +42,7 @@ class ClientWorkerPropertyAPI:
@property
def should_capture_child_tasks_in_placement_group(self) -> bool:
return self._fetch_runtime_context().capture_client_tasks
@property
def runtime_env(self) -> Dict:
return json.loads(self._fetch_runtime_context().runtime_env)

View file

@ -5,6 +5,7 @@ import base64
from collections import defaultdict
import os
import queue
import pickle
import threading
from typing import Any
@ -60,7 +61,6 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
self.ray_connect_handler = ray_connect_handler
def Init(self, request, context=None) -> ray_client_pb2.InitResponse:
import pickle
if request.job_config:
job_config = pickle.loads(request.job_config)
job_config.client_job = True
@ -158,6 +158,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
ctx.node_id = rtc.node_id.binary()
ctx.capture_client_tasks = \
rtc.should_capture_child_tasks_in_placement_group
ctx.runtime_env = json.dumps(rtc.runtime_env)
resp.runtime_context.CopyFrom(ctx)
else:
with disable_client_hook():

View file

@ -176,6 +176,12 @@ class Worker:
assert isinstance(self.current_job_id, ray.JobID)
return self._session_index, self.current_job_id
@property
def runtime_env(self):
"""Get the runtime env in json format"""
return json.loads(
self.core_worker.get_job_config().runtime_env.raw_json)
def get_serialization_context(self, job_id=None):
"""Get the SerializationContext of the job that this worker is processing.

View file

@ -134,7 +134,10 @@ message RayException {
/// The runtime environment describes all the runtime packages needed to
/// run some task or actor.
message RuntimeEnv {
repeated string uris = 1;
/// The raw json passed from user
string raw_json = 1;
/// Uris used in this runtime env
repeated string uris = 2;
}
/// The task specification encapsulates all immutable information about the

View file

@ -174,6 +174,7 @@ message ClusterInfoResponse {
bytes job_id = 1;
bytes node_id = 2;
bool capture_client_tasks = 3;
string runtime_env = 4;
}
ClusterInfoType.TypeEnum type = 1;