From becffc6cef73d41a81af9240a9431e29c69c5903 Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Thu, 14 Mar 2019 23:42:57 +0800 Subject: [PATCH] Fix checkpoint crash for actor creation task. (#4327) * Fix checkpoint crash for actor creation task. * Lint * Move test to test_actor.py * Revert unused code in test_failure.py * Refine test according to Raul's suggestion. --- python/ray/function_manager.py | 6 +++-- python/ray/tests/test_actor.py | 39 ++++++++++++++++++++++++++++++++ python/ray/tests/test_failure.py | 14 +----------- python/ray/tests/utils.py | 13 +++++++++++ 4 files changed, 57 insertions(+), 15 deletions(-) diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index bc63a994f..7651f0d62 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -783,8 +783,10 @@ class FunctionActorManager(object): method_returns = method(actor, *args) except Exception as e: # Save the checkpoint before allowing the method exception - # to be thrown. - if isinstance(actor, ray.actor.Checkpointable): + # to be thrown, but don't save the checkpoint for actor + # creation task. + if (isinstance(actor, ray.actor.Checkpointable) + and self._worker.actor_task_counter != 1): self._save_and_log_checkpoint(actor) raise e else: diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index fd5ce7bfb..2dfc35350 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -16,6 +16,10 @@ import ray import ray.ray_constants as ray_constants import ray.tests.utils import ray.tests.cluster_utils +from ray.tests.utils import ( + wait_for_errors, + relevant_errors, +) @pytest.fixture @@ -2563,3 +2567,38 @@ def test_bad_checkpointable_actor_class(): class BadCheckpointableActor(ray.actor.Checkpointable): def should_checkpoint(self, checkpoint_context): return True + + +def test_init_exception_in_checkpointable_actor(ray_start_regular, + ray_checkpointable_actor_cls): + # This test is similar to test_failure.py::test_failed_actor_init. + # This test is used to guarantee that checkpointable actor does not + # break the same logic. + error_message1 = "actor constructor failed" + error_message2 = "actor method failed" + + @ray.remote + class CheckpointableFailedActor(ray_checkpointable_actor_cls): + def __init__(self): + raise Exception(error_message1) + + def fail_method(self): + raise Exception(error_message2) + + def should_checkpoint(self, checkpoint_context): + return True + + a = CheckpointableFailedActor.remote() + + # Make sure that we get errors from a failed constructor. + wait_for_errors(ray_constants.TASK_PUSH_ERROR, 1, timeout=2) + errors = relevant_errors(ray_constants.TASK_PUSH_ERROR) + assert len(errors) == 1 + assert error_message1 in errors[0]["message"] + + # Make sure that we get errors from a failed method. + a.fail_method.remote() + wait_for_errors(ray_constants.TASK_PUSH_ERROR, 2, timeout=2) + errors = relevant_errors(ray_constants.TASK_PUSH_ERROR) + assert len(errors) == 2 + assert error_message1 in errors[1]["message"] diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 21152c353..81d9c9380 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -17,19 +17,7 @@ import ray import ray.ray_constants as ray_constants from ray.utils import _random_string from ray.tests.cluster_utils import Cluster - - -def relevant_errors(error_type): - return [info for info in ray.error_info() if info["type"] == error_type] - - -def wait_for_errors(error_type, num_errors, timeout=10): - start_time = time.time() - while time.time() - start_time < timeout: - if len(relevant_errors(error_type)) >= num_errors: - return - time.sleep(0.1) - raise Exception("Timing out of wait.") +from ray.tests.utils import (relevant_errors, wait_for_errors) @pytest.fixture diff --git a/python/ray/tests/utils.py b/python/ray/tests/utils.py index e4249f89a..3485b2638 100644 --- a/python/ray/tests/utils.py +++ b/python/ray/tests/utils.py @@ -81,3 +81,16 @@ def run_string_as_driver_nonblocking(driver_script): f.flush() return subprocess.Popen( [sys.executable, f.name], stdout=subprocess.PIPE) + + +def relevant_errors(error_type): + return [info for info in ray.error_info() if info["type"] == error_type] + + +def wait_for_errors(error_type, num_errors, timeout=10): + start_time = time.time() + while time.time() - start_time < timeout: + if len(relevant_errors(error_type)) >= num_errors: + return + time.sleep(0.1) + raise Exception("Timing out of wait.")