mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[runtime_env] Align the interface to use multiple uris for runtime env (#15562)
This commit is contained in:
parent
b6f593b53e
commit
b5c9780a3d
4 changed files with 14 additions and 11 deletions
|
@ -380,8 +380,8 @@ def package_exists(pkg_uri: str) -> bool:
|
|||
raise NotImplementedError(f"Protocol {protocol} is not supported")
|
||||
|
||||
|
||||
def rewrite_working_dir_uri(job_config: JobConfig) -> None:
|
||||
"""Rewrite the working dir uri field in job_config.
|
||||
def rewrite_runtime_env_uris(job_config: JobConfig) -> None:
|
||||
"""Rewrite the uris field in job_config.
|
||||
|
||||
This function is used to update the runtime field in job_config. The
|
||||
runtime field will be generated based on the hash of required files and
|
||||
|
@ -391,17 +391,20 @@ def rewrite_working_dir_uri(job_config: JobConfig) -> None:
|
|||
job_config (JobConfig): The job config.
|
||||
"""
|
||||
# For now, we only support local directory and packages
|
||||
working_dir_uri = job_config.runtime_env.get("working_dir_uri")
|
||||
if working_dir_uri is not None:
|
||||
job_config.runtime_env["uris"] = [working_dir_uri]
|
||||
return
|
||||
working_dir = job_config.runtime_env.get("working_dir")
|
||||
py_modules = job_config.runtime_env.get("py_modules")
|
||||
excludes = job_config.runtime_env.get("excludes")
|
||||
|
||||
if (not job_config.runtime_env.get("working_dir_uri")) and (working_dir
|
||||
or py_modules):
|
||||
if working_dir or py_modules:
|
||||
if excludes is None:
|
||||
excludes = []
|
||||
pkg_name = get_project_package_name(working_dir, py_modules, excludes)
|
||||
job_config.runtime_env[
|
||||
"working_dir_uri"] = Protocol.GCS.value + "://" + pkg_name
|
||||
job_config.runtime_env["uris"] = [
|
||||
Protocol.GCS.value + "://" + pkg_name
|
||||
]
|
||||
|
||||
|
||||
def upload_runtime_env_package_if_needed(job_config: JobConfig) -> None:
|
||||
|
|
|
@ -72,8 +72,8 @@ class JobConfig:
|
|||
|
||||
def get_runtime_env_uris(self):
|
||||
"""Get the uris of runtime environment"""
|
||||
if self.runtime_env.get("working_dir_uri"):
|
||||
return [self.runtime_env.get("working_dir_uri")]
|
||||
if self.runtime_env.get("uris"):
|
||||
return self.runtime_env.get("uris")
|
||||
return []
|
||||
|
||||
def _get_proto_runtime(self):
|
||||
|
|
|
@ -464,7 +464,7 @@ class Worker:
|
|||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
(old_dir, runtime_env.PKG_DIR) = (runtime_env.PKG_DIR, tmp_dir)
|
||||
# Generate the uri for runtime env
|
||||
runtime_env.rewrite_working_dir_uri(job_config)
|
||||
runtime_env.rewrite_runtime_env_uris(job_config)
|
||||
init_req = ray_client_pb2.InitRequest(
|
||||
job_config=pickle.dumps(job_config))
|
||||
init_resp = self.data_client.Init(init_req)
|
||||
|
|
|
@ -758,7 +758,7 @@ def init(
|
|||
if driver_mode == SCRIPT_MODE and job_config:
|
||||
# Rewrite the URI. Note the package isn't uploaded to the URI until
|
||||
# later in the connect
|
||||
runtime_env.rewrite_working_dir_uri(job_config)
|
||||
runtime_env.rewrite_runtime_env_uris(job_config)
|
||||
|
||||
connect(
|
||||
_global_node,
|
||||
|
|
Loading…
Add table
Reference in a new issue