[runtime env] Test common failure scenarios (#25977)

Tests the following failure scenarios:
- Fail to upload data in `ray.init()` (`working_dir`, `py_modules`)
- Eager install fails in `ray.init()` for some other reason (bad `pip` package)
- Fail to download data from GCS (`working_dir`)

Improves the following error message cases:
- Return RuntimeEnvSetupError on failure to upload working_dir or py_modules
- Return RuntimeEnvSetupError on failure to download files from GCS during runtime env setup

Not covered in this PR:
- RPC to agent fails (This is extremely rare because the Raylet and agent are on the same node.)
- Agent is not started or dead (We don't need to worry about this because the Raylet fate shares with the agent.)

The approach is to use environment variables to induce failures in various places.  The alternative would be to refactor the packaging code to use dependency injection for the Internal KV client so that we can pass in a fake. I'm not sure how much of an improvement this would be.  I think we'd still have to set an environment variable to pass in the fake client, because these are essentially e2e tests of `ray.init()` and we don't have an API to pass it in.
This commit is contained in:
Archit Kulkarni 2022-08-15 09:35:56 -07:00 committed by GitHub
parent eb37bb857c
commit 058c239cf1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 279 additions and 83 deletions

View file

@ -464,8 +464,6 @@ if __name__ == "__main__":
disable_metrics_collection=args.disable_metrics_collection, disable_metrics_collection=args.disable_metrics_collection,
agent_id=args.agent_id, agent_id=args.agent_id,
) )
if os.environ.get("_RAY_AGENT_FAILING"):
raise Exception("Failure injection failure.")
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(agent.run()) loop.run_until_complete(agent.run())

View file

@ -315,7 +315,7 @@ def test_timeout(job_sdk_client):
wait_for_condition(_check_job_failed, client=client, job_id=job_id, timeout=10) wait_for_condition(_check_job_failed, client=client, job_id=job_id, timeout=10)
data = client.get_job_info(job_id) data = client.get_job_info(job_id)
assert "Failed to setup runtime environment" in data.message assert "Failed to set up runtime environment" in data.message
assert "Timeout" in data.message assert "Timeout" in data.message
assert "consider increasing `setup_timeout_seconds`" in data.message assert "consider increasing `setup_timeout_seconds`" in data.message
@ -375,7 +375,7 @@ def test_runtime_env_setup_failure(job_sdk_client):
wait_for_condition(_check_job_failed, client=client, job_id=job_id) wait_for_condition(_check_job_failed, client=client, job_id=job_id)
data = client.get_job_info(job_id) data = client.get_job_info(job_id)
assert "Failed to setup runtime environment" in data.message assert "Failed to set up runtime environment" in data.message
def test_submit_job_with_exception_in_driver(job_sdk_client): def test_submit_job_with_exception_in_driver(job_sdk_client):

View file

@ -12,6 +12,7 @@ from ray._private.runtime_env.packaging import (
) )
from ray._private.runtime_env.plugin import RuntimeEnvPlugin from ray._private.runtime_env.plugin import RuntimeEnvPlugin
from ray._private.utils import get_directory_size_bytes, try_to_create_directory from ray._private.utils import get_directory_size_bytes, try_to_create_directory
from ray.exceptions import RuntimeEnvSetupError
default_logger = logging.getLogger(__name__) default_logger = logging.getLogger(__name__)
@ -49,9 +50,14 @@ class JavaJarsPlugin(RuntimeEnvPlugin):
self, uri: str, logger: Optional[logging.Logger] = default_logger self, uri: str, logger: Optional[logging.Logger] = default_logger
): ):
"""Download a jar URI.""" """Download a jar URI."""
jar_file = await download_and_unpack_package( try:
uri, self._resources_dir, self._gcs_aio_client, logger=logger jar_file = await download_and_unpack_package(
) uri, self._resources_dir, self._gcs_aio_client, logger=logger
)
except Exception as e:
raise RuntimeEnvSetupError(
"Failed to download jar file: {}".format(e)
) from e
module_dir = self._get_local_dir_from_uri(uri) module_dir = self._get_local_dir_from_uri(uri)
logger.debug(f"Succeeded to download jar file {jar_file} .") logger.debug(f"Succeeded to download jar file {jar_file} .")
return module_dir return module_dir
@ -68,9 +74,14 @@ class JavaJarsPlugin(RuntimeEnvPlugin):
if is_jar_uri(uri): if is_jar_uri(uri):
module_dir = await self._download_jars(uri=uri, logger=logger) module_dir = await self._download_jars(uri=uri, logger=logger)
else: else:
module_dir = await download_and_unpack_package( try:
uri, self._resources_dir, self._gcs_aio_client, logger=logger module_dir = await download_and_unpack_package(
) uri, self._resources_dir, self._gcs_aio_client, logger=logger
)
except Exception as e:
raise RuntimeEnvSetupError(
"Failed to download jar file: {}".format(e)
) from e
return get_directory_size_bytes(module_dir) return get_directory_size_bytes(module_dir)

View file

@ -34,6 +34,13 @@ GCS_STORAGE_MAX_SIZE = int(
) )
RAY_PKG_PREFIX = "_ray_pkg_" RAY_PKG_PREFIX = "_ray_pkg_"
RAY_RUNTIME_ENV_FAIL_UPLOAD_FOR_TESTING_ENV_VAR = (
"RAY_RUNTIME_ENV_FAIL_UPLOAD_FOR_TESTING"
)
RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING_ENV_VAR = (
"RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING"
)
def _mib_string(num_bytes: float) -> str: def _mib_string(num_bytes: float) -> str:
size_mib = float(num_bytes / 1024 ** 2) size_mib = float(num_bytes / 1024 ** 2)
@ -341,6 +348,10 @@ def _store_package_in_gcs(
logger.info(f"Pushing file package '{pkg_uri}' ({size_str}) to Ray cluster...") logger.info(f"Pushing file package '{pkg_uri}' ({size_str}) to Ray cluster...")
try: try:
if os.environ.get(RAY_RUNTIME_ENV_FAIL_UPLOAD_FOR_TESTING_ENV_VAR):
raise RuntimeError(
"Simulating failure to upload package for testing purposes."
)
_internal_kv_put(pkg_uri, data) _internal_kv_put(pkg_uri, data)
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
@ -469,13 +480,26 @@ def get_uri_for_directory(directory: str, excludes: Optional[List[str]] = None)
) )
def upload_package_to_gcs(pkg_uri: str, pkg_bytes: bytes): def upload_package_to_gcs(pkg_uri: str, pkg_bytes: bytes) -> None:
"""Upload a local package to GCS.
Args:
pkg_uri: The URI of the package, e.g. gcs://my_package.zip
pkg_bytes: The data to be uploaded.
Raises:
RuntimeError: If the upload fails.
ValueError: If the pkg_uri is a remote path or if the data's
size exceeds GCS_STORAGE_MAX_SIZE.
NotImplementedError: If the protocol of the URI is not supported.
"""
protocol, pkg_name = parse_uri(pkg_uri) protocol, pkg_name = parse_uri(pkg_uri)
if protocol == Protocol.GCS: if protocol == Protocol.GCS:
_store_package_in_gcs(pkg_uri, pkg_bytes) _store_package_in_gcs(pkg_uri, pkg_bytes)
elif protocol in Protocol.remote_protocols(): elif protocol in Protocol.remote_protocols():
raise RuntimeError( raise ValueError(
"upload_package_to_gcs should not be called with remote path." "upload_package_to_gcs should not be called with a remote path."
) )
else: else:
raise NotImplementedError(f"Protocol {protocol} is not supported") raise NotImplementedError(f"Protocol {protocol} is not supported")
@ -527,6 +551,12 @@ def upload_package_if_needed(
include_parent_dir: If true, includes the top-level directory as a include_parent_dir: If true, includes the top-level directory as a
directory inside the zip file. directory inside the zip file.
excludes: List specifying files to exclude. excludes: List specifying files to exclude.
Raises:
RuntimeError: If the upload fails.
ValueError: If the pkg_uri is a remote path or if the data's
size exceeds GCS_STORAGE_MAX_SIZE.
NotImplementedError: If the protocol of the URI is not supported.
""" """
if excludes is None: if excludes is None:
excludes = [] excludes = []
@ -572,7 +602,26 @@ async def download_and_unpack_package(
Will be written to a file or directory named {base_directory}/{uri}. Will be written to a file or directory named {base_directory}/{uri}.
Returns the path to this file or directory. Returns the path to this file or directory.
Args:
pkg_uri: URI of the package to download.
base_directory: Directory to use as the parent directory of the target
directory for the unpacked files.
gcs_aio_client: Client to use for downloading from the GCS.
logger: The logger to use.
Returns:
Path to the local directory containing the unpacked package files.
Raises:
IOError: If the download fails.
ImportError: If smart_open is not installed and a remote URI is used.
NotImplementedError: If the protocol of the URI is not supported.
""" """
if os.environ.get(RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING_ENV_VAR):
raise IOError("Failed to download package. (Simulated failure for testing)")
pkg_file = Path(_get_local_path(base_directory, pkg_uri)) pkg_file = Path(_get_local_path(base_directory, pkg_uri))
with FileLock(str(pkg_file) + ".lock"): with FileLock(str(pkg_file) + ".lock"):
if logger is None: if logger is None:
@ -592,7 +641,13 @@ async def download_and_unpack_package(
pkg_uri.encode(), namespace=None, timeout=None pkg_uri.encode(), namespace=None, timeout=None
) )
if code is None: if code is None:
raise IOError(f"Failed to fetch URI {pkg_uri} from GCS.") raise IOError(
f"Failed to download runtime_env file package {pkg_uri} "
"from the GCS to the Ray worker node. The package may "
"have prematurely been deleted from the GCS due to a "
"problem with Ray. Try re-running the statement or "
"restarting the Ray cluster."
)
code = code or b"" code = code or b""
pkg_file.write_bytes(code) pkg_file.write_bytes(code)

View file

@ -23,6 +23,7 @@ from ray._private.runtime_env.packaging import (
from ray._private.runtime_env.plugin import RuntimeEnvPlugin from ray._private.runtime_env.plugin import RuntimeEnvPlugin
from ray._private.runtime_env.working_dir import set_pythonpath_in_context from ray._private.runtime_env.working_dir import set_pythonpath_in_context
from ray._private.utils import get_directory_size_bytes, try_to_create_directory from ray._private.utils import get_directory_size_bytes, try_to_create_directory
from ray.exceptions import RuntimeEnvSetupError
default_logger = logging.getLogger(__name__) default_logger = logging.getLogger(__name__)
@ -90,23 +91,35 @@ def upload_py_modules_if_needed(
excludes = runtime_env.get("excludes", None) excludes = runtime_env.get("excludes", None)
module_uri = get_uri_for_directory(module_path, excludes=excludes) module_uri = get_uri_for_directory(module_path, excludes=excludes)
if upload_fn is None: if upload_fn is None:
upload_package_if_needed( try:
module_uri, upload_package_if_needed(
scratch_dir, module_uri,
module_path, scratch_dir,
excludes=excludes, module_path,
include_parent_dir=True, excludes=excludes,
logger=logger, include_parent_dir=True,
) logger=logger,
)
except Exception as e:
raise RuntimeEnvSetupError(
f"Failed to upload module {module_path} to the Ray "
f"cluster: {e}"
) from e
else: else:
upload_fn(module_path, excludes=excludes) upload_fn(module_path, excludes=excludes)
elif Path(module_path).suffix == ".whl": elif Path(module_path).suffix == ".whl":
module_uri = get_uri_for_package(Path(module_path)) module_uri = get_uri_for_package(Path(module_path))
if upload_fn is None: if upload_fn is None:
if not package_exists(module_uri): if not package_exists(module_uri):
upload_package_to_gcs( try:
module_uri, Path(module_path).read_bytes() upload_package_to_gcs(
) module_uri, Path(module_path).read_bytes()
)
except Exception as e:
raise RuntimeEnvSetupError(
f"Failed to upload {module_path} to the Ray "
f"cluster: {e}"
) from e
else: else:
upload_fn(module_path, excludes=None, is_file=True) upload_fn(module_path, excludes=None, is_file=True)
else: else:

View file

@ -18,6 +18,7 @@ from ray._private.runtime_env.packaging import (
) )
from ray._private.runtime_env.plugin import RuntimeEnvPlugin from ray._private.runtime_env.plugin import RuntimeEnvPlugin
from ray._private.utils import get_directory_size_bytes, try_to_create_directory from ray._private.utils import get_directory_size_bytes, try_to_create_directory
from ray.exceptions import RuntimeEnvSetupError
default_logger = logging.getLogger(__name__) default_logger = logging.getLogger(__name__)
@ -68,18 +69,28 @@ def upload_working_dir_if_needed(
) )
pkg_uri = get_uri_for_package(package_path) pkg_uri = get_uri_for_package(package_path)
upload_package_to_gcs(pkg_uri, package_path.read_bytes()) try:
upload_package_to_gcs(pkg_uri, package_path.read_bytes())
except Exception as e:
raise RuntimeEnvSetupError(
f"Failed to upload package {package_path} to the Ray cluster: {e}"
) from e
runtime_env["working_dir"] = pkg_uri runtime_env["working_dir"] = pkg_uri
return runtime_env return runtime_env
if upload_fn is None: if upload_fn is None:
upload_package_if_needed( try:
working_dir_uri, upload_package_if_needed(
scratch_dir, working_dir_uri,
working_dir, scratch_dir,
include_parent_dir=False, working_dir,
excludes=excludes, include_parent_dir=False,
logger=logger, excludes=excludes,
) logger=logger,
)
except Exception as e:
raise RuntimeEnvSetupError(
f"Failed to upload working_dir {working_dir} to the Ray cluster: {e}"
) from e
else: else:
upload_fn(working_dir, excludes=excludes) upload_fn(working_dir, excludes=excludes)

View file

@ -604,7 +604,7 @@ class RuntimeEnvSetupError(RayError):
self.error_message = error_message self.error_message = error_message
def __str__(self): def __str__(self):
msgs = ["Failed to setup runtime environment."] msgs = ["Failed to set up runtime environment."]
if self.error_message: if self.error_message:
msgs.append(self.error_message) msgs.append(self.error_message)
return "\n".join(msgs) return "\n".join(msgs)

View file

@ -307,6 +307,7 @@ py_test_module_list(
files = [ files = [
"test_runtime_env.py", "test_runtime_env.py",
"test_runtime_env_2.py", "test_runtime_env_2.py",
"test_runtime_env_failure.py",
"test_runtime_env_working_dir.py", "test_runtime_env_working_dir.py",
"test_runtime_env_working_dir_2.py", "test_runtime_env_working_dir_2.py",
"test_runtime_env_working_dir_3.py", "test_runtime_env_working_dir_3.py",

View file

@ -749,7 +749,7 @@ def test_wrapped_actor_creation(call_ray_start):
def test_init_requires_no_resources(call_ray_start, use_client): def test_init_requires_no_resources(call_ray_start, use_client):
import ray import ray
if use_client: if not use_client:
address = call_ray_start address = call_ray_start
ray.init(address) ray.init(address)
else: else:

View file

@ -150,13 +150,6 @@ def test_dashboard(shutdown_only):
) )
@pytest.fixture
def set_agent_failure_env_var():
os.environ["_RAY_AGENT_FAILING"] = "1"
yield
del os.environ["_RAY_AGENT_FAILING"]
conflict_port = 34567 conflict_port = 34567

View file

@ -289,44 +289,6 @@ def test_failed_job_env_no_hang(shutdown_only, runtime_env_class):
ray.get(f.remote()) ray.get(f.remote())
@pytest.fixture
def set_agent_failure_env_var():
os.environ["_RAY_AGENT_FAILING"] = "1"
yield
del os.environ["_RAY_AGENT_FAILING"]
# TODO(SongGuyang): Fail the agent which is in different node from driver.
@pytest.mark.skip(
reason="Agent failure will lead to raylet failure and driver failure."
)
@pytest.mark.parametrize("runtime_env_class", [dict, RuntimeEnv])
def test_runtime_env_broken(
set_agent_failure_env_var, runtime_env_class, ray_start_cluster_head
):
@ray.remote
class A:
def ready(self):
pass
@ray.remote
def f():
pass
runtime_env = runtime_env_class(env_vars={"TF_WARNINGS": "none"})
"""
Test task raises an exception.
"""
with pytest.raises(ray.exceptions.LocalRayletDiedError):
ray.get(f.options(runtime_env=runtime_env).remote())
"""
Test actor task raises an exception.
"""
a = A.options(runtime_env=runtime_env).remote()
with pytest.raises(ray.exceptions.RayActorError):
ray.get(a.ready.remote())
class TestURICache: class TestURICache:
def test_zero_cache_size(self): def test_zero_cache_size(self):
uris_to_sizes = {"5": 5, "3": 3} uris_to_sizes = {"5": 5, "3": 3}

View file

@ -0,0 +1,152 @@
import os
from unittest import mock
import pytest
from ray._private.runtime_env.packaging import (
RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING_ENV_VAR,
RAY_RUNTIME_ENV_FAIL_UPLOAD_FOR_TESTING_ENV_VAR,
)
import ray
from ray.exceptions import RuntimeEnvSetupError
def using_ray_client(address):
return address.startswith("ray://")
# Set scope to "class" to force this to run before start_cluster, whose scope
# is "function". We need this environment variable to be set before Ray is started.
@pytest.fixture(scope="class")
def fail_download():
with mock.patch.dict(
os.environ,
{
RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING_ENV_VAR: "1",
},
):
print("RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING enabled.")
yield
@pytest.fixture
def client_connection_timeout_1s():
"""Lower Ray Client ray.init() timeout to 1 second (default 30s) to save time"""
with mock.patch.dict(
os.environ,
{
"RAY_CLIENT_RECONNECT_GRACE_PERIOD": "1",
},
):
yield
class TestRuntimeEnvFailure:
@pytest.mark.parametrize("plugin", ["working_dir", "py_modules"])
def test_fail_upload(
self, tmpdir, monkeypatch, start_cluster, plugin, client_connection_timeout_1s
):
"""Simulate failing to upload the working_dir to the GCS.
Test that we raise an exception and don't hang.
"""
monkeypatch.setenv(RAY_RUNTIME_ENV_FAIL_UPLOAD_FOR_TESTING_ENV_VAR, "1")
_, address = start_cluster
if plugin == "working_dir":
runtime_env = {"working_dir": str(tmpdir)}
else:
runtime_env = {"py_modules": [str(tmpdir)]}
with pytest.raises(RuntimeEnvSetupError) as e:
ray.init(address, runtime_env=runtime_env)
assert "Failed to upload" in str(e.value)
@pytest.mark.parametrize("plugin", ["working_dir", "py_modules"])
def test_fail_download(
self,
tmpdir,
monkeypatch,
fail_download,
start_cluster,
plugin,
client_connection_timeout_1s,
):
"""Simulate failing to download the working_dir from the GCS.
Test that we raise an exception and don't hang.
"""
_, address = start_cluster
if plugin == "working_dir":
runtime_env = {"working_dir": str(tmpdir)}
else:
runtime_env = {"py_modules": [str(tmpdir)]}
def init_ray():
ray.init(address, runtime_env=runtime_env)
if using_ray_client(address):
# Fails at ray.init() because the working_dir is downloaded for the
# Ray Client server.
with pytest.raises(ConnectionAbortedError) as e:
init_ray()
assert "Failed to download" in str(e.value)
else:
init_ray()
# TODO(architkulkarni): After #25972 is resolved, we should raise an
# exception in ray.init(). Until then, we need to `ray.get` a task
# to raise the exception.
@ray.remote
def f():
pass
with pytest.raises(RuntimeEnvSetupError) as e:
ray.get(f.remote())
assert "Failed to download" in str(e.value)
def test_eager_install_fail(
self, tmpdir, monkeypatch, start_cluster, client_connection_timeout_1s
):
"""Simulate failing to install a runtime_env in ray.init().
By default eager_install is set to True. We should make sure
the driver fails to start if the eager_install fails.
"""
_, address = start_cluster
def init_ray():
# Simulate failure using a nonexistent `pip` package. This will pass
# validation but fail during installation.
ray.init(address, runtime_env={"pip": ["ray-nonexistent-pkg"]})
if using_ray_client(address):
# Fails at ray.init() because the `pip` package is downloaded for the
# Ray Client server.
with pytest.raises(ConnectionAbortedError) as e:
init_ray()
assert "No matching distribution found for ray-nonexistent-pkg" in str(
e.value
)
else:
init_ray()
# TODO(architkulkarni): After #25972 is resolved, we should raise an
# exception in ray.init(). Until then, we need to `ray.get` a task
# to raise the exception.
@ray.remote
def f():
pass
with pytest.raises(RuntimeEnvSetupError) as e:
ray.get(f.remote())
assert "No matching distribution found for ray-nonexistent-pkg" in str(
e.value
)
if __name__ == "__main__":
import sys
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
sys.exit(pytest.main(["-sv", __file__]))

View file

@ -11,7 +11,7 @@ import ray
from ray._private.test_utils import wait_for_condition, chdir, check_local_files_gced from ray._private.test_utils import wait_for_condition, chdir, check_local_files_gced
from ray._private.runtime_env import RAY_WORKER_DEV_EXCLUDES from ray._private.runtime_env import RAY_WORKER_DEV_EXCLUDES
from ray._private.runtime_env.packaging import GCS_STORAGE_MAX_SIZE from ray._private.runtime_env.packaging import GCS_STORAGE_MAX_SIZE
from ray.exceptions import GetTimeoutError from ray.exceptions import GetTimeoutError, RuntimeEnvSetupError
# This test requires you have AWS credentials set up (any AWS credentials will # This test requires you have AWS credentials set up (any AWS credentials will
# do, this test only accesses a public bucket). # do, this test only accesses a public bucket).
@ -109,7 +109,7 @@ def test_large_file_error(shutdown_only, option: str):
with open("test_file_2", "wb") as f: with open("test_file_2", "wb") as f:
f.write(os.urandom(size)) f.write(os.urandom(size))
with pytest.raises(ValueError): with pytest.raises(RuntimeEnvSetupError):
if option == "working_dir": if option == "working_dir":
ray.init(runtime_env={"working_dir": "."}) ray.init(runtime_env={"working_dir": "."})
else: else: