Fix checkpoint crash for actor creation task. (#4327)

* Fix checkpoint crash for actor creation task.

* Lint

* Move test to test_actor.py

* Revert unused code in test_failure.py

* Refine test according to Raul's suggestion.
This commit is contained in:
Yuhong Guo 2019-03-14 23:42:57 +08:00 committed by GitHub
parent 2f37cd7e27
commit becffc6cef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 15 deletions

View file

@ -783,8 +783,10 @@ class FunctionActorManager(object):
method_returns = method(actor, *args)
except Exception as e:
# Save the checkpoint before allowing the method exception
# to be thrown.
if isinstance(actor, ray.actor.Checkpointable):
# to be thrown, but don't save the checkpoint for actor
# creation task.
if (isinstance(actor, ray.actor.Checkpointable)
and self._worker.actor_task_counter != 1):
self._save_and_log_checkpoint(actor)
raise e
else:

View file

@ -16,6 +16,10 @@ import ray
import ray.ray_constants as ray_constants
import ray.tests.utils
import ray.tests.cluster_utils
from ray.tests.utils import (
wait_for_errors,
relevant_errors,
)
@pytest.fixture
@ -2563,3 +2567,38 @@ def test_bad_checkpointable_actor_class():
class BadCheckpointableActor(ray.actor.Checkpointable):
def should_checkpoint(self, checkpoint_context):
return True
def test_init_exception_in_checkpointable_actor(ray_start_regular,
ray_checkpointable_actor_cls):
# This test is similar to test_failure.py::test_failed_actor_init.
# This test is used to guarantee that checkpointable actor does not
# break the same logic.
error_message1 = "actor constructor failed"
error_message2 = "actor method failed"
@ray.remote
class CheckpointableFailedActor(ray_checkpointable_actor_cls):
def __init__(self):
raise Exception(error_message1)
def fail_method(self):
raise Exception(error_message2)
def should_checkpoint(self, checkpoint_context):
return True
a = CheckpointableFailedActor.remote()
# Make sure that we get errors from a failed constructor.
wait_for_errors(ray_constants.TASK_PUSH_ERROR, 1, timeout=2)
errors = relevant_errors(ray_constants.TASK_PUSH_ERROR)
assert len(errors) == 1
assert error_message1 in errors[0]["message"]
# Make sure that we get errors from a failed method.
a.fail_method.remote()
wait_for_errors(ray_constants.TASK_PUSH_ERROR, 2, timeout=2)
errors = relevant_errors(ray_constants.TASK_PUSH_ERROR)
assert len(errors) == 2
assert error_message1 in errors[1]["message"]

View file

@ -17,19 +17,7 @@ import ray
import ray.ray_constants as ray_constants
from ray.utils import _random_string
from ray.tests.cluster_utils import Cluster
def relevant_errors(error_type):
return [info for info in ray.error_info() if info["type"] == error_type]
def wait_for_errors(error_type, num_errors, timeout=10):
start_time = time.time()
while time.time() - start_time < timeout:
if len(relevant_errors(error_type)) >= num_errors:
return
time.sleep(0.1)
raise Exception("Timing out of wait.")
from ray.tests.utils import (relevant_errors, wait_for_errors)
@pytest.fixture

View file

@ -81,3 +81,16 @@ def run_string_as_driver_nonblocking(driver_script):
f.flush()
return subprocess.Popen(
[sys.executable, f.name], stdout=subprocess.PIPE)
def relevant_errors(error_type):
return [info for info in ray.error_info() if info["type"] == error_type]
def wait_for_errors(error_type, num_errors, timeout=10):
start_time = time.time()
while time.time() - start_time < timeout:
if len(relevant_errors(error_type)) >= num_errors:
return
time.sleep(0.1)
raise Exception("Timing out of wait.")