mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[job submission] Support local py_modules in jobs (#22843)
This commit is contained in:
parent
85598d9d10
commit
c78bd809ce
5 changed files with 87 additions and 53 deletions
|
@ -18,8 +18,10 @@ except ImportError:
|
|||
from ray._private.runtime_env.packaging import (
|
||||
create_package,
|
||||
get_uri_for_directory,
|
||||
parse_uri,
|
||||
)
|
||||
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
|
||||
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
|
||||
|
||||
from ray.dashboard.modules.job.common import uri_to_http_components
|
||||
|
||||
from ray.ray_constants import DEFAULT_DASHBOARD_PORT
|
||||
|
@ -280,32 +282,40 @@ class SubmissionClient:
|
|||
package_file.unlink()
|
||||
|
||||
def _upload_package_if_needed(
|
||||
self, package_path: str, excludes: Optional[List[str]] = None
|
||||
self,
|
||||
package_path: str,
|
||||
include_parent_dir: Optional[bool] = False,
|
||||
excludes: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
package_uri = get_uri_for_directory(package_path, excludes=excludes)
|
||||
|
||||
if not self._package_exists(package_uri):
|
||||
self._upload_package(package_uri, package_path, excludes=excludes)
|
||||
self._upload_package(
|
||||
package_uri,
|
||||
package_path,
|
||||
include_parent_dir=include_parent_dir,
|
||||
excludes=excludes,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Package {package_uri} already exists, skipping upload.")
|
||||
|
||||
return package_uri
|
||||
|
||||
def _upload_working_dir_if_needed(self, runtime_env: Dict[str, Any]):
|
||||
if "working_dir" in runtime_env:
|
||||
working_dir = runtime_env["working_dir"]
|
||||
try:
|
||||
parse_uri(working_dir)
|
||||
is_uri = True
|
||||
logger.debug("working_dir is already a valid URI.")
|
||||
except ValueError:
|
||||
is_uri = False
|
||||
def _upload_fn(working_dir, excludes):
|
||||
self._upload_package_if_needed(
|
||||
working_dir, include_parent_dir=False, excludes=excludes
|
||||
)
|
||||
|
||||
if not is_uri:
|
||||
logger.debug("working_dir is not a URI, attempting to upload.")
|
||||
package_uri = self._upload_package_if_needed(
|
||||
working_dir, excludes=runtime_env.get("excludes", None)
|
||||
)
|
||||
runtime_env["working_dir"] = package_uri
|
||||
upload_working_dir_if_needed(runtime_env, upload_fn=_upload_fn)
|
||||
|
||||
def _upload_py_modules_if_needed(self, runtime_env: Dict[str, Any]):
|
||||
def _upload_fn(module_path, excludes):
|
||||
self._upload_package_if_needed(
|
||||
module_path, include_parent_dir=True, excludes=excludes
|
||||
)
|
||||
|
||||
upload_py_modules_if_needed(runtime_env, "", upload_fn=_upload_fn)
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
def get_version(self) -> str:
|
||||
|
|
|
@ -67,6 +67,7 @@ class JobSubmissionClient(SubmissionClient):
|
|||
metadata.update(self._default_metadata)
|
||||
|
||||
self._upload_working_dir_if_needed(runtime_env)
|
||||
self._upload_py_modules_if_needed(runtime_env)
|
||||
req = JobSubmitRequest(
|
||||
entrypoint=entrypoint,
|
||||
job_id=job_id,
|
||||
|
|
|
@ -108,16 +108,22 @@ def _check_job_stopped(client: JobSubmissionClient, job_id: str) -> bool:
|
|||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module", params=["no_working_dir", "local_working_dir", "s3_working_dir"]
|
||||
scope="module",
|
||||
params=[
|
||||
"no_working_dir",
|
||||
"local_working_dir",
|
||||
"s3_working_dir",
|
||||
"local_py_modules",
|
||||
],
|
||||
)
|
||||
def working_dir_option(request):
|
||||
def runtime_env_option(request):
|
||||
if request.param == "no_working_dir":
|
||||
yield {
|
||||
"runtime_env": {},
|
||||
"entrypoint": "echo hello",
|
||||
"expected_logs": "hello\n",
|
||||
}
|
||||
elif request.param == "local_working_dir":
|
||||
elif request.param == "local_working_dir" or request.param == "local_py_modules":
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
path = Path(tmp_dir)
|
||||
|
||||
|
@ -138,11 +144,23 @@ def working_dir_option(request):
|
|||
with init_file.open(mode="w") as f:
|
||||
f.write("from test_module.test import run_test\n")
|
||||
|
||||
yield {
|
||||
"runtime_env": {"working_dir": tmp_dir},
|
||||
"entrypoint": "python test.py",
|
||||
"expected_logs": "Hello from test_module!\n",
|
||||
}
|
||||
if request.param == "local_working_dir":
|
||||
yield {
|
||||
"runtime_env": {"working_dir": tmp_dir},
|
||||
"entrypoint": "python test.py",
|
||||
"expected_logs": "Hello from test_module!\n",
|
||||
}
|
||||
elif request.param == "local_py_modules":
|
||||
yield {
|
||||
"runtime_env": {"py_modules": [str(Path(tmp_dir) / "test_module")]},
|
||||
"entrypoint": (
|
||||
"python -c 'import test_module;"
|
||||
"print(test_module.run_test())'"
|
||||
),
|
||||
"expected_logs": "Hello from test_module!\n",
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unexpected pytest fixture option {request.param}")
|
||||
elif request.param == "s3_working_dir":
|
||||
yield {
|
||||
"runtime_env": {
|
||||
|
@ -155,18 +173,18 @@ def working_dir_option(request):
|
|||
assert False, f"Unrecognized option: {request.param}."
|
||||
|
||||
|
||||
def test_submit_job(job_sdk_client, working_dir_option):
|
||||
def test_submit_job(job_sdk_client, runtime_env_option):
|
||||
client = job_sdk_client
|
||||
|
||||
job_id = client.submit_job(
|
||||
entrypoint=working_dir_option["entrypoint"],
|
||||
runtime_env=working_dir_option["runtime_env"],
|
||||
entrypoint=runtime_env_option["entrypoint"],
|
||||
runtime_env=runtime_env_option["runtime_env"],
|
||||
)
|
||||
|
||||
wait_for_condition(_check_job_succeeded, client=client, job_id=job_id)
|
||||
|
||||
logs = client.get_job_logs(job_id)
|
||||
assert logs == working_dir_option["expected_logs"]
|
||||
assert logs == runtime_env_option["expected_logs"]
|
||||
|
||||
|
||||
def test_http_bad_request(job_sdk_client):
|
||||
|
@ -189,13 +207,10 @@ def test_http_bad_request(job_sdk_client):
|
|||
|
||||
def test_invalid_runtime_env(job_sdk_client):
|
||||
client = job_sdk_client
|
||||
job_id = client.submit_job(
|
||||
entrypoint="echo hello", runtime_env={"working_dir": "s3://not_a_zip"}
|
||||
)
|
||||
|
||||
wait_for_condition(_check_job_failed, client=client, job_id=job_id)
|
||||
data = client.get_job_info(job_id)
|
||||
assert "Only .zip files supported for remote URIs" in data.message
|
||||
with pytest.raises(ValueError, match="Only .zip files supported"):
|
||||
client.submit_job(
|
||||
entrypoint="echo hello", runtime_env={"working_dir": "s3://not_a_zip"}
|
||||
)
|
||||
|
||||
|
||||
def test_runtime_env_setup_failure(job_sdk_client):
|
||||
|
|
|
@ -38,6 +38,7 @@ def upload_py_modules_if_needed(
|
|||
runtime_env: Dict[str, Any],
|
||||
scratch_dir: str,
|
||||
logger: Optional[logging.Logger] = default_logger,
|
||||
upload_fn=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Uploads the entries in py_modules and replaces them with a list of URIs.
|
||||
|
||||
|
@ -82,14 +83,17 @@ def upload_py_modules_if_needed(
|
|||
# module_path is a local path.
|
||||
excludes = runtime_env.get("excludes", None)
|
||||
module_uri = get_uri_for_directory(module_path, excludes=excludes)
|
||||
upload_package_if_needed(
|
||||
module_uri,
|
||||
scratch_dir,
|
||||
module_path,
|
||||
excludes=excludes,
|
||||
include_parent_dir=True,
|
||||
logger=logger,
|
||||
)
|
||||
if upload_fn is None:
|
||||
upload_package_if_needed(
|
||||
module_uri,
|
||||
scratch_dir,
|
||||
module_path,
|
||||
excludes=excludes,
|
||||
include_parent_dir=True,
|
||||
logger=logger,
|
||||
)
|
||||
else:
|
||||
upload_fn(module_path, excludes=excludes)
|
||||
|
||||
py_modules_uris.append(module_uri)
|
||||
|
||||
|
|
|
@ -24,8 +24,9 @@ default_logger = logging.getLogger(__name__)
|
|||
|
||||
def upload_working_dir_if_needed(
|
||||
runtime_env: Dict[str, Any],
|
||||
scratch_dir: str,
|
||||
scratch_dir: Optional[str] = os.getcwd(),
|
||||
logger: Optional[logging.Logger] = default_logger,
|
||||
upload_fn=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Uploads the working_dir and replaces it with a URI.
|
||||
|
||||
|
@ -70,15 +71,18 @@ def upload_working_dir_if_needed(
|
|||
upload_package_to_gcs(pkg_uri, package_path.read_bytes())
|
||||
runtime_env["working_dir"] = pkg_uri
|
||||
return runtime_env
|
||||
if upload_fn is None:
|
||||
upload_package_if_needed(
|
||||
working_dir_uri,
|
||||
scratch_dir,
|
||||
working_dir,
|
||||
include_parent_dir=False,
|
||||
excludes=excludes,
|
||||
logger=logger,
|
||||
)
|
||||
else:
|
||||
upload_fn(working_dir, excludes=excludes)
|
||||
|
||||
upload_package_if_needed(
|
||||
working_dir_uri,
|
||||
scratch_dir,
|
||||
working_dir,
|
||||
include_parent_dir=False,
|
||||
excludes=excludes,
|
||||
logger=logger,
|
||||
)
|
||||
runtime_env["working_dir"] = working_dir_uri
|
||||
return runtime_env
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue