diff --git a/python/ray/_private/runtime_env.py b/python/ray/_private/runtime_env.py index 725338fa8..3e6301711 100644 --- a/python/ray/_private/runtime_env.py +++ b/python/ray/_private/runtime_env.py @@ -380,8 +380,8 @@ def package_exists(pkg_uri: str) -> bool: raise NotImplementedError(f"Protocol {protocol} is not supported") -def rewrite_working_dir_uri(job_config: JobConfig) -> None: - """Rewrite the working dir uri field in job_config. +def rewrite_runtime_env_uris(job_config: JobConfig) -> None: + """Rewrite the uris field in job_config. This function is used to update the runtime field in job_config. The runtime field will be generated based on the hash of required files and @@ -391,17 +391,20 @@ def rewrite_working_dir_uri(job_config: JobConfig) -> None: job_config (JobConfig): The job config. """ # For now, we only support local directory and packages + working_dir_uri = job_config.runtime_env.get("working_dir_uri") + if working_dir_uri is not None: + job_config.runtime_env["uris"] = [working_dir_uri] + return working_dir = job_config.runtime_env.get("working_dir") py_modules = job_config.runtime_env.get("py_modules") excludes = job_config.runtime_env.get("excludes") - - if (not job_config.runtime_env.get("working_dir_uri")) and (working_dir - or py_modules): + if working_dir or py_modules: if excludes is None: excludes = [] pkg_name = get_project_package_name(working_dir, py_modules, excludes) - job_config.runtime_env[ - "working_dir_uri"] = Protocol.GCS.value + "://" + pkg_name + job_config.runtime_env["uris"] = [ + Protocol.GCS.value + "://" + pkg_name + ] def upload_runtime_env_package_if_needed(job_config: JobConfig) -> None: diff --git a/python/ray/job_config.py b/python/ray/job_config.py index e9fe07d4a..21025df53 100644 --- a/python/ray/job_config.py +++ b/python/ray/job_config.py @@ -72,8 +72,8 @@ class JobConfig: def get_runtime_env_uris(self): """Get the uris of runtime environment""" - if self.runtime_env.get("working_dir_uri"): - return [self.runtime_env.get("working_dir_uri")] + if self.runtime_env.get("uris"): + return self.runtime_env.get("uris") return [] def _get_proto_runtime(self): diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 527fd3c98..ba813365d 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -464,7 +464,7 @@ class Worker: with tempfile.TemporaryDirectory() as tmp_dir: (old_dir, runtime_env.PKG_DIR) = (runtime_env.PKG_DIR, tmp_dir) # Generate the uri for runtime env - runtime_env.rewrite_working_dir_uri(job_config) + runtime_env.rewrite_runtime_env_uris(job_config) init_req = ray_client_pb2.InitRequest( job_config=pickle.dumps(job_config)) init_resp = self.data_client.Init(init_req) diff --git a/python/ray/worker.py b/python/ray/worker.py index 53b427d8f..0c9fa13c3 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -758,7 +758,7 @@ def init( if driver_mode == SCRIPT_MODE and job_config: # Rewrite the URI. Note the package isn't uploaded to the URI until # later in the connect - runtime_env.rewrite_working_dir_uri(job_config) + runtime_env.rewrite_runtime_env_uris(job_config) connect( _global_node,