[job submission] Support local py_modules in jobs (#22843)

This commit is contained in:
Archit Kulkarni 2022-03-10 09:42:25 -08:00 committed by GitHub
parent 85598d9d10
commit c78bd809ce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 87 additions and 53 deletions

View file

@ -18,8 +18,10 @@ except ImportError:
from ray._private.runtime_env.packaging import ( from ray._private.runtime_env.packaging import (
create_package, create_package,
get_uri_for_directory, get_uri_for_directory,
parse_uri,
) )
from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
from ray.dashboard.modules.job.common import uri_to_http_components from ray.dashboard.modules.job.common import uri_to_http_components
from ray.ray_constants import DEFAULT_DASHBOARD_PORT from ray.ray_constants import DEFAULT_DASHBOARD_PORT
@ -280,32 +282,40 @@ class SubmissionClient:
package_file.unlink() package_file.unlink()
def _upload_package_if_needed( def _upload_package_if_needed(
self, package_path: str, excludes: Optional[List[str]] = None self,
package_path: str,
include_parent_dir: Optional[bool] = False,
excludes: Optional[List[str]] = None,
) -> str: ) -> str:
package_uri = get_uri_for_directory(package_path, excludes=excludes) package_uri = get_uri_for_directory(package_path, excludes=excludes)
if not self._package_exists(package_uri): if not self._package_exists(package_uri):
self._upload_package(package_uri, package_path, excludes=excludes) self._upload_package(
package_uri,
package_path,
include_parent_dir=include_parent_dir,
excludes=excludes,
)
else: else:
logger.info(f"Package {package_uri} already exists, skipping upload.") logger.info(f"Package {package_uri} already exists, skipping upload.")
return package_uri return package_uri
def _upload_working_dir_if_needed(self, runtime_env: Dict[str, Any]): def _upload_working_dir_if_needed(self, runtime_env: Dict[str, Any]):
if "working_dir" in runtime_env: def _upload_fn(working_dir, excludes):
working_dir = runtime_env["working_dir"] self._upload_package_if_needed(
try: working_dir, include_parent_dir=False, excludes=excludes
parse_uri(working_dir)
is_uri = True
logger.debug("working_dir is already a valid URI.")
except ValueError:
is_uri = False
if not is_uri:
logger.debug("working_dir is not a URI, attempting to upload.")
package_uri = self._upload_package_if_needed(
working_dir, excludes=runtime_env.get("excludes", None)
) )
runtime_env["working_dir"] = package_uri
upload_working_dir_if_needed(runtime_env, upload_fn=_upload_fn)
def _upload_py_modules_if_needed(self, runtime_env: Dict[str, Any]):
def _upload_fn(module_path, excludes):
self._upload_package_if_needed(
module_path, include_parent_dir=True, excludes=excludes
)
upload_py_modules_if_needed(runtime_env, "", upload_fn=_upload_fn)
@PublicAPI(stability="beta") @PublicAPI(stability="beta")
def get_version(self) -> str: def get_version(self) -> str:

View file

@ -67,6 +67,7 @@ class JobSubmissionClient(SubmissionClient):
metadata.update(self._default_metadata) metadata.update(self._default_metadata)
self._upload_working_dir_if_needed(runtime_env) self._upload_working_dir_if_needed(runtime_env)
self._upload_py_modules_if_needed(runtime_env)
req = JobSubmitRequest( req = JobSubmitRequest(
entrypoint=entrypoint, entrypoint=entrypoint,
job_id=job_id, job_id=job_id,

View file

@ -108,16 +108,22 @@ def _check_job_stopped(client: JobSubmissionClient, job_id: str) -> bool:
@pytest.fixture( @pytest.fixture(
scope="module", params=["no_working_dir", "local_working_dir", "s3_working_dir"] scope="module",
params=[
"no_working_dir",
"local_working_dir",
"s3_working_dir",
"local_py_modules",
],
) )
def working_dir_option(request): def runtime_env_option(request):
if request.param == "no_working_dir": if request.param == "no_working_dir":
yield { yield {
"runtime_env": {}, "runtime_env": {},
"entrypoint": "echo hello", "entrypoint": "echo hello",
"expected_logs": "hello\n", "expected_logs": "hello\n",
} }
elif request.param == "local_working_dir": elif request.param == "local_working_dir" or request.param == "local_py_modules":
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
path = Path(tmp_dir) path = Path(tmp_dir)
@ -138,11 +144,23 @@ def working_dir_option(request):
with init_file.open(mode="w") as f: with init_file.open(mode="w") as f:
f.write("from test_module.test import run_test\n") f.write("from test_module.test import run_test\n")
if request.param == "local_working_dir":
yield { yield {
"runtime_env": {"working_dir": tmp_dir}, "runtime_env": {"working_dir": tmp_dir},
"entrypoint": "python test.py", "entrypoint": "python test.py",
"expected_logs": "Hello from test_module!\n", "expected_logs": "Hello from test_module!\n",
} }
elif request.param == "local_py_modules":
yield {
"runtime_env": {"py_modules": [str(Path(tmp_dir) / "test_module")]},
"entrypoint": (
"python -c 'import test_module;"
"print(test_module.run_test())'"
),
"expected_logs": "Hello from test_module!\n",
}
else:
raise ValueError(f"Unexpected pytest fixture option {request.param}")
elif request.param == "s3_working_dir": elif request.param == "s3_working_dir":
yield { yield {
"runtime_env": { "runtime_env": {
@ -155,18 +173,18 @@ def working_dir_option(request):
assert False, f"Unrecognized option: {request.param}." assert False, f"Unrecognized option: {request.param}."
def test_submit_job(job_sdk_client, working_dir_option): def test_submit_job(job_sdk_client, runtime_env_option):
client = job_sdk_client client = job_sdk_client
job_id = client.submit_job( job_id = client.submit_job(
entrypoint=working_dir_option["entrypoint"], entrypoint=runtime_env_option["entrypoint"],
runtime_env=working_dir_option["runtime_env"], runtime_env=runtime_env_option["runtime_env"],
) )
wait_for_condition(_check_job_succeeded, client=client, job_id=job_id) wait_for_condition(_check_job_succeeded, client=client, job_id=job_id)
logs = client.get_job_logs(job_id) logs = client.get_job_logs(job_id)
assert logs == working_dir_option["expected_logs"] assert logs == runtime_env_option["expected_logs"]
def test_http_bad_request(job_sdk_client): def test_http_bad_request(job_sdk_client):
@ -189,14 +207,11 @@ def test_http_bad_request(job_sdk_client):
def test_invalid_runtime_env(job_sdk_client): def test_invalid_runtime_env(job_sdk_client):
client = job_sdk_client client = job_sdk_client
job_id = client.submit_job( with pytest.raises(ValueError, match="Only .zip files supported"):
client.submit_job(
entrypoint="echo hello", runtime_env={"working_dir": "s3://not_a_zip"} entrypoint="echo hello", runtime_env={"working_dir": "s3://not_a_zip"}
) )
wait_for_condition(_check_job_failed, client=client, job_id=job_id)
data = client.get_job_info(job_id)
assert "Only .zip files supported for remote URIs" in data.message
def test_runtime_env_setup_failure(job_sdk_client): def test_runtime_env_setup_failure(job_sdk_client):
client = job_sdk_client client = job_sdk_client

View file

@ -38,6 +38,7 @@ def upload_py_modules_if_needed(
runtime_env: Dict[str, Any], runtime_env: Dict[str, Any],
scratch_dir: str, scratch_dir: str,
logger: Optional[logging.Logger] = default_logger, logger: Optional[logging.Logger] = default_logger,
upload_fn=None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Uploads the entries in py_modules and replaces them with a list of URIs. """Uploads the entries in py_modules and replaces them with a list of URIs.
@ -82,6 +83,7 @@ def upload_py_modules_if_needed(
# module_path is a local path. # module_path is a local path.
excludes = runtime_env.get("excludes", None) excludes = runtime_env.get("excludes", None)
module_uri = get_uri_for_directory(module_path, excludes=excludes) module_uri = get_uri_for_directory(module_path, excludes=excludes)
if upload_fn is None:
upload_package_if_needed( upload_package_if_needed(
module_uri, module_uri,
scratch_dir, scratch_dir,
@ -90,6 +92,8 @@ def upload_py_modules_if_needed(
include_parent_dir=True, include_parent_dir=True,
logger=logger, logger=logger,
) )
else:
upload_fn(module_path, excludes=excludes)
py_modules_uris.append(module_uri) py_modules_uris.append(module_uri)

View file

@ -24,8 +24,9 @@ default_logger = logging.getLogger(__name__)
def upload_working_dir_if_needed( def upload_working_dir_if_needed(
runtime_env: Dict[str, Any], runtime_env: Dict[str, Any],
scratch_dir: str, scratch_dir: Optional[str] = os.getcwd(),
logger: Optional[logging.Logger] = default_logger, logger: Optional[logging.Logger] = default_logger,
upload_fn=None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Uploads the working_dir and replaces it with a URI. """Uploads the working_dir and replaces it with a URI.
@ -70,7 +71,7 @@ def upload_working_dir_if_needed(
upload_package_to_gcs(pkg_uri, package_path.read_bytes()) upload_package_to_gcs(pkg_uri, package_path.read_bytes())
runtime_env["working_dir"] = pkg_uri runtime_env["working_dir"] = pkg_uri
return runtime_env return runtime_env
if upload_fn is None:
upload_package_if_needed( upload_package_if_needed(
working_dir_uri, working_dir_uri,
scratch_dir, scratch_dir,
@ -79,6 +80,9 @@ def upload_working_dir_if_needed(
excludes=excludes, excludes=excludes,
logger=logger, logger=logger,
) )
else:
upload_fn(working_dir, excludes=excludes)
runtime_env["working_dir"] = working_dir_uri runtime_env["working_dir"] = working_dir_uri
return runtime_env return runtime_env