mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[runtime_env] Support working_dir and py_modules from HTTPS and Google Cloud Storage (#20280)
This commit is contained in:
parent
6c3bad52b6
commit
c0aeb4a236
10 changed files with 374 additions and 65 deletions
|
@ -83,7 +83,7 @@ def working_dir_option(request):
|
|||
elif request.param == "s3_working_dir":
|
||||
yield {
|
||||
"runtime_env": {
|
||||
"working_dir": "s3://runtime-env-test/script.zip",
|
||||
"working_dir": "s3://runtime-env-test/script_runtime_env.zip",
|
||||
},
|
||||
"entrypoint": "python script.py",
|
||||
"expected_logs": "Executing main() from script.py !!\n",
|
||||
|
@ -124,7 +124,7 @@ def test_http_bad_request(job_sdk_client):
|
|||
|
||||
# 500 - HTTPInternalServerError
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Only .zip files supported for S3 URIs"):
|
||||
RuntimeError, match="Only .zip files supported for remote URIs"):
|
||||
r = client.submit_job(
|
||||
entrypoint="echo hello",
|
||||
runtime_env={"working_dir": "s3://does_not_exist"})
|
||||
|
|
|
@ -115,7 +115,9 @@ class TestShellScriptExecution:
|
|||
def test_submit_with_s3_runtime_env(self, job_manager):
|
||||
job_id = job_manager.submit_job(
|
||||
entrypoint="python script.py",
|
||||
runtime_env={"working_dir": "s3://runtime-env-test/script.zip"})
|
||||
runtime_env={
|
||||
"working_dir": "s3://runtime-env-test/script_runtime_env.zip"
|
||||
})
|
||||
|
||||
wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
|
|
|
@ -39,8 +39,18 @@ class Protocol(Enum):
|
|||
return self
|
||||
|
||||
GCS = "gcs", "For packages dynamically uploaded and managed by the GCS."
|
||||
S3 = "s3", "Remote s3 path, assumes everything packed in one zip file."
|
||||
CONDA = "conda", "For conda environments installed locally on each node."
|
||||
HTTPS = "https", ("Remote https path, "
|
||||
"assumes everything packed in one zip file.")
|
||||
S3 = "s3", "Remote s3 path, assumes everything packed in one zip file."
|
||||
GS = "gs", ("Remote google storage path, "
|
||||
"assumes everything packed in one zip file.")
|
||||
|
||||
@classmethod
|
||||
def remote_protocols(cls):
|
||||
# Returns a lit of protocols that support remote storage
|
||||
# These protocols should only be used with paths that end in ".zip"
|
||||
return [cls.HTTPS, cls.S3, cls.GS]
|
||||
|
||||
|
||||
def _xor_bytes(left: bytes, right: bytes) -> bytes:
|
||||
|
@ -122,19 +132,49 @@ def parse_uri(pkg_uri: str) -> Tuple[Protocol, str]:
|
|||
netloc='_ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip'
|
||||
)
|
||||
-> ("gcs", "_ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip")
|
||||
For S3 URIs, the bucket and path will have '/' replaced with '_'.
|
||||
For HTTPS URIs, the netloc will have '.' replaced with '_', and
|
||||
the path will have '/' replaced with '_'. The package name will be the
|
||||
adjusted path with 'https_' prepended.
|
||||
urlparse(
|
||||
"https://github.com/shrekris-anyscale/test_module/archive/HEAD.zip"
|
||||
)
|
||||
-> ParseResult(
|
||||
scheme='https',
|
||||
netloc='github.com',
|
||||
path='/shrekris-anyscale/test_repo/archive/HEAD.zip'
|
||||
)
|
||||
-> ("https",
|
||||
"github_com_shrekris-anyscale_test_repo_archive_HEAD.zip")
|
||||
For S3 URIs, the bucket and path will have '/' replaced with '_'. The
|
||||
package name will be the adjusted path with 's3_' prepended.
|
||||
urlparse("s3://bucket/dir/file.zip")
|
||||
-> ParseResult(
|
||||
scheme='s3',
|
||||
netloc='bucket',
|
||||
path='/dir/file.zip'
|
||||
)
|
||||
-> ("s3", "s3_bucket_dir_file.zip")
|
||||
-> ("s3", "bucket_dir_file.zip")
|
||||
For GS URIs, the path will have '/' replaced with '_'. The package name
|
||||
will be the adjusted path with 'gs_' prepended.
|
||||
urlparse("gs://public-runtime-env-test/test_module.zip")
|
||||
-> ParseResult(
|
||||
scheme='gs',
|
||||
netloc='public-runtime-env-test',
|
||||
path='/test_module.zip'
|
||||
)
|
||||
-> ("gs",
|
||||
"gs_public-runtime-env-test_test_module.zip")
|
||||
"""
|
||||
uri = urlparse(pkg_uri)
|
||||
protocol = Protocol(uri.scheme)
|
||||
if protocol == Protocol.S3:
|
||||
return (protocol, f"s3_{uri.netloc}_" + "_".join(uri.path.split("/")))
|
||||
if protocol == Protocol.S3 or protocol == Protocol.GS:
|
||||
return (protocol,
|
||||
f"{protocol.value}_{uri.netloc}{uri.path.replace('/', '_')}")
|
||||
elif protocol == Protocol.HTTPS:
|
||||
return (
|
||||
protocol,
|
||||
f"https_{uri.netloc.replace('.', '_')}{uri.path.replace('/', '_')}"
|
||||
)
|
||||
else:
|
||||
return (protocol, uri.netloc)
|
||||
|
||||
|
@ -290,8 +330,9 @@ def upload_package_to_gcs(pkg_uri: str, pkg_bytes: bytes):
|
|||
protocol, pkg_name = parse_uri(pkg_uri)
|
||||
if protocol == Protocol.GCS:
|
||||
_store_package_in_gcs(pkg_uri, pkg_bytes)
|
||||
elif protocol == Protocol.S3:
|
||||
raise RuntimeError("push_package should not be called with s3 path.")
|
||||
elif protocol in Protocol.remote_protocols():
|
||||
raise RuntimeError(
|
||||
"push_package should not be called with remote path.")
|
||||
else:
|
||||
raise NotImplementedError(f"Protocol {protocol} is not supported")
|
||||
|
||||
|
@ -393,34 +434,152 @@ def download_and_unpack_package(
|
|||
raise IOError(f"Failed to fetch URI {pkg_uri} from GCS.")
|
||||
code = code or b""
|
||||
pkg_file.write_bytes(code)
|
||||
elif protocol == Protocol.S3:
|
||||
# Download package from S3.
|
||||
try:
|
||||
from smart_open import open
|
||||
import boto3
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You must `pip install smart_open` and "
|
||||
"`pip install boto3` to fetch URIs in s3 "
|
||||
"bucket.")
|
||||
unzip_package(
|
||||
package_path=pkg_file,
|
||||
target_dir=local_dir,
|
||||
remove_top_level_directory=False,
|
||||
unlink_zip=True,
|
||||
logger=logger)
|
||||
elif protocol in Protocol.remote_protocols():
|
||||
# Download package from remote URI
|
||||
tp = None
|
||||
|
||||
if protocol == Protocol.S3:
|
||||
try:
|
||||
from smart_open import open
|
||||
import boto3
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You must `pip install smart_open` and "
|
||||
"`pip install boto3` to fetch URIs in s3 "
|
||||
"bucket.")
|
||||
tp = {"client": boto3.client("s3")}
|
||||
elif protocol == Protocol.GS:
|
||||
try:
|
||||
from smart_open import open
|
||||
from google.cloud import storage # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You must `pip install smart_open` and "
|
||||
"`pip install google-cloud-storage` "
|
||||
"to fetch URIs in Google Cloud Storage bucket.")
|
||||
else:
|
||||
try:
|
||||
from smart_open import open
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You must `pip install smart_open` "
|
||||
f"to fetch {protocol.value.upper()} URIs.")
|
||||
|
||||
tp = {"client": boto3.client("s3")}
|
||||
with open(pkg_uri, "rb", transport_params=tp) as package_zip:
|
||||
with open(pkg_file, "wb") as fin:
|
||||
fin.write(package_zip.read())
|
||||
|
||||
unzip_package(
|
||||
package_path=pkg_file,
|
||||
target_dir=local_dir,
|
||||
remove_top_level_directory=True,
|
||||
unlink_zip=True,
|
||||
logger=logger)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Protocol {protocol} is not supported")
|
||||
|
||||
os.mkdir(local_dir)
|
||||
logger.debug(f"Unpacking {pkg_file} to {local_dir}")
|
||||
with ZipFile(str(pkg_file), "r") as zip_ref:
|
||||
zip_ref.extractall(local_dir)
|
||||
pkg_file.unlink()
|
||||
|
||||
return str(local_dir)
|
||||
|
||||
|
||||
def get_top_level_dir_from_compressed_package(package_path: str):
|
||||
"""
|
||||
If compressed package at package_path contains a single top-level
|
||||
directory, returns the name of the top-level directory. Otherwise,
|
||||
returns None.
|
||||
"""
|
||||
|
||||
package_zip = ZipFile(package_path, "r")
|
||||
top_level_directory = None
|
||||
|
||||
for file_name in package_zip.namelist():
|
||||
if top_level_directory is None:
|
||||
# Cache the top_level_directory name when checking
|
||||
# the first file in the zipped package
|
||||
if "/" in file_name:
|
||||
top_level_directory = file_name.split("/")[0]
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
# Confirm that all other files
|
||||
# belong to the same top_level_directory
|
||||
if "/" not in file_name or \
|
||||
file_name.split("/")[0] != top_level_directory:
|
||||
return None
|
||||
|
||||
return top_level_directory
|
||||
|
||||
|
||||
def extract_file_and_remove_top_level_dir(base_dir: str, fname: str,
|
||||
zip_ref: ZipFile):
|
||||
"""
|
||||
Extracts fname file from zip_ref zip file, removes the top level directory
|
||||
from fname's file path, and stores fname in the base_dir.
|
||||
"""
|
||||
|
||||
fname_without_top_level_dir = "/".join(fname.split("/")[1:])
|
||||
|
||||
# If this condition is false, it means there was no top-level directory,
|
||||
# so we do nothing
|
||||
if fname_without_top_level_dir:
|
||||
zip_ref.extract(fname, base_dir)
|
||||
os.rename(
|
||||
os.path.join(base_dir, fname),
|
||||
os.path.join(base_dir, fname_without_top_level_dir))
|
||||
|
||||
|
||||
def unzip_package(package_path: str,
|
||||
target_dir: str,
|
||||
remove_top_level_directory: bool,
|
||||
unlink_zip: bool,
|
||||
logger: Optional[logging.Logger] = default_logger):
|
||||
"""
|
||||
Unzip the compressed package contained at package_path and store the
|
||||
contents in target_dir. If remove_top_level_directory is True, the function
|
||||
will automatically remove the top_level_directory and store the contents
|
||||
directly in target_dir. If unlink_zip is True, the function will unlink the
|
||||
zip file stored at package_path.
|
||||
"""
|
||||
try:
|
||||
os.mkdir(target_dir)
|
||||
except FileExistsError:
|
||||
logger.info(f"Directory at {target_dir} already exists")
|
||||
|
||||
logger.debug(f"Unpacking {package_path} to {target_dir}")
|
||||
|
||||
if remove_top_level_directory:
|
||||
top_level_directory = get_top_level_dir_from_compressed_package(
|
||||
package_path)
|
||||
if top_level_directory is None:
|
||||
raise ValueError("The package at package_path must contain "
|
||||
"a single top level directory. Make sure there "
|
||||
"are no hidden files at the same level as the "
|
||||
"top level directory.")
|
||||
with ZipFile(str(package_path), "r") as zip_ref:
|
||||
for fname in zip_ref.namelist():
|
||||
extract_file_and_remove_top_level_dir(
|
||||
base_dir=target_dir, fname=fname, zip_ref=zip_ref)
|
||||
|
||||
# Remove now-empty top_level_directory and any empty subdirectories
|
||||
# left over from extract_file_and_remove_top_level_dir operations
|
||||
leftover_top_level_directory = os.path.join(
|
||||
target_dir, top_level_directory)
|
||||
if os.path.isdir(leftover_top_level_directory):
|
||||
shutil.rmtree(leftover_top_level_directory)
|
||||
else:
|
||||
with ZipFile(str(package_path), "r") as zip_ref:
|
||||
zip_ref.extractall(target_dir)
|
||||
|
||||
if unlink_zip:
|
||||
Path(package_path).unlink()
|
||||
|
||||
|
||||
def delete_package(pkg_uri: str, base_directory: str) -> bool:
|
||||
"""Deletes a specific URI from the local filesystem.
|
||||
|
||||
|
|
|
@ -18,8 +18,8 @@ def _check_is_uri(s: str) -> bool:
|
|||
except ValueError:
|
||||
protocol, path = None, None
|
||||
|
||||
if protocol == Protocol.S3 and not path.endswith(".zip"):
|
||||
raise ValueError("Only .zip files supported for S3 URIs.")
|
||||
if protocol in Protocol.remote_protocols() and not path.endswith(".zip"):
|
||||
raise ValueError("Only .zip files supported for remote URIs.")
|
||||
|
||||
return protocol is not None
|
||||
|
||||
|
|
|
@ -40,8 +40,8 @@ def validate_uri(uri: str):
|
|||
"be dynamically uploaded is only supported at the job level "
|
||||
"(i.e., passed to `ray.init`).")
|
||||
|
||||
if protocol == Protocol.S3 and not path.endswith(".zip"):
|
||||
raise ValueError("Only .zip files supported for S3 URIs.")
|
||||
if protocol in Protocol.remote_protocols() and not path.endswith(".zip"):
|
||||
raise ValueError("Only .zip files supported for remote URIs.")
|
||||
|
||||
|
||||
def parse_and_validate_py_modules(py_modules: List[str]) -> List[str]:
|
||||
|
|
|
@ -35,8 +35,9 @@ def upload_working_dir_if_needed(
|
|||
protocol, path = None, None
|
||||
|
||||
if protocol is not None:
|
||||
if protocol == Protocol.S3 and not path.endswith(".zip"):
|
||||
raise ValueError("Only .zip files supported for S3 URIs.")
|
||||
if protocol in Protocol.remote_protocols(
|
||||
) and not path.endswith(".zip"):
|
||||
raise ValueError("Only .zip files supported for remote URIs.")
|
||||
return runtime_env
|
||||
|
||||
excludes = runtime_env.get("excludes", None)
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
import random
|
||||
from shutil import rmtree
|
||||
from shutil import rmtree, make_archive
|
||||
import string
|
||||
import sys
|
||||
import tempfile
|
||||
from filecmp import dircmp
|
||||
from zipfile import ZipFile
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
@ -13,7 +15,12 @@ from ray.experimental.internal_kv import (_internal_kv_del,
|
|||
_internal_kv_exists)
|
||||
from ray._private.runtime_env.packaging import (
|
||||
_dir_travel, get_uri_for_directory, _get_excludes,
|
||||
upload_package_if_needed)
|
||||
upload_package_if_needed, parse_uri, Protocol,
|
||||
get_top_level_dir_from_compressed_package,
|
||||
extract_file_and_remove_top_level_dir, unzip_package)
|
||||
|
||||
TOP_LEVEL_DIR_NAME = "top_level"
|
||||
ARCHIVE_NAME = "archive.zip"
|
||||
|
||||
|
||||
def random_string(size: int = 10):
|
||||
|
@ -42,6 +49,37 @@ def random_dir():
|
|||
yield tmp_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_zip_file_without_top_level_dir(random_dir):
|
||||
path = Path(random_dir)
|
||||
make_archive(path / ARCHIVE_NAME[:ARCHIVE_NAME.rfind(".")], "zip", path)
|
||||
yield str(path / ARCHIVE_NAME)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_zip_file_with_top_level_dir():
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
path = Path(tmp_dir)
|
||||
top_level_dir = path / TOP_LEVEL_DIR_NAME
|
||||
top_level_dir.mkdir(parents=True)
|
||||
next_level_dir = top_level_dir
|
||||
for _ in range(10):
|
||||
p1 = next_level_dir / random_string(10)
|
||||
with p1.open("w") as f1:
|
||||
f1.write(random_string(100))
|
||||
p2 = next_level_dir / random_string(10)
|
||||
with p2.open("w") as f2:
|
||||
f2.write(random_string(200))
|
||||
dir1 = next_level_dir / random_string(15)
|
||||
dir1.mkdir(parents=True)
|
||||
dir2 = next_level_dir / random_string(15)
|
||||
dir2.mkdir(parents=True)
|
||||
next_level_dir = dir2
|
||||
make_archive(path / ARCHIVE_NAME[:ARCHIVE_NAME.rfind(".")], "zip",
|
||||
path, TOP_LEVEL_DIR_NAME)
|
||||
yield str(path / ARCHIVE_NAME)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.")
|
||||
class TestGetURIForDirectory:
|
||||
def test_invalid_directory(self):
|
||||
|
@ -115,6 +153,77 @@ class TestUploadPackageIfNeeded:
|
|||
assert uploaded
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.")
|
||||
class TestGetTopLevelDirFromCompressedPackage:
|
||||
def test_get_top_level_valid(self, random_zip_file_with_top_level_dir):
|
||||
top_level_dir_name = get_top_level_dir_from_compressed_package(
|
||||
str(random_zip_file_with_top_level_dir))
|
||||
assert top_level_dir_name == TOP_LEVEL_DIR_NAME
|
||||
|
||||
def test_get_top_level_invalid(self,
|
||||
random_zip_file_without_top_level_dir):
|
||||
top_level_dir_name = get_top_level_dir_from_compressed_package(
|
||||
str(random_zip_file_without_top_level_dir))
|
||||
assert top_level_dir_name is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.")
|
||||
class TestExtractFileAndRemoveTopLevelDir:
|
||||
def test_valid_extraction(self, random_zip_file_with_top_level_dir):
|
||||
archive_path = random_zip_file_with_top_level_dir
|
||||
tmp_path = archive_path[:archive_path.rfind("/")]
|
||||
rmtree(os.path.join(tmp_path, TOP_LEVEL_DIR_NAME))
|
||||
with ZipFile(archive_path, "r") as zf:
|
||||
for fname in zf.namelist():
|
||||
extract_file_and_remove_top_level_dir(
|
||||
base_dir=tmp_path, fname=fname, zip_ref=zf)
|
||||
rmtree(os.path.join(tmp_path, TOP_LEVEL_DIR_NAME))
|
||||
with ZipFile(archive_path, "r") as zf:
|
||||
zf.extractall(tmp_path)
|
||||
dcmp = dircmp(tmp_path, f"{tmp_path}/{TOP_LEVEL_DIR_NAME}")
|
||||
|
||||
# Since this test uses the tmp_path as the target directory, and since
|
||||
# the tmp_path also contains the zip file and the top level directory,
|
||||
# make sure that the only difference between the tmp_path's contents
|
||||
# and the top level directory's contents are the zip file and the top
|
||||
# level directory itself. This implies that all files have been
|
||||
# extracted from the top level directory and moved into the tmp_path.
|
||||
assert set(dcmp.left_only) == {ARCHIVE_NAME, TOP_LEVEL_DIR_NAME}
|
||||
|
||||
# Make sure that all the subdirectories and files have been moved to
|
||||
# the target directory
|
||||
assert len(dcmp.right_only) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.")
|
||||
@pytest.mark.parametrize("remove_top_level_directory", [False, True])
|
||||
@pytest.mark.parametrize("unlink_zip", [False, True])
|
||||
def test_unzip_package(random_zip_file_with_top_level_dir,
|
||||
remove_top_level_directory, unlink_zip):
|
||||
archive_path = random_zip_file_with_top_level_dir
|
||||
tmp_path = archive_path[:archive_path.rfind("/")]
|
||||
tmp_subdir = f"{tmp_path}/{TOP_LEVEL_DIR_NAME}_tmp"
|
||||
unzip_package(
|
||||
package_path=archive_path,
|
||||
target_dir=tmp_subdir,
|
||||
remove_top_level_directory=remove_top_level_directory,
|
||||
unlink_zip=unlink_zip)
|
||||
|
||||
dcmp = None
|
||||
if remove_top_level_directory:
|
||||
dcmp = dircmp(f"{tmp_subdir}", f"{tmp_path}/{TOP_LEVEL_DIR_NAME}")
|
||||
else:
|
||||
dcmp = dircmp(f"{tmp_subdir}/{TOP_LEVEL_DIR_NAME}",
|
||||
f"{tmp_path}/{TOP_LEVEL_DIR_NAME}")
|
||||
assert len(dcmp.left_only) == 0
|
||||
assert len(dcmp.right_only) == 0
|
||||
|
||||
if unlink_zip:
|
||||
assert not Path(archive_path).is_file()
|
||||
else:
|
||||
assert Path(archive_path).is_file()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.")
|
||||
def test_travel():
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
@ -177,5 +286,19 @@ def test_travel():
|
|||
assert dir_paths == visited_dir_paths
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"parsing_tuple",
|
||||
[("gcs://file.zip", Protocol.GCS, "file.zip"),
|
||||
("s3://bucket/file.zip", Protocol.S3, "s3_bucket_file.zip"),
|
||||
("https://test.com/file.zip", Protocol.HTTPS, "https_test_com_file.zip"),
|
||||
("gs://bucket/file.zip", Protocol.GS, "gs_bucket_file.zip")])
|
||||
def test_parsing(parsing_tuple):
|
||||
uri, protocol, package_name = parsing_tuple
|
||||
parsed_protocol, parsed_package_name = parse_uri(uri)
|
||||
|
||||
assert protocol == parsed_protocol
|
||||
assert package_name == parsed_package_name
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-sv", __file__]))
|
||||
|
|
|
@ -67,14 +67,23 @@ class TestValidateWorkingDir:
|
|||
with pytest.raises(TypeError):
|
||||
parse_and_validate_working_dir(1)
|
||||
|
||||
def test_validate_s3_invalid_extension(self):
|
||||
with pytest.raises(
|
||||
ValueError, match="Only .zip files supported for S3 URIs."):
|
||||
parse_and_validate_working_dir("s3://bucket/file")
|
||||
def test_validate_remote_invalid_extensions(self):
|
||||
for uri in [
|
||||
"https://some_domain.com/path/file", "s3://bucket/file",
|
||||
"gs://bucket/file"
|
||||
]:
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Only .zip files supported for remote URIs."):
|
||||
parse_and_validate_working_dir(uri)
|
||||
|
||||
def test_validate_s3_valid_input(self):
|
||||
working_dir = parse_and_validate_working_dir("s3://bucket/file.zip")
|
||||
assert working_dir == "s3://bucket/file.zip"
|
||||
def test_validate_remote_valid_input(self):
|
||||
for uri in [
|
||||
"https://some_domain.com/path/file.zip",
|
||||
"s3://bucket/file.zip", "gs://bucket/file.zip"
|
||||
]:
|
||||
working_dir = parse_and_validate_working_dir(uri)
|
||||
assert working_dir == uri
|
||||
|
||||
|
||||
class TestValidatePyModules:
|
||||
|
@ -90,14 +99,23 @@ class TestValidatePyModules:
|
|||
with pytest.raises(TypeError):
|
||||
parse_and_validate_py_modules([1])
|
||||
|
||||
def test_validate_s3_invalid_extension(self):
|
||||
def test_validate_remote_invalid_extension(self):
|
||||
uris = [
|
||||
"https://some_domain.com/path/file", "s3://bucket/file",
|
||||
"gs://bucket/file"
|
||||
]
|
||||
with pytest.raises(
|
||||
ValueError, match="Only .zip files supported for S3 URIs."):
|
||||
parse_and_validate_py_modules(["s3://bucket/file"])
|
||||
ValueError,
|
||||
match="Only .zip files supported for remote URIs."):
|
||||
parse_and_validate_py_modules(uris)
|
||||
|
||||
def test_validate_s3_valid_input(self):
|
||||
py_modules = parse_and_validate_py_modules(["s3://bucket/file.zip"])
|
||||
assert py_modules == ["s3://bucket/file.zip"]
|
||||
def test_validate_remote_valid_input(self):
|
||||
uris = [
|
||||
"https://some_domain.com/path/file.zip", "s3://bucket/file.zip",
|
||||
"gs://bucket/file.zip"
|
||||
]
|
||||
py_modules = parse_and_validate_py_modules(uris)
|
||||
assert py_modules == uris
|
||||
|
||||
|
||||
class TestValidateExcludes:
|
||||
|
|
|
@ -15,7 +15,11 @@ import ray
|
|||
# This package contains a subdirectory called `test_module`.
|
||||
# Calling `test_module.one()` should return `2`.
|
||||
# If you find that confusing, take it up with @jiaodong...
|
||||
S3_PACKAGE_URI = "s3://runtime-env-test/remote_runtime_env.zip"
|
||||
HTTPS_PACKAGE_URI = ("https://github.com/shrekris-anyscale/"
|
||||
"test_module/archive/HEAD.zip")
|
||||
S3_PACKAGE_URI = "s3://runtime-env-test/test_runtime_env.zip"
|
||||
GS_PACKAGE_URI = "gs://public-runtime-env-test/test_module.zip"
|
||||
REMOTE_URIS = [HTTPS_PACKAGE_URI, S3_PACKAGE_URI]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
@ -234,13 +238,14 @@ def test_input_validation(start_cluster, option: str):
|
|||
|
||||
ray.shutdown()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
if option == "working_dir":
|
||||
ray.init(address, runtime_env={"working_dir": "s3://no_dot_zip"})
|
||||
else:
|
||||
ray.init(address, runtime_env={"py_modules": ["s3://no_dot_zip"]})
|
||||
for uri in ["https://no_dot_zip", "s3://no_dot_zip", "gs://no_dot_zip"]:
|
||||
with pytest.raises(ValueError):
|
||||
if option == "working_dir":
|
||||
ray.init(address, runtime_env={"working_dir": uri})
|
||||
else:
|
||||
ray.init(address, runtime_env={"py_modules": [uri]})
|
||||
|
||||
ray.shutdown()
|
||||
ray.shutdown()
|
||||
|
||||
if option == "py_modules":
|
||||
with pytest.raises(TypeError):
|
||||
|
@ -249,12 +254,13 @@ def test_input_validation(start_cluster, option: str):
|
|||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.")
|
||||
@pytest.mark.parametrize("remote_uri", REMOTE_URIS)
|
||||
@pytest.mark.parametrize("option", ["failure", "working_dir", "py_modules"])
|
||||
@pytest.mark.parametrize("per_task_actor", [True, False])
|
||||
def test_s3_uri(start_cluster, option, per_task_actor):
|
||||
def test_remote_package_uri(start_cluster, remote_uri, option, per_task_actor):
|
||||
"""Tests the case where we lazily read files or import inside a task/actor.
|
||||
|
||||
In this case, the files come from an S3 URI.
|
||||
In this case, the files come from a remote location.
|
||||
|
||||
This tests both that this fails *without* the working_dir and that it
|
||||
passes with it.
|
||||
|
@ -262,9 +268,9 @@ def test_s3_uri(start_cluster, option, per_task_actor):
|
|||
cluster, address = start_cluster
|
||||
|
||||
if option == "working_dir":
|
||||
env = {"working_dir": S3_PACKAGE_URI}
|
||||
env = {"working_dir": remote_uri}
|
||||
elif option == "py_modules":
|
||||
env = {"py_modules": [S3_PACKAGE_URI]}
|
||||
env = {"py_modules": [remote_uri]}
|
||||
|
||||
if option == "failure" or per_task_actor:
|
||||
ray.init(address)
|
||||
|
@ -305,7 +311,7 @@ def test_s3_uri(start_cluster, option, per_task_actor):
|
|||
@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.")
|
||||
@pytest.mark.parametrize("option", ["working_dir", "py_modules"])
|
||||
@pytest.mark.parametrize(
|
||||
"source", [S3_PACKAGE_URI, lazy_fixture("tmp_working_dir")])
|
||||
"source", [*REMOTE_URIS, lazy_fixture("tmp_working_dir")])
|
||||
def test_multi_node(start_cluster, option: str, source: str):
|
||||
"""Tests that the working_dir is propagated across multi-node clusters."""
|
||||
NUM_NODES = 3
|
||||
|
@ -317,7 +323,7 @@ def test_multi_node(start_cluster, option: str, source: str):
|
|||
if option == "working_dir":
|
||||
ray.init(address, runtime_env={"working_dir": source})
|
||||
elif option == "py_modules":
|
||||
if source != S3_PACKAGE_URI:
|
||||
if source not in REMOTE_URIS:
|
||||
source = str(Path(source) / "test_module")
|
||||
ray.init(address, runtime_env={"py_modules": [source]})
|
||||
|
||||
|
@ -489,7 +495,7 @@ cache/
|
|||
@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.")
|
||||
@pytest.mark.parametrize(
|
||||
"working_dir",
|
||||
[S3_PACKAGE_URI, lazy_fixture("tmp_working_dir")])
|
||||
[*REMOTE_URIS, lazy_fixture("tmp_working_dir")])
|
||||
def test_runtime_context(start_cluster, working_dir):
|
||||
"""Tests that the working_dir is propagated in the runtime_context."""
|
||||
cluster, address = start_cluster
|
||||
|
@ -497,8 +503,8 @@ def test_runtime_context(start_cluster, working_dir):
|
|||
|
||||
def check():
|
||||
wd = ray.get_runtime_context().runtime_env["working_dir"]
|
||||
if working_dir == S3_PACKAGE_URI:
|
||||
assert wd == S3_PACKAGE_URI
|
||||
if working_dir in REMOTE_URIS:
|
||||
assert wd == working_dir
|
||||
else:
|
||||
assert wd.startswith("gcs://_ray_pkg_")
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from ray._private.runtime_env.packaging import GCS_STORAGE_MAX_SIZE
|
|||
# This package contains a subdirectory called `test_module`.
|
||||
# Calling `test_module.one()` should return `2`.
|
||||
# If you find that confusing, take it up with @jiaodong...
|
||||
S3_PACKAGE_URI = "s3://runtime-env-test/remote_runtime_env.zip"
|
||||
S3_PACKAGE_URI = "s3://runtime-env-test/test_runtime_env.zip"
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
Loading…
Add table
Reference in a new issue