[runtime_env] Align the interface to use multiple uris for runtime env (#15562)

This commit is contained in:
Yi Cheng 2021-04-29 08:45:25 -07:00 committed by GitHub
parent b6f593b53e
commit b5c9780a3d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 11 deletions

View file

@ -380,8 +380,8 @@ def package_exists(pkg_uri: str) -> bool:
raise NotImplementedError(f"Protocol {protocol} is not supported")
def rewrite_working_dir_uri(job_config: JobConfig) -> None:
"""Rewrite the working dir uri field in job_config.
def rewrite_runtime_env_uris(job_config: JobConfig) -> None:
"""Rewrite the uris field in job_config.
This function is used to update the runtime field in job_config. The
runtime field will be generated based on the hash of required files and
@ -391,17 +391,20 @@ def rewrite_working_dir_uri(job_config: JobConfig) -> None:
job_config (JobConfig): The job config.
"""
# For now, we only support local directory and packages
working_dir_uri = job_config.runtime_env.get("working_dir_uri")
if working_dir_uri is not None:
job_config.runtime_env["uris"] = [working_dir_uri]
return
working_dir = job_config.runtime_env.get("working_dir")
py_modules = job_config.runtime_env.get("py_modules")
excludes = job_config.runtime_env.get("excludes")
if (not job_config.runtime_env.get("working_dir_uri")) and (working_dir
or py_modules):
if working_dir or py_modules:
if excludes is None:
excludes = []
pkg_name = get_project_package_name(working_dir, py_modules, excludes)
job_config.runtime_env[
"working_dir_uri"] = Protocol.GCS.value + "://" + pkg_name
job_config.runtime_env["uris"] = [
Protocol.GCS.value + "://" + pkg_name
]
def upload_runtime_env_package_if_needed(job_config: JobConfig) -> None:

View file

@ -72,8 +72,8 @@ class JobConfig:
def get_runtime_env_uris(self):
"""Get the uris of runtime environment"""
if self.runtime_env.get("working_dir_uri"):
return [self.runtime_env.get("working_dir_uri")]
if self.runtime_env.get("uris"):
return self.runtime_env.get("uris")
return []
def _get_proto_runtime(self):

View file

@ -464,7 +464,7 @@ class Worker:
with tempfile.TemporaryDirectory() as tmp_dir:
(old_dir, runtime_env.PKG_DIR) = (runtime_env.PKG_DIR, tmp_dir)
# Generate the uri for runtime env
runtime_env.rewrite_working_dir_uri(job_config)
runtime_env.rewrite_runtime_env_uris(job_config)
init_req = ray_client_pb2.InitRequest(
job_config=pickle.dumps(job_config))
init_resp = self.data_client.Init(init_req)

View file

@ -758,7 +758,7 @@ def init(
if driver_mode == SCRIPT_MODE and job_config:
# Rewrite the URI. Note the package isn't uploaded to the URI until
# later in the connect
runtime_env.rewrite_working_dir_uri(job_config)
runtime_env.rewrite_runtime_env_uris(job_config)
connect(
_global_node,