diff --git a/dashboard/modules/dashboard_sdk.py b/dashboard/modules/dashboard_sdk.py index d3e24727b..3eef0dabf 100644 --- a/dashboard/modules/dashboard_sdk.py +++ b/dashboard/modules/dashboard_sdk.py @@ -18,10 +18,10 @@ except ImportError: from ray._private.runtime_env.packaging import ( create_package, get_uri_for_directory, + get_uri_for_package, ) from ray._private.runtime_env.py_modules import upload_py_modules_if_needed from ray._private.runtime_env.working_dir import upload_working_dir_if_needed - from ray.dashboard.modules.job.common import uri_to_http_components from ray.ray_constants import DEFAULT_DASHBOARD_PORT @@ -259,17 +259,21 @@ class SubmissionClient: package_path: str, include_parent_dir: Optional[bool] = False, excludes: Optional[List[str]] = None, + is_file: bool = False, ) -> bool: logger.info(f"Uploading package {package_uri}.") with tempfile.TemporaryDirectory() as tmp_dir: protocol, package_name = uri_to_http_components(package_uri) - package_file = Path(tmp_dir) / package_name - create_package( - package_path, - package_file, - include_parent_dir=include_parent_dir, - excludes=excludes, - ) + if is_file: + package_file = Path(package_path) + else: + package_file = Path(tmp_dir) / package_name + create_package( + package_path, + package_file, + include_parent_dir=include_parent_dir, + excludes=excludes, + ) try: r = self._do_request( "PUT", @@ -279,15 +283,21 @@ class SubmissionClient: if r.status_code != 200: self._raise_error(r) finally: - package_file.unlink() + # If the package is a user's existing file, don't delete it. + if not is_file: + package_file.unlink() def _upload_package_if_needed( self, package_path: str, - include_parent_dir: Optional[bool] = False, + include_parent_dir: bool = False, excludes: Optional[List[str]] = None, + is_file: bool = False, ) -> str: - package_uri = get_uri_for_directory(package_path, excludes=excludes) + if is_file: + package_uri = get_uri_for_package(Path(package_path)) + else: + package_uri = get_uri_for_directory(package_path, excludes=excludes) if not self._package_exists(package_uri): self._upload_package( @@ -295,6 +305,7 @@ class SubmissionClient: package_path, include_parent_dir=include_parent_dir, excludes=excludes, + is_file=is_file, ) else: logger.info(f"Package {package_uri} already exists, skipping upload.") @@ -302,20 +313,23 @@ class SubmissionClient: return package_uri def _upload_working_dir_if_needed(self, runtime_env: Dict[str, Any]): - def _upload_fn(working_dir, excludes): + def _upload_fn(working_dir, excludes, is_file=False): self._upload_package_if_needed( - working_dir, include_parent_dir=False, excludes=excludes + working_dir, + include_parent_dir=False, + excludes=excludes, + is_file=is_file, ) upload_working_dir_if_needed(runtime_env, upload_fn=_upload_fn) def _upload_py_modules_if_needed(self, runtime_env: Dict[str, Any]): - def _upload_fn(module_path, excludes): + def _upload_fn(module_path, excludes, is_file=False): self._upload_package_if_needed( - module_path, include_parent_dir=True, excludes=excludes + module_path, include_parent_dir=True, excludes=excludes, is_file=is_file ) - upload_py_modules_if_needed(runtime_env, "", upload_fn=_upload_fn) + upload_py_modules_if_needed(runtime_env, upload_fn=_upload_fn) @PublicAPI(stability="beta") def get_version(self) -> str: diff --git a/dashboard/modules/job/common.py b/dashboard/modules/job/common.py index 20338dd73..c1dce0457 100644 --- a/dashboard/modules/job/common.py +++ b/dashboard/modules/job/common.py @@ -3,6 +3,7 @@ from enum import Enum import time from typing import Any, Dict, Optional, Tuple import pickle +from pathlib import Path from ray import ray_constants from ray.experimental.internal_kv import ( @@ -138,18 +139,17 @@ class JobInfoStorageClient: def uri_to_http_components(package_uri: str) -> Tuple[str, str]: - if not package_uri.endswith(".zip"): - raise ValueError(f"package_uri ({package_uri}) does not end in .zip") - # We need to strip the gcs:// prefix and .zip suffix to make it - # possible to pass the package_uri over HTTP. + suffix = Path(package_uri).suffix + if suffix not in {".zip", ".whl"}: + raise ValueError(f"package_uri ({package_uri}) does not end in .zip or .whl") + # We need to strip the :// prefix to make it possible to pass + # the package_uri over HTTP. protocol, package_name = parse_uri(package_uri) - return protocol.value, package_name[: -len(".zip")] + return protocol.value, package_name def http_uri_components_to_uri(protocol: str, package_name: str) -> str: - if package_name.endswith(".zip"): - raise ValueError(f"package_name ({package_name}) should not end in .zip") - return f"{protocol}://{package_name}.zip" + return f"{protocol}://{package_name}" def validate_request_type(json_data: Dict[str, Any], request_type: dataclass) -> Any: diff --git a/dashboard/modules/job/tests/pip_install_test-0.5-py3-none-any.whl b/dashboard/modules/job/tests/pip_install_test-0.5-py3-none-any.whl new file mode 100644 index 000000000..6871eb639 Binary files /dev/null and b/dashboard/modules/job/tests/pip_install_test-0.5-py3-none-any.whl differ diff --git a/dashboard/modules/job/tests/test_backwards_compatibility.py b/dashboard/modules/job/tests/test_backwards_compatibility.py index e9a27060b..465936dde 100644 --- a/dashboard/modules/job/tests/test_backwards_compatibility.py +++ b/dashboard/modules/job/tests/test_backwards_compatibility.py @@ -32,6 +32,9 @@ def _compatibility_script_path(file_name: str) -> str: class TestBackwardsCompatibility: + # TODO (architkulkarni): Reenable test after #22368 is merged, and make the + # it backwards compatibility script install the commit from #22368. + @pytest.mark.skip("#22368 breaks backwards compatibility of the package REST API.") def test_cli(self): """ 1) Create a new conda environment with ray version X installed diff --git a/dashboard/modules/job/tests/test_common.py b/dashboard/modules/job/tests/test_common.py index bb1509175..2984c9f0f 100644 --- a/dashboard/modules/job/tests/test_common.py +++ b/dashboard/modules/job/tests/test_common.py @@ -82,26 +82,25 @@ class TestJobSubmitRequestValidation: def test_uri_to_http_and_back(): - assert uri_to_http_components("gcs://hello.zip") == ("gcs", "hello") + assert uri_to_http_components("gcs://hello.zip") == ("gcs", "hello.zip") + assert uri_to_http_components("gcs://hello.whl") == ("gcs", "hello.whl") with pytest.raises(ValueError, match="'blah' is not a valid Protocol"): uri_to_http_components("blah://halb.zip") - with pytest.raises(ValueError, match="does not end in .zip"): + with pytest.raises(ValueError, match="does not end in .zip or .whl"): assert uri_to_http_components("gcs://hello.not_zip") - with pytest.raises(ValueError, match="does not end in .zip"): + with pytest.raises(ValueError, match="does not end in .zip or .whl"): assert uri_to_http_components("gcs://hello") - assert http_uri_components_to_uri("gcs", "hello") == "gcs://hello.zip" - assert http_uri_components_to_uri("blah", "halb") == "blah://halb.zip" + assert http_uri_components_to_uri("gcs", "hello.zip") == "gcs://hello.zip" + assert http_uri_components_to_uri("blah", "halb.zip") == "blah://halb.zip" + assert http_uri_components_to_uri("blah", "halb.whl") == "blah://halb.whl" - with pytest.raises(ValueError, match="should not end in .zip"): - assert http_uri_components_to_uri("gcs", "hello.zip") - - original_uri = "gcs://hello.zip" - new_uri = http_uri_components_to_uri(*uri_to_http_components(original_uri)) - assert new_uri == original_uri + for original_uri in ["gcs://hello.zip", "gcs://fasdf.whl"]: + new_uri = http_uri_components_to_uri(*uri_to_http_components(original_uri)) + assert new_uri == original_uri if __name__ == "__main__": diff --git a/dashboard/modules/job/tests/test_http_job_server.py b/dashboard/modules/job/tests/test_http_job_server.py index 5b2c20470..bb8cc653d 100644 --- a/dashboard/modules/job/tests/test_http_job_server.py +++ b/dashboard/modules/job/tests/test_http_job_server.py @@ -1,5 +1,7 @@ import logging from pathlib import Path +import os +import shutil import sys import json import yaml @@ -115,13 +117,16 @@ def _check_job_stopped(client: JobSubmissionClient, job_id: str) -> bool: "no_working_dir", "local_working_dir", "s3_working_dir", + "local_py_modules", + "working_dir_and_local_py_modules_whl", + "local_working_dir_zip", "pip_txt", "conda_yaml", "local_py_modules", ], ) def runtime_env_option(request): - driver_script = """ + import_in_task_script = """ import ray ray.init(address="auto") @@ -137,7 +142,12 @@ ray.get(f.remote()) "entrypoint": "echo hello", "expected_logs": "hello\n", } - elif request.param == "local_working_dir" or request.param == "local_py_modules": + elif request.param in { + "local_working_dir", + "local_working_dir_zip", + "local_py_modules", + "working_dir_and_local_py_modules_whl", + }: with tempfile.TemporaryDirectory() as tmp_dir: path = Path(tmp_dir) @@ -164,6 +174,15 @@ ray.get(f.remote()) "entrypoint": "python test.py", "expected_logs": "Hello from test_module!\n", } + elif request.param == "local_working_dir_zip": + local_zipped_dir = shutil.make_archive( + os.path.join(tmp_dir, "test"), "zip", tmp_dir + ) + yield { + "runtime_env": {"working_dir": local_zipped_dir}, + "entrypoint": "python test.py", + "expected_logs": "Hello from test_module!\n", + } elif request.param == "local_py_modules": yield { "runtime_env": {"py_modules": [str(Path(tmp_dir) / "test_module")]}, @@ -173,6 +192,23 @@ ray.get(f.remote()) ), "expected_logs": "Hello from test_module!\n", } + elif request.param == "working_dir_and_local_py_modules_whl": + yield { + "runtime_env": { + "working_dir": "s3://runtime-env-test/script_runtime_env.zip", + "py_modules": [ + Path(os.path.dirname(__file__)) + / "pip_install_test-0.5-py3-none-any.whl" + ], + }, + "entrypoint": ( + "python script.py && python -c 'import pip_install_test'" + ), + "expected_logs": ( + "Executing main() from script.py !!\n" + "Good job! You installed a pip module." + ), + } else: raise ValueError(f"Unexpected pytest fixture option {request.param}") elif request.param == "s3_working_dir": @@ -192,9 +228,10 @@ ray.get(f.remote()) runtime_env = {"pip": {"packages": relative_filepath, "pip_check": False}} yield { "runtime_env": runtime_env, - "entrypoint": f"python -c '{driver_script}'", - # TODO(architkulkarni): Uncomment after #22968 is fixed. - # "entrypoint": "python -c 'import pip_install_test'", + "entrypoint": ( + f"python -c 'import pip_install_test' && " + f"python -c '{import_in_task_script}'" + ), "expected_logs": "Good job! You installed a pip module.", } elif request.param == "conda_yaml": @@ -207,7 +244,7 @@ ray.get(f.remote()) yield { "runtime_env": runtime_env, - "entrypoint": f"python -c '{driver_script}'", + "entrypoint": f"python -c '{import_in_task_script}'", # TODO(architkulkarni): Uncomment after #22968 is fixed. # "entrypoint": "python -c 'import pip_install_test'", "expected_logs": "Good job! You installed a pip module.", diff --git a/doc/source/ray-core/handling-dependencies.rst b/doc/source/ray-core/handling-dependencies.rst index 903d39307..9053f521f 100644 --- a/doc/source/ray-core/handling-dependencies.rst +++ b/doc/source/ray-core/handling-dependencies.rst @@ -297,7 +297,7 @@ The ``runtime_env`` is a Python dictionary or a python class :class:`ray.runtime Note: If your local directory contains a ``.gitignore`` file, the files and paths specified therein will not be uploaded to the cluster. - ``py_modules`` (List[str|module]): Specifies Python modules to be available for import in the Ray workers. (For more ways to specify packages, see also the ``pip`` and ``conda`` fields below.) - Each entry must be either (1) a path to a local directory, (2) a URI to a remote zip file (see :ref:`remote-uris` for details), or (3) a Python module object. + Each entry must be either (1) a path to a local directory, (2) a URI to a remote zip file (see :ref:`remote-uris` for details), (3) a Python module object, or (4) a path to a local `.whl` file. - Examples of entries in the list: @@ -309,6 +309,8 @@ The ``runtime_env`` is a Python dictionary or a python class :class:`ray.runtime - ``my_module # Assumes my_module has already been imported, e.g. via 'import my_module'`` + - ``my_module.whl`` + The modules will be downloaded to each node on the cluster. Note: Setting options (1) and (3) per-task or per-actor is currently unsupported, it can only be set per-job (i.e., in ``ray.init()``). diff --git a/python/ray/_private/runtime_env/packaging.py b/python/ray/_private/runtime_env/packaging.py index 36f9d0f29..19f2348a5 100644 --- a/python/ray/_private/runtime_env/packaging.py +++ b/python/ray/_private/runtime_env/packaging.py @@ -185,6 +185,24 @@ def parse_uri(pkg_uri: str) -> Tuple[Protocol, str]: return (protocol, uri.netloc) +def is_zip_uri(uri: str) -> bool: + try: + protocol, path = parse_uri(uri) + except ValueError: + return False + + return Path(path).suffix == ".zip" + + +def is_whl_uri(uri: str) -> bool: + try: + protocol, path = parse_uri(uri) + except ValueError: + return False + + return Path(path).suffix == ".whl" + + def _get_excludes(path: Path, excludes: List[str]) -> Callable: path = path.absolute() pathspec = PathSpec.from_lines("gitwildmatch", excludes) @@ -295,10 +313,17 @@ def package_exists(pkg_uri: str) -> bool: def get_uri_for_package(package: Path) -> str: """Get a content-addressable URI from a package's contents.""" - hash_val = hashlib.md5(package.read_bytes()).hexdigest() - return "{protocol}://{pkg_name}.zip".format( - protocol=Protocol.GCS.value, pkg_name=RAY_PKG_PREFIX + hash_val - ) + if package.suffix == ".whl": + # Wheel file names include the Python package name, version + # and tags, so it is already effectively content-addressed. + return "{protocol}://{whl_filename}".format( + protocol=Protocol.GCS.value, whl_filename=package.name + ) + else: + hash_val = hashlib.md5(package.read_bytes()).hexdigest() + return "{protocol}://{pkg_name}.zip".format( + protocol=Protocol.GCS.value, pkg_name=RAY_PKG_PREFIX + hash_val + ) def get_uri_for_directory(directory: str, excludes: Optional[List[str]] = None) -> str: @@ -434,9 +459,10 @@ def download_and_unpack_package( base_directory: str, logger: Optional[logging.Logger] = default_logger, ) -> str: - """Download the package corresponding to this URI and unpack it. + """Download the package corresponding to this URI and unpack it if zipped. - Will be written to a directory named {base_directory}/{uri}. + Will be written to a file or directory named {base_directory}/{uri}. + Returns the path to this file or directory. """ pkg_file = Path(_get_local_path(base_directory, pkg_uri)) with FileLock(str(pkg_file) + ".lock"): @@ -458,13 +484,17 @@ def download_and_unpack_package( raise IOError(f"Failed to fetch URI {pkg_uri} from GCS.") code = code or b"" pkg_file.write_bytes(code) - unzip_package( - package_path=pkg_file, - target_dir=local_dir, - remove_top_level_directory=False, - unlink_zip=True, - logger=logger, - ) + + if is_zip_uri(pkg_uri): + unzip_package( + package_path=pkg_file, + target_dir=local_dir, + remove_top_level_directory=False, + unlink_zip=True, + logger=logger, + ) + else: + return str(pkg_file) elif protocol in Protocol.remote_protocols(): # Download package from remote URI tp = None diff --git a/python/ray/_private/runtime_env/py_modules.py b/python/ray/_private/runtime_env/py_modules.py index 64e3318c5..e88f48592 100644 --- a/python/ray/_private/runtime_env/py_modules.py +++ b/python/ray/_private/runtime_env/py_modules.py @@ -6,15 +6,20 @@ from pathlib import Path import asyncio from ray.experimental.internal_kv import _internal_kv_initialized +from ray._private.runtime_env.conda_utils import exec_cmd_stream_to_logger from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.runtime_env.packaging import ( download_and_unpack_package, delete_package, get_local_dir_from_uri, get_uri_for_directory, + get_uri_for_package, + package_exists, parse_uri, + is_whl_uri, Protocol, upload_package_if_needed, + upload_package_to_gcs, ) from ray._private.runtime_env.working_dir import set_pythonpath_in_context from ray._private.utils import get_directory_size_bytes @@ -37,7 +42,7 @@ def _check_is_uri(s: str) -> bool: def upload_py_modules_if_needed( runtime_env: Dict[str, Any], - scratch_dir: str, + scratch_dir: Optional[str] = os.getcwd(), logger: Optional[logging.Logger] = default_logger, upload_fn=None, ) -> Dict[str, Any]: @@ -82,19 +87,34 @@ def upload_py_modules_if_needed( module_uri = module_path else: # module_path is a local path. - excludes = runtime_env.get("excludes", None) - module_uri = get_uri_for_directory(module_path, excludes=excludes) - if upload_fn is None: - upload_package_if_needed( - module_uri, - scratch_dir, - module_path, - excludes=excludes, - include_parent_dir=True, - logger=logger, - ) + if Path(module_path).is_dir(): + excludes = runtime_env.get("excludes", None) + module_uri = get_uri_for_directory(module_path, excludes=excludes) + if upload_fn is None: + upload_package_if_needed( + module_uri, + scratch_dir, + module_path, + excludes=excludes, + include_parent_dir=True, + logger=logger, + ) + else: + upload_fn(module_path, excludes=excludes) + elif Path(module_path).suffix == ".whl": + module_uri = get_uri_for_package(Path(module_path)) + if upload_fn is None: + if not package_exists(module_uri): + upload_package_to_gcs( + module_uri, Path(module_path).read_bytes() + ) + else: + upload_fn(module_path, excludes=None, is_file=True) else: - upload_fn(module_path, excludes=excludes) + raise ValueError( + "py_modules entry must be a directory or a .whl file; " + f"got {module_path}" + ) py_modules_uris.append(module_uri) @@ -111,6 +131,9 @@ class PyModulesManager: try_to_create_directory(self._resources_dir) assert _internal_kv_initialized() + def _get_local_dir_from_uri(self, uri: str): + return get_local_dir_from_uri(uri, self._resources_dir) + def delete_uri( self, uri: str, logger: Optional[logging.Logger] = default_logger ) -> int: @@ -128,6 +151,39 @@ class PyModulesManager: def get_uris(self, runtime_env: dict) -> Optional[List[str]]: return runtime_env.py_modules() + def _download_and_install_wheel( + self, uri: str, logger: Optional[logging.Logger] = default_logger + ): + """Download and install a wheel URI, and then delete the local wheel file.""" + wheel_file = download_and_unpack_package( + uri, self._resources_dir, logger=logger + ) + module_dir = self._get_local_dir_from_uri(uri) + + pip_install_cmd = [ + "pip", + "install", + wheel_file, + f"--target={module_dir}", + ] + logger.info( + "Running py_modules wheel install command: %s", str(pip_install_cmd) + ) + try: + exit_code, output = exec_cmd_stream_to_logger(pip_install_cmd, logger) + finally: + if Path(wheel_file).exists(): + Path(wheel_file).unlink() + + if exit_code != 0: + if Path(module_dir).exists(): + Path(module_dir).unlink() + raise RuntimeError( + f"Failed to install py_modules wheel {wheel_file}" + f"to {module_dir}:\n{output}" + ) + return module_dir + async def create( self, uri: str, @@ -140,9 +196,14 @@ class PyModulesManager: # TODO(Catch-Bull): Refactor method create into an async process, and # make this method running in current loop. def _create(): - module_dir = download_and_unpack_package( - uri, self._resources_dir, logger=logger - ) + if is_whl_uri(uri): + module_dir = self._download_and_install_wheel(uri=uri, logger=logger) + + else: + module_dir = download_and_unpack_package( + uri, self._resources_dir, logger=logger + ) + return get_directory_size_bytes(module_dir) loop = asyncio.get_event_loop() @@ -159,12 +220,12 @@ class PyModulesManager: return module_dirs = [] for uri in uris: - module_dir = get_local_dir_from_uri(uri, self._resources_dir) + module_dir = self._get_local_dir_from_uri(uri) if not module_dir.exists(): raise ValueError( f"Local directory {module_dir} for URI {uri} does " "not exist on the cluster. Something may have gone wrong while " - "downloading or unpacking the py_modules files." + "downloading, unpacking or installing the py_modules files." ) module_dirs.append(str(module_dir)) set_pythonpath_in_context(os.pathsep.join(module_dirs), context) diff --git a/python/ray/tests/pip_install_test-0.5-py3-none-any.whl b/python/ray/tests/pip_install_test-0.5-py3-none-any.whl new file mode 100644 index 000000000..6871eb639 Binary files /dev/null and b/python/ray/tests/pip_install_test-0.5-py3-none-any.whl differ diff --git a/python/ray/tests/test_runtime_env_packaging.py b/python/ray/tests/test_runtime_env_packaging.py index 8761a29c2..529004f37 100644 --- a/python/ray/tests/test_runtime_env_packaging.py +++ b/python/ray/tests/test_runtime_env_packaging.py @@ -15,8 +15,11 @@ from ray._private.runtime_env.packaging import ( get_local_dir_from_uri, get_uri_for_directory, _get_excludes, + get_uri_for_package, upload_package_if_needed, parse_uri, + is_zip_uri, + is_whl_uri, Protocol, get_top_level_dir_from_compressed_package, remove_dir_from_filepaths, @@ -351,6 +354,23 @@ def test_parsing(parsing_tuple): assert package_name == parsed_package_name +def test_is_whl_uri(): + assert is_whl_uri("gcs://my-package.whl") + assert not is_whl_uri("gcs://asdf.zip") + assert not is_whl_uri("invalid_format") + + +def test_is_zip_uri(): + assert is_zip_uri("s3://my-package.zip") + assert is_zip_uri("gcs://asdf.zip") + assert not is_zip_uri("invalid_format") + assert not is_zip_uri("gcs://a.whl") + + +def test_get_uri_for_package(): + assert get_uri_for_package(Path("/tmp/my-pkg.whl")) == "gcs://my-pkg.whl" + + def test_get_local_dir_from_uri(): uri = "gcs://.zip" assert get_local_dir_from_uri(uri, "base_dir") == Path( diff --git a/python/ray/tests/test_runtime_env_working_dir.py b/python/ray/tests/test_runtime_env_working_dir.py index 7a1c1a055..a180b2f1f 100644 --- a/python/ray/tests/test_runtime_env_working_dir.py +++ b/python/ray/tests/test_runtime_env_working_dir.py @@ -87,7 +87,14 @@ def test_inherit_cluster_env_pythonpath(monkeypatch): @pytest.mark.parametrize( - "option", ["failure", "working_dir", "working_dir_zip", "py_modules"] + "option", + [ + "failure", + "working_dir", + "working_dir_zip", + "py_modules", + "working_dir_and_py_modules", + ], ) @pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.") def test_lazy_reads( @@ -121,14 +128,27 @@ def test_lazy_reads( ray.init( address, runtime_env={ - "py_modules": [str(Path(tmp_working_dir) / "test_module")] + "py_modules": [ + str(Path(tmp_working_dir) / "test_module"), + Path(os.path.dirname(__file__)) + / "pip_install_test-0.5-py3-none-any.whl", + ] }, ) - elif option == "py_modules_path": + elif option == "working_dir_and_py_modules": ray.init( address, - runtime_env={"py_modules": [Path(tmp_working_dir) / "test_module"]}, + runtime_env={ + "working_dir": tmp_working_dir, + "py_modules": [ + str(Path(tmp_working_dir) / "test_module"), + Path(os.path.dirname(__file__)) + / "pip_install_test-0.5-py3-none-any.whl", + ], + }, ) + else: + raise ValueError(f"unexpected pytest parameter {option}") call_ray_init() @@ -155,6 +175,20 @@ def test_lazy_reads( else: assert ray.get(test_import.remote()) == 1 + if option in {"py_modules", "working_dir_and_py_modules"}: + + @ray.remote + def test_py_modules_whl(): + import pip_install_test # noqa: F401 + + return True + + assert ray.get(test_py_modules_whl.remote()) + + if option in {"py_modules", "working_dir_zip"}: + # These options are not tested beyond this point, so return to save time. + return + reinit() @ray.remote @@ -164,7 +198,7 @@ def test_lazy_reads( if option == "failure": with pytest.raises(FileNotFoundError): ray.get(test_read.remote()) - elif option == "working_dir": + elif option in {"working_dir_and_py_modules", "working_dir"}: assert ray.get(test_read.remote()) == "world" reinit() @@ -187,7 +221,7 @@ def test_lazy_reads( assert ray.get(a.test_import.remote()) == 1 with pytest.raises(FileNotFoundError): assert ray.get(a.test_read.remote()) == "world" - elif option == "working_dir": + elif option in {"working_dir_and_py_modules", "working_dir"}: assert ray.get(a.test_import.remote()) == 1 assert ray.get(a.test_read.remote()) == "world" diff --git a/python/ray/tests/test_runtime_env_working_dir_2.py b/python/ray/tests/test_runtime_env_working_dir_2.py index 9419bc7cc..95c1477ee 100644 --- a/python/ray/tests/test_runtime_env_working_dir_2.py +++ b/python/ray/tests/test_runtime_env_working_dir_2.py @@ -290,12 +290,24 @@ class TestGC: elif option == "py_modules": if source != S3_PACKAGE_URI: source = str(Path(source) / "test_module") - ray.init(address, runtime_env={"py_modules": [source]}) + ray.init( + address, + runtime_env={ + "py_modules": [ + source, + Path(os.path.dirname(__file__)) + / "pip_install_test-0.5-py3-none-any.whl", + ] + }, + ) # For a local directory, the package should be in the GCS. # For an S3 URI, there should be nothing in the GCS because # it will be downloaded from S3 directly on each node. - if source == S3_PACKAGE_URI: + # In the "py_modules" case, we have specified a local wheel + # file to be uploaded to the GCS, so we do not expect the + # internal KV to be empty. + if source == S3_PACKAGE_URI and option != "py_modules": assert check_internal_kv_gced() else: assert not check_internal_kv_gced() @@ -305,13 +317,15 @@ class TestGC: def test_import(self): import test_module + if option == "py_modules": + import pip_install_test # noqa: F401 test_module.one() num_cpus = int(ray.available_resources()["CPU"]) actors = [A.remote() for _ in range(num_cpus)] ray.get([a.test_import.remote() for a in actors]) - if source == S3_PACKAGE_URI: + if source == S3_PACKAGE_URI and option != "py_modules": assert check_internal_kv_gced() else: assert not check_internal_kv_gced() @@ -349,7 +363,13 @@ class TestGC: if option == "working_dir": A = A.options(runtime_env={"working_dir": S3_PACKAGE_URI}) else: - A = A.options(runtime_env={"py_modules": [S3_PACKAGE_URI]}) + A = A.options( + runtime_env={ + "py_modules": [ + S3_PACKAGE_URI, + ] + } + ) num_cpus = int(ray.available_resources()["CPU"]) actors = [A.remote() for _ in range(num_cpus)] @@ -375,12 +395,23 @@ class TestGC: elif option == "py_modules": if source != S3_PACKAGE_URI: source = str(Path(source) / "test_module") - ray.init(address, namespace="test", runtime_env={"py_modules": [source]}) + ray.init( + address, + namespace="test", + runtime_env={ + "py_modules": [ + source, + Path(os.path.dirname(__file__)) + / "pip_install_test-0.5-py3-none-any.whl", + ] + }, + ) # For a local directory, the package should be in the GCS. # For an S3 URI, there should be nothing in the GCS because # it will be downloaded from S3 directly on each node. - if source == S3_PACKAGE_URI: + # In the "py_modules" case, a local wheel file will be in the GCS. + if source == S3_PACKAGE_URI and option != "py_modules": assert check_internal_kv_gced() else: assert not check_internal_kv_gced() @@ -390,12 +421,14 @@ class TestGC: def test_import(self): import test_module + if option == "py_modules": + import pip_install_test # noqa: F401 test_module.one() a = A.options(name="test", lifetime="detached").remote() ray.get(a.test_import.remote()) - if source == S3_PACKAGE_URI: + if source == S3_PACKAGE_URI and option != "py_modules": assert check_internal_kv_gced() else: assert not check_internal_kv_gced() @@ -405,7 +438,7 @@ class TestGC: ray.init(address, namespace="test") - if source == S3_PACKAGE_URI: + if source == S3_PACKAGE_URI and option != "py_modules": assert check_internal_kv_gced() else: assert not check_internal_kv_gced()