mirror of
https://github.com/vale981/ray
synced 2025-03-08 11:31:40 -05:00
[runtime env] Put runtime env into runtime context; (#15895)
This commit is contained in:
parent
d042aa6d73
commit
874558e813
9 changed files with 50 additions and 7 deletions
|
@ -27,7 +27,7 @@ def dispatch(request: str, args: List[str]):
|
|||
"""
|
||||
if request == "DEL_FILE" and len(args) == 1:
|
||||
path = pathlib.Path(args[0])
|
||||
if path.is_dir():
|
||||
if path.is_dir() and not path.is_symlink():
|
||||
shutil.rmtree(str(path))
|
||||
else:
|
||||
path.unlink()
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import ray
|
||||
import uuid
|
||||
import json
|
||||
|
||||
from ray.core.generated.common_pb2 import RuntimeEnv as RuntimeEnvPB
|
||||
|
||||
|
||||
class JobConfig:
|
||||
|
@ -103,8 +106,8 @@ class JobConfig:
|
|||
"""Return the JSON-serialized parsed runtime env dict"""
|
||||
return self._parsed_runtime_env.serialize()
|
||||
|
||||
def _get_proto_runtime(self):
|
||||
from ray.core.generated.common_pb2 import RuntimeEnv
|
||||
runtime_env = RuntimeEnv()
|
||||
def _get_proto_runtime(self) -> RuntimeEnvPB:
|
||||
runtime_env = RuntimeEnvPB()
|
||||
runtime_env.uris[:] = self.get_runtime_env_uris()
|
||||
runtime_env.raw_json = json.dumps(self.runtime_env)
|
||||
return runtime_env
|
||||
|
|
|
@ -151,6 +151,15 @@ class RuntimeContext(object):
|
|||
"""
|
||||
return self.worker.should_capture_child_tasks_in_placement_group
|
||||
|
||||
@property
|
||||
def runtime_env(self):
|
||||
"""Get the runtime env passed to job_config
|
||||
|
||||
Returns:
|
||||
The runtime env currently using by this worker.
|
||||
"""
|
||||
return self.worker.runtime_env
|
||||
|
||||
|
||||
_runtime_context = None
|
||||
|
||||
|
|
|
@ -435,6 +435,21 @@ cache/
|
|||
"FAILED,Test,Test,FAILED,FAILED,Test,Test,FAILED,FAILED,FAILED,FAILED"
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
||||
@pytest.mark.parametrize("client_mode", [True, False])
|
||||
def test_runtime_env_getter(ray_start_cluster_head, working_dir, client_mode):
|
||||
cluster = ray_start_cluster_head
|
||||
(address, env, PKG_DIR) = start_client_server(cluster, client_mode)
|
||||
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
|
||||
# Execute the following cmd in driver with runtime_env
|
||||
execute_statement = """
|
||||
print(ray.get_runtime_context().runtime_env["working_dir"])
|
||||
"""
|
||||
script = driver_script.format(**locals())
|
||||
out = run_string_as_driver(script, env)
|
||||
assert out.strip().split()[-1] == working_dir
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
||||
@pytest.mark.parametrize("client_mode", [True, False])
|
||||
def test_two_node_uri(two_node_cluster, working_dir, client_mode):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from typing import TYPE_CHECKING
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
if TYPE_CHECKING:
|
||||
from ray.runtime_context import RuntimeContext
|
||||
from ray import JobID
|
||||
|
@ -41,3 +42,7 @@ class ClientWorkerPropertyAPI:
|
|||
@property
|
||||
def should_capture_child_tasks_in_placement_group(self) -> bool:
|
||||
return self._fetch_runtime_context().capture_client_tasks
|
||||
|
||||
@property
|
||||
def runtime_env(self) -> Dict:
|
||||
return json.loads(self._fetch_runtime_context().runtime_env)
|
||||
|
|
|
@ -5,6 +5,7 @@ import base64
|
|||
from collections import defaultdict
|
||||
import os
|
||||
import queue
|
||||
import pickle
|
||||
|
||||
import threading
|
||||
from typing import Any
|
||||
|
@ -60,7 +61,6 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
self.ray_connect_handler = ray_connect_handler
|
||||
|
||||
def Init(self, request, context=None) -> ray_client_pb2.InitResponse:
|
||||
import pickle
|
||||
if request.job_config:
|
||||
job_config = pickle.loads(request.job_config)
|
||||
job_config.client_job = True
|
||||
|
@ -158,6 +158,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
ctx.node_id = rtc.node_id.binary()
|
||||
ctx.capture_client_tasks = \
|
||||
rtc.should_capture_child_tasks_in_placement_group
|
||||
ctx.runtime_env = json.dumps(rtc.runtime_env)
|
||||
resp.runtime_context.CopyFrom(ctx)
|
||||
else:
|
||||
with disable_client_hook():
|
||||
|
|
|
@ -176,6 +176,12 @@ class Worker:
|
|||
assert isinstance(self.current_job_id, ray.JobID)
|
||||
return self._session_index, self.current_job_id
|
||||
|
||||
@property
|
||||
def runtime_env(self):
|
||||
"""Get the runtime env in json format"""
|
||||
return json.loads(
|
||||
self.core_worker.get_job_config().runtime_env.raw_json)
|
||||
|
||||
def get_serialization_context(self, job_id=None):
|
||||
"""Get the SerializationContext of the job that this worker is processing.
|
||||
|
||||
|
|
|
@ -134,7 +134,10 @@ message RayException {
|
|||
/// The runtime environment describes all the runtime packages needed to
|
||||
/// run some task or actor.
|
||||
message RuntimeEnv {
|
||||
repeated string uris = 1;
|
||||
/// The raw json passed from user
|
||||
string raw_json = 1;
|
||||
/// Uris used in this runtime env
|
||||
repeated string uris = 2;
|
||||
}
|
||||
|
||||
/// The task specification encapsulates all immutable information about the
|
||||
|
|
|
@ -174,6 +174,7 @@ message ClusterInfoResponse {
|
|||
bytes job_id = 1;
|
||||
bytes node_id = 2;
|
||||
bool capture_client_tasks = 3;
|
||||
string runtime_env = 4;
|
||||
}
|
||||
|
||||
ClusterInfoType.TypeEnum type = 1;
|
||||
|
|
Loading…
Add table
Reference in a new issue