mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Prototype actor checkpointing. (#814)
* Initial testing of checkpointing functions. * Save checkpoints in Redis. * Pipe checkpoint_interval through remote decorator. * Add a test. * Small cleanups. * Submit dummy tasks when reconstructing tasks before the most recent tasks so that we don't end up reconstructing the arguments for those tasks. * Remove old checkpoints to save space. * Fix linting.
This commit is contained in:
parent
d7b10a84b6
commit
dbe3d9351c
3 changed files with 229 additions and 24 deletions
|
@ -39,6 +39,33 @@ def get_actor_method_function_id(attr):
|
||||||
return ray.local_scheduler.ObjectID(function_id)
|
return ray.local_scheduler.ObjectID(function_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_actor_checkpoint(actor_id, worker):
|
||||||
|
"""Get the most recent checkpoint associated with a given actor ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor_id: The actor ID of the actor to get the checkpoint for.
|
||||||
|
worker: The worker to use to get the checkpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If a checkpoint exists, this returns a tuple of the checkpoint index
|
||||||
|
and the checkpoint. Otherwise it returns (-1, None). The checkpoint
|
||||||
|
index is the actor counter of the last task that was executed on
|
||||||
|
the actor before the checkpoint was made.
|
||||||
|
"""
|
||||||
|
# Get all of the keys associated with checkpoints for this actor.
|
||||||
|
actor_key = b"Actor:" + actor_id
|
||||||
|
checkpoint_indices = [int(key[len(b"checkpoint_"):])
|
||||||
|
for key in worker.redis_client.hkeys(actor_key)
|
||||||
|
if key.startswith(b"checkpoint_")]
|
||||||
|
if len(checkpoint_indices) == 0:
|
||||||
|
return -1, None
|
||||||
|
most_recent_checkpoint_index = max(checkpoint_indices)
|
||||||
|
# Get the most recent checkpoint.
|
||||||
|
checkpoint = worker.redis_client.hget(
|
||||||
|
actor_key, "checkpoint_{}".format(most_recent_checkpoint_index))
|
||||||
|
return most_recent_checkpoint_index, checkpoint
|
||||||
|
|
||||||
|
|
||||||
def fetch_and_register_actor(actor_class_key, worker):
|
def fetch_and_register_actor(actor_class_key, worker):
|
||||||
"""Import an actor.
|
"""Import an actor.
|
||||||
|
|
||||||
|
@ -48,12 +75,15 @@ def fetch_and_register_actor(actor_class_key, worker):
|
||||||
"""
|
"""
|
||||||
actor_id_str = worker.actor_id
|
actor_id_str = worker.actor_id
|
||||||
(driver_id, class_id, class_name,
|
(driver_id, class_id, class_name,
|
||||||
module, pickled_class, actor_method_names) = worker.redis_client.hmget(
|
module, pickled_class, checkpoint_interval,
|
||||||
|
actor_method_names) = worker.redis_client.hmget(
|
||||||
actor_class_key, ["driver_id", "class_id", "class_name", "module",
|
actor_class_key, ["driver_id", "class_id", "class_name", "module",
|
||||||
"class", "actor_method_names"])
|
"class", "checkpoint_interval",
|
||||||
|
"actor_method_names"])
|
||||||
|
|
||||||
actor_name = class_name.decode("ascii")
|
actor_name = class_name.decode("ascii")
|
||||||
module = module.decode("ascii")
|
module = module.decode("ascii")
|
||||||
|
checkpoint_interval = int(checkpoint_interval)
|
||||||
actor_method_names = json.loads(actor_method_names.decode("ascii"))
|
actor_method_names = json.loads(actor_method_names.decode("ascii"))
|
||||||
|
|
||||||
# Create a temporary actor with some temporary methods so that if the actor
|
# Create a temporary actor with some temporary methods so that if the actor
|
||||||
|
@ -62,6 +92,7 @@ def fetch_and_register_actor(actor_class_key, worker):
|
||||||
class TemporaryActor(object):
|
class TemporaryActor(object):
|
||||||
pass
|
pass
|
||||||
worker.actors[actor_id_str] = TemporaryActor()
|
worker.actors[actor_id_str] = TemporaryActor()
|
||||||
|
worker.actor_checkpoint_interval = checkpoint_interval
|
||||||
|
|
||||||
def temporary_actor_method(*xs):
|
def temporary_actor_method(*xs):
|
||||||
raise Exception("The actor with name {} failed to be imported, and so "
|
raise Exception("The actor with name {} failed to be imported, and so "
|
||||||
|
@ -79,6 +110,7 @@ def fetch_and_register_actor(actor_class_key, worker):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
unpickled_class = pickle.loads(pickled_class)
|
unpickled_class = pickle.loads(pickled_class)
|
||||||
|
worker.actor_class = unpickled_class
|
||||||
except Exception:
|
except Exception:
|
||||||
# If an exception was thrown when the actor was imported, we record the
|
# If an exception was thrown when the actor was imported, we record the
|
||||||
# traceback and notify the scheduler of the failure.
|
# traceback and notify the scheduler of the failure.
|
||||||
|
@ -100,7 +132,8 @@ def fetch_and_register_actor(actor_class_key, worker):
|
||||||
# for the actor.
|
# for the actor.
|
||||||
|
|
||||||
|
|
||||||
def export_actor_class(class_id, Class, actor_method_names, worker):
|
def export_actor_class(class_id, Class, actor_method_names,
|
||||||
|
checkpoint_interval, worker):
|
||||||
if worker.mode is None:
|
if worker.mode is None:
|
||||||
raise NotImplemented("TODO(pcm): Cache actors")
|
raise NotImplemented("TODO(pcm): Cache actors")
|
||||||
key = b"ActorClass:" + class_id
|
key = b"ActorClass:" + class_id
|
||||||
|
@ -108,6 +141,7 @@ def export_actor_class(class_id, Class, actor_method_names, worker):
|
||||||
"class_name": Class.__name__,
|
"class_name": Class.__name__,
|
||||||
"module": Class.__module__,
|
"module": Class.__module__,
|
||||||
"class": pickle.dumps(Class),
|
"class": pickle.dumps(Class),
|
||||||
|
"checkpoint_interval": checkpoint_interval,
|
||||||
"actor_method_names": json.dumps(list(actor_method_names))}
|
"actor_method_names": json.dumps(list(actor_method_names))}
|
||||||
worker.redis_client.hmset(key, d)
|
worker.redis_client.hmset(key, d)
|
||||||
worker.redis_client.rpush("Exports", key)
|
worker.redis_client.rpush("Exports", key)
|
||||||
|
@ -173,6 +207,18 @@ def reconstruct_actor_state(actor_id, worker):
|
||||||
actor_id: The ID of the actor being reconstructed.
|
actor_id: The ID of the actor being reconstructed.
|
||||||
worker: The worker object that is running the actor.
|
worker: The worker object that is running the actor.
|
||||||
"""
|
"""
|
||||||
|
# Get the most recent actor checkpoint.
|
||||||
|
checkpoint_index, checkpoint = get_actor_checkpoint(actor_id, worker)
|
||||||
|
if checkpoint is not None:
|
||||||
|
print("Loading actor state from checkpoint {}"
|
||||||
|
.format(checkpoint_index))
|
||||||
|
# Wait for the actor to have been defined.
|
||||||
|
worker._wait_for_actor()
|
||||||
|
# TODO(rkn): Restoring from the checkpoint may fail, so this should be
|
||||||
|
# in a try-except block and we should give a good error message.
|
||||||
|
worker.actors[actor_id] = (
|
||||||
|
worker.actor_class.__ray_restore_from_checkpoint__(checkpoint))
|
||||||
|
|
||||||
# TODO(rkn): This call is expensive. It'd be nice to find a way to get only
|
# TODO(rkn): This call is expensive. It'd be nice to find a way to get only
|
||||||
# the tasks that are relevant to this actor.
|
# the tasks that are relevant to this actor.
|
||||||
tasks = ray.global_state.task_table()
|
tasks = ray.global_state.task_table()
|
||||||
|
@ -238,10 +284,18 @@ def reconstruct_actor_state(actor_id, worker):
|
||||||
# local scheduler does bookkeeping about this actor's resource
|
# local scheduler does bookkeeping about this actor's resource
|
||||||
# utilization and things like that. It's also important for updating
|
# utilization and things like that. It's also important for updating
|
||||||
# some state on the worker.
|
# some state on the worker.
|
||||||
worker.submit_task(
|
if task_spec_info["ActorCounter"] > checkpoint_index:
|
||||||
hex_to_object_id(task_spec_info["FunctionID"]),
|
worker.submit_task(
|
||||||
task_spec_info["Args"],
|
hex_to_object_id(task_spec_info["FunctionID"]),
|
||||||
actor_id=hex_to_object_id(task_spec_info["ActorID"]))
|
task_spec_info["Args"],
|
||||||
|
actor_id=hex_to_object_id(task_spec_info["ActorID"]))
|
||||||
|
else:
|
||||||
|
# Pass in a dummy task with no arguments to avoid having to
|
||||||
|
# unnecessarily reconstruct past arguments.
|
||||||
|
worker.submit_task(
|
||||||
|
hex_to_object_id(task_spec_info["FunctionID"]),
|
||||||
|
[],
|
||||||
|
actor_id=hex_to_object_id(task_spec_info["ActorID"]))
|
||||||
|
|
||||||
# Clear the extra state that we set.
|
# Clear the extra state that we set.
|
||||||
del worker.task_driver_id
|
del worker.task_driver_id
|
||||||
|
@ -250,18 +304,22 @@ def reconstruct_actor_state(actor_id, worker):
|
||||||
|
|
||||||
# Get the task from the local scheduler.
|
# Get the task from the local scheduler.
|
||||||
retrieved_task = worker._get_next_task_from_local_scheduler()
|
retrieved_task = worker._get_next_task_from_local_scheduler()
|
||||||
# Assert that the retrieved task is the same as the constructed task.
|
|
||||||
assert (ray.local_scheduler.task_to_string(task_spec) ==
|
|
||||||
ray.local_scheduler.task_to_string(retrieved_task))
|
|
||||||
|
|
||||||
# Wait for the task to be ready and execute the task.
|
# If the task happened before the most recent checkpoint, ignore it.
|
||||||
worker._wait_for_and_process_task(retrieved_task)
|
# Otherwise, execute it.
|
||||||
|
if retrieved_task.actor_counter() > checkpoint_index:
|
||||||
|
# Assert that the retrieved task is the same as the constructed
|
||||||
|
# task.
|
||||||
|
assert (ray.local_scheduler.task_to_string(task_spec) ==
|
||||||
|
ray.local_scheduler.task_to_string(retrieved_task))
|
||||||
|
# Wait for the task to be ready and then execute it.
|
||||||
|
worker._wait_for_and_process_task(retrieved_task)
|
||||||
|
|
||||||
# Enter the main loop to receive and process tasks.
|
# Enter the main loop to receive and process tasks.
|
||||||
worker.main_loop()
|
worker.main_loop()
|
||||||
|
|
||||||
|
|
||||||
def make_actor(cls, num_cpus, num_gpus):
|
def make_actor(cls, num_cpus, num_gpus, checkpoint_interval):
|
||||||
# Modify the class to have an additional method that will be used for
|
# Modify the class to have an additional method that will be used for
|
||||||
# terminating the worker.
|
# terminating the worker.
|
||||||
class Class(cls):
|
class Class(cls):
|
||||||
|
@ -278,6 +336,26 @@ def make_actor(cls, num_cpus, num_gpus):
|
||||||
import os
|
import os
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
|
|
||||||
|
def __ray_save_checkpoint__(self):
|
||||||
|
if hasattr(self, "__ray_save__"):
|
||||||
|
object_to_serialize = self.__ray_save__()
|
||||||
|
else:
|
||||||
|
object_to_serialize = self
|
||||||
|
return pickle.dumps(object_to_serialize)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __ray_restore_from_checkpoint__(cls, pickled_checkpoint):
|
||||||
|
checkpoint = pickle.loads(pickled_checkpoint)
|
||||||
|
if hasattr(cls, "__ray_restore__"):
|
||||||
|
actor_object = cls.__new__(cls)
|
||||||
|
actor_object.__ray_restore__(checkpoint)
|
||||||
|
else:
|
||||||
|
# TODO(rkn): It's possible that this will cause problems. When
|
||||||
|
# you unpickle the same object twice, the two objects will not
|
||||||
|
# have the same class.
|
||||||
|
actor_object = pickle.loads(checkpoint)
|
||||||
|
return actor_object
|
||||||
|
|
||||||
Class.__module__ = cls.__module__
|
Class.__module__ = cls.__module__
|
||||||
Class.__name__ = cls.__name__
|
Class.__name__ = cls.__name__
|
||||||
|
|
||||||
|
@ -363,6 +441,7 @@ def make_actor(cls, num_cpus, num_gpus):
|
||||||
if len(exported) == 0:
|
if len(exported) == 0:
|
||||||
export_actor_class(class_id, Class,
|
export_actor_class(class_id, Class,
|
||||||
self._ray_actor_methods.keys(),
|
self._ray_actor_methods.keys(),
|
||||||
|
checkpoint_interval,
|
||||||
ray.worker.global_worker)
|
ray.worker.global_worker)
|
||||||
exported.append(0)
|
exported.append(0)
|
||||||
# Export the actor.
|
# Export the actor.
|
||||||
|
|
|
@ -581,6 +581,13 @@ class Worker(object):
|
||||||
"data": data})
|
"data": data})
|
||||||
self.redis_client.rpush("ErrorKeys", error_key)
|
self.redis_client.rpush("ErrorKeys", error_key)
|
||||||
|
|
||||||
|
def _wait_for_actor(self):
|
||||||
|
"""Wait until the actor has been imported."""
|
||||||
|
assert self.actor_id != NIL_ACTOR_ID
|
||||||
|
# Wait until the actor has been imported.
|
||||||
|
while self.actor_id not in self.actors:
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
def _wait_for_function(self, function_id, driver_id, timeout=10):
|
def _wait_for_function(self, function_id, driver_id, timeout=10):
|
||||||
"""Wait until the function to be executed is present on this worker.
|
"""Wait until the function to be executed is present on this worker.
|
||||||
|
|
||||||
|
@ -764,6 +771,35 @@ class Worker(object):
|
||||||
data={"function_id": function_id.id(),
|
data={"function_id": function_id.id(),
|
||||||
"function_name": function_name})
|
"function_name": function_name})
|
||||||
|
|
||||||
|
def _checkpoint_actor_state(self, actor_counter):
|
||||||
|
"""Checkpoint the actor state.
|
||||||
|
|
||||||
|
This currently saves the checkpoint to Redis, but the checkpoint really
|
||||||
|
needs to go somewhere else.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor_counter: The index of the most recent task that ran on this
|
||||||
|
actor.
|
||||||
|
"""
|
||||||
|
print("Saving actor checkpoint. actor_counter = {}."
|
||||||
|
.format(actor_counter))
|
||||||
|
actor_key = b"Actor:" + self.actor_id
|
||||||
|
checkpoint = self.actors[self.actor_id].__ray_save_checkpoint__()
|
||||||
|
# Save the checkpoint in Redis. TODO(rkn): Checkpoints should not
|
||||||
|
# be stored in Redis. Fix this.
|
||||||
|
self.redis_client.hset(
|
||||||
|
actor_key,
|
||||||
|
"checkpoint_{}".format(actor_counter),
|
||||||
|
checkpoint)
|
||||||
|
# Remove the previous checkpoints if there is one.
|
||||||
|
checkpoint_indices = [int(key[len(b"checkpoint_"):])
|
||||||
|
for key in self.redis_client.hkeys(actor_key)
|
||||||
|
if key.startswith(b"checkpoint_")]
|
||||||
|
for index in checkpoint_indices:
|
||||||
|
if index < actor_counter:
|
||||||
|
self.redis_client.hdel(actor_key,
|
||||||
|
"checkpoint_{}".format(index))
|
||||||
|
|
||||||
def _wait_for_and_process_task(self, task):
|
def _wait_for_and_process_task(self, task):
|
||||||
"""Wait for a task to be ready and process the task.
|
"""Wait for a task to be ready and process the task.
|
||||||
|
|
||||||
|
@ -811,6 +847,13 @@ class Worker(object):
|
||||||
ray.worker.global_worker.local_scheduler_client.disconnect()
|
ray.worker.global_worker.local_scheduler_client.disconnect()
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
|
|
||||||
|
# Checkpoint the actor state if it is the right time to do so.
|
||||||
|
actor_counter = task.actor_counter()
|
||||||
|
if (self.actor_id != NIL_ACTOR_ID and
|
||||||
|
self.actor_checkpoint_interval != -1 and
|
||||||
|
actor_counter % self.actor_checkpoint_interval == 0):
|
||||||
|
self._checkpoint_actor_state(actor_counter)
|
||||||
|
|
||||||
def _get_next_task_from_local_scheduler(self):
|
def _get_next_task_from_local_scheduler(self):
|
||||||
"""Get the next task from the local scheduler.
|
"""Get the next task from the local scheduler.
|
||||||
|
|
||||||
|
@ -2118,11 +2161,13 @@ def remote(*args, **kwargs):
|
||||||
the driver.
|
the driver.
|
||||||
max_calls (int): The maximum number of tasks of this kind that can be
|
max_calls (int): The maximum number of tasks of this kind that can be
|
||||||
run on a worker before the worker needs to be restarted.
|
run on a worker before the worker needs to be restarted.
|
||||||
|
checkpoint_interval (int): The number of tasks to run between
|
||||||
|
checkpoints of the actor state.
|
||||||
"""
|
"""
|
||||||
worker = global_worker
|
worker = global_worker
|
||||||
|
|
||||||
def make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
def make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
||||||
max_calls, func_id=None):
|
max_calls, checkpoint_interval, func_id=None):
|
||||||
def remote_decorator(func_or_class):
|
def remote_decorator(func_or_class):
|
||||||
if inspect.isfunction(func_or_class):
|
if inspect.isfunction(func_or_class):
|
||||||
function_properties = FunctionProperties(
|
function_properties = FunctionProperties(
|
||||||
|
@ -2133,7 +2178,8 @@ def remote(*args, **kwargs):
|
||||||
return remote_function_decorator(func_or_class,
|
return remote_function_decorator(func_or_class,
|
||||||
function_properties)
|
function_properties)
|
||||||
if inspect.isclass(func_or_class):
|
if inspect.isclass(func_or_class):
|
||||||
return worker.make_actor(func_or_class, num_cpus, num_gpus)
|
return worker.make_actor(func_or_class, num_cpus, num_gpus,
|
||||||
|
checkpoint_interval)
|
||||||
raise Exception("The @ray.remote decorator must be applied to "
|
raise Exception("The @ray.remote decorator must be applied to "
|
||||||
"either a function or to a class.")
|
"either a function or to a class.")
|
||||||
|
|
||||||
|
@ -2203,17 +2249,21 @@ def remote(*args, **kwargs):
|
||||||
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else 1
|
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else 1
|
||||||
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else 0
|
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else 0
|
||||||
max_calls = kwargs["max_calls"] if "max_calls" in kwargs else 0
|
max_calls = kwargs["max_calls"] if "max_calls" in kwargs else 0
|
||||||
|
checkpoint_interval = (kwargs["checkpoint_interval"]
|
||||||
|
if "checkpoint_interval" in kwargs else -1)
|
||||||
|
|
||||||
if _mode() == WORKER_MODE:
|
if _mode() == WORKER_MODE:
|
||||||
if "function_id" in kwargs:
|
if "function_id" in kwargs:
|
||||||
function_id = kwargs["function_id"]
|
function_id = kwargs["function_id"]
|
||||||
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
||||||
max_calls, function_id)
|
max_calls, checkpoint_interval,
|
||||||
|
function_id)
|
||||||
|
|
||||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||||
# This is the case where the decorator is just @ray.remote.
|
# This is the case where the decorator is just @ray.remote.
|
||||||
return make_remote_decorator(num_return_vals, num_cpus,
|
return make_remote_decorator(
|
||||||
num_gpus, max_calls)(args[0])
|
num_return_vals, num_cpus,
|
||||||
|
num_gpus, max_calls, checkpoint_interval)(args[0])
|
||||||
else:
|
else:
|
||||||
# This is the case where the decorator is something like
|
# This is the case where the decorator is something like
|
||||||
# @ray.remote(num_return_vals=2).
|
# @ray.remote(num_return_vals=2).
|
||||||
|
@ -2223,13 +2273,16 @@ def remote(*args, **kwargs):
|
||||||
"the arguments 'num_return_vals', 'num_cpus', "
|
"the arguments 'num_return_vals', 'num_cpus', "
|
||||||
"'num_gpus', or 'max_calls', like "
|
"'num_gpus', or 'max_calls', like "
|
||||||
"'@ray.remote(num_return_vals=2)'.")
|
"'@ray.remote(num_return_vals=2)'.")
|
||||||
assert len(args) == 0 and ("num_return_vals" in kwargs or
|
assert (len(args) == 0 and
|
||||||
"num_cpus" in kwargs or
|
("num_return_vals" in kwargs or
|
||||||
"num_gpus" in kwargs or
|
"num_cpus" in kwargs or
|
||||||
"max_calls" in kwargs), error_string
|
"num_gpus" in kwargs or
|
||||||
|
"max_calls" in kwargs or
|
||||||
|
"checkpoint_interval" in kwargs)), error_string
|
||||||
for key in kwargs:
|
for key in kwargs:
|
||||||
assert key in ["num_return_vals", "num_cpus",
|
assert key in ["num_return_vals", "num_cpus",
|
||||||
"num_gpus", "max_calls"], error_string
|
"num_gpus", "max_calls",
|
||||||
|
"checkpoint_interval"], error_string
|
||||||
assert "function_id" not in kwargs
|
assert "function_id" not in kwargs
|
||||||
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
||||||
max_calls)
|
max_calls, checkpoint_interval)
|
||||||
|
|
|
@ -1214,6 +1214,79 @@ class ActorReconstruction(unittest.TestCase):
|
||||||
|
|
||||||
ray.worker.cleanup()
|
ray.worker.cleanup()
|
||||||
|
|
||||||
|
def testCheckpointing(self):
|
||||||
|
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
|
||||||
|
num_workers=0, redirect_output=True)
|
||||||
|
|
||||||
|
@ray.remote(checkpoint_interval=5)
|
||||||
|
class Counter(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.x = 0
|
||||||
|
# The number of times that inc has been called. We won't bother
|
||||||
|
# restoring this in the checkpoint
|
||||||
|
self.num_inc_calls = 0
|
||||||
|
|
||||||
|
def local_plasma(self):
|
||||||
|
return ray.worker.global_worker.plasma_client.store_socket_name
|
||||||
|
|
||||||
|
def inc(self, *xs):
|
||||||
|
self.num_inc_calls += 1
|
||||||
|
self.x += 1
|
||||||
|
return self.x
|
||||||
|
|
||||||
|
def get_num_inc_calls(self):
|
||||||
|
return self.num_inc_calls
|
||||||
|
|
||||||
|
def test_restore(self):
|
||||||
|
# This method will only work if __ray_restore__ has been run.
|
||||||
|
return self.y
|
||||||
|
|
||||||
|
def __ray_save__(self):
|
||||||
|
return self.x, -1
|
||||||
|
|
||||||
|
def __ray_restore__(self, checkpoint):
|
||||||
|
self.x, val = checkpoint
|
||||||
|
self.num_inc_calls = 0
|
||||||
|
# Test that __ray_save__ has been run.
|
||||||
|
assert val == -1
|
||||||
|
self.y = self.x
|
||||||
|
|
||||||
|
local_plasma = ray.worker.global_worker.plasma_client.store_socket_name
|
||||||
|
|
||||||
|
# Create an actor that is not on the local scheduler.
|
||||||
|
actor = Counter.remote()
|
||||||
|
while ray.get(actor.local_plasma.remote()) == local_plasma:
|
||||||
|
actor = Counter.remote()
|
||||||
|
|
||||||
|
args = [ray.put(0) for _ in range(100)]
|
||||||
|
ids = [actor.inc.remote(*args[i:]) for i in range(100)]
|
||||||
|
|
||||||
|
# Wait for the last task to finish running.
|
||||||
|
ray.get(ids[-1])
|
||||||
|
|
||||||
|
# Kill the second local scheduler.
|
||||||
|
process = ray.services.all_processes[
|
||||||
|
ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][1]
|
||||||
|
process.kill()
|
||||||
|
process.wait()
|
||||||
|
# Kill the corresponding plasma store to get rid of the cached objects.
|
||||||
|
process = ray.services.all_processes[
|
||||||
|
ray.services.PROCESS_TYPE_PLASMA_STORE][1]
|
||||||
|
process.kill()
|
||||||
|
process.wait()
|
||||||
|
|
||||||
|
# Get all of the results. TODO(rkn): This currently doesn't work.
|
||||||
|
# results = ray.get(ids)
|
||||||
|
# self.assertEqual(results, list(range(1, 1 + len(results))))
|
||||||
|
|
||||||
|
self.assertEqual(ray.get(actor.test_restore.remote()), 99)
|
||||||
|
|
||||||
|
# The inc method should only have executed once on the new actor (for
|
||||||
|
# the one method call since the most recent checkpoint).
|
||||||
|
self.assertEqual(ray.get(actor.get_num_inc_calls.remote()), 1)
|
||||||
|
|
||||||
|
ray.worker.cleanup()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(verbosity=2)
|
unittest.main(verbosity=2)
|
||||||
|
|
Loading…
Add table
Reference in a new issue