Prototype actor checkpointing. (#814)

* Initial testing of checkpointing functions.

* Save checkpoints in Redis.

* Pipe checkpoint_interval through remote decorator.

* Add a test.

* Small cleanups.

* Submit dummy tasks when reconstructing tasks before the most recent tasks so that we don't end up reconstructing the arguments for those tasks.

* Remove old checkpoints to save space.

* Fix linting.
This commit is contained in:
Robert Nishihara 2017-08-07 17:52:39 -07:00 committed by Philipp Moritz
parent d7b10a84b6
commit dbe3d9351c
3 changed files with 229 additions and 24 deletions

View file

@ -39,6 +39,33 @@ def get_actor_method_function_id(attr):
return ray.local_scheduler.ObjectID(function_id) return ray.local_scheduler.ObjectID(function_id)
def get_actor_checkpoint(actor_id, worker):
"""Get the most recent checkpoint associated with a given actor ID.
Args:
actor_id: The actor ID of the actor to get the checkpoint for.
worker: The worker to use to get the checkpoint.
Returns:
If a checkpoint exists, this returns a tuple of the checkpoint index
and the checkpoint. Otherwise it returns (-1, None). The checkpoint
index is the actor counter of the last task that was executed on
the actor before the checkpoint was made.
"""
# Get all of the keys associated with checkpoints for this actor.
actor_key = b"Actor:" + actor_id
checkpoint_indices = [int(key[len(b"checkpoint_"):])
for key in worker.redis_client.hkeys(actor_key)
if key.startswith(b"checkpoint_")]
if len(checkpoint_indices) == 0:
return -1, None
most_recent_checkpoint_index = max(checkpoint_indices)
# Get the most recent checkpoint.
checkpoint = worker.redis_client.hget(
actor_key, "checkpoint_{}".format(most_recent_checkpoint_index))
return most_recent_checkpoint_index, checkpoint
def fetch_and_register_actor(actor_class_key, worker): def fetch_and_register_actor(actor_class_key, worker):
"""Import an actor. """Import an actor.
@ -48,12 +75,15 @@ def fetch_and_register_actor(actor_class_key, worker):
""" """
actor_id_str = worker.actor_id actor_id_str = worker.actor_id
(driver_id, class_id, class_name, (driver_id, class_id, class_name,
module, pickled_class, actor_method_names) = worker.redis_client.hmget( module, pickled_class, checkpoint_interval,
actor_method_names) = worker.redis_client.hmget(
actor_class_key, ["driver_id", "class_id", "class_name", "module", actor_class_key, ["driver_id", "class_id", "class_name", "module",
"class", "actor_method_names"]) "class", "checkpoint_interval",
"actor_method_names"])
actor_name = class_name.decode("ascii") actor_name = class_name.decode("ascii")
module = module.decode("ascii") module = module.decode("ascii")
checkpoint_interval = int(checkpoint_interval)
actor_method_names = json.loads(actor_method_names.decode("ascii")) actor_method_names = json.loads(actor_method_names.decode("ascii"))
# Create a temporary actor with some temporary methods so that if the actor # Create a temporary actor with some temporary methods so that if the actor
@ -62,6 +92,7 @@ def fetch_and_register_actor(actor_class_key, worker):
class TemporaryActor(object): class TemporaryActor(object):
pass pass
worker.actors[actor_id_str] = TemporaryActor() worker.actors[actor_id_str] = TemporaryActor()
worker.actor_checkpoint_interval = checkpoint_interval
def temporary_actor_method(*xs): def temporary_actor_method(*xs):
raise Exception("The actor with name {} failed to be imported, and so " raise Exception("The actor with name {} failed to be imported, and so "
@ -79,6 +110,7 @@ def fetch_and_register_actor(actor_class_key, worker):
try: try:
unpickled_class = pickle.loads(pickled_class) unpickled_class = pickle.loads(pickled_class)
worker.actor_class = unpickled_class
except Exception: except Exception:
# If an exception was thrown when the actor was imported, we record the # If an exception was thrown when the actor was imported, we record the
# traceback and notify the scheduler of the failure. # traceback and notify the scheduler of the failure.
@ -100,7 +132,8 @@ def fetch_and_register_actor(actor_class_key, worker):
# for the actor. # for the actor.
def export_actor_class(class_id, Class, actor_method_names, worker): def export_actor_class(class_id, Class, actor_method_names,
checkpoint_interval, worker):
if worker.mode is None: if worker.mode is None:
raise NotImplemented("TODO(pcm): Cache actors") raise NotImplemented("TODO(pcm): Cache actors")
key = b"ActorClass:" + class_id key = b"ActorClass:" + class_id
@ -108,6 +141,7 @@ def export_actor_class(class_id, Class, actor_method_names, worker):
"class_name": Class.__name__, "class_name": Class.__name__,
"module": Class.__module__, "module": Class.__module__,
"class": pickle.dumps(Class), "class": pickle.dumps(Class),
"checkpoint_interval": checkpoint_interval,
"actor_method_names": json.dumps(list(actor_method_names))} "actor_method_names": json.dumps(list(actor_method_names))}
worker.redis_client.hmset(key, d) worker.redis_client.hmset(key, d)
worker.redis_client.rpush("Exports", key) worker.redis_client.rpush("Exports", key)
@ -173,6 +207,18 @@ def reconstruct_actor_state(actor_id, worker):
actor_id: The ID of the actor being reconstructed. actor_id: The ID of the actor being reconstructed.
worker: The worker object that is running the actor. worker: The worker object that is running the actor.
""" """
# Get the most recent actor checkpoint.
checkpoint_index, checkpoint = get_actor_checkpoint(actor_id, worker)
if checkpoint is not None:
print("Loading actor state from checkpoint {}"
.format(checkpoint_index))
# Wait for the actor to have been defined.
worker._wait_for_actor()
# TODO(rkn): Restoring from the checkpoint may fail, so this should be
# in a try-except block and we should give a good error message.
worker.actors[actor_id] = (
worker.actor_class.__ray_restore_from_checkpoint__(checkpoint))
# TODO(rkn): This call is expensive. It'd be nice to find a way to get only # TODO(rkn): This call is expensive. It'd be nice to find a way to get only
# the tasks that are relevant to this actor. # the tasks that are relevant to this actor.
tasks = ray.global_state.task_table() tasks = ray.global_state.task_table()
@ -238,10 +284,18 @@ def reconstruct_actor_state(actor_id, worker):
# local scheduler does bookkeeping about this actor's resource # local scheduler does bookkeeping about this actor's resource
# utilization and things like that. It's also important for updating # utilization and things like that. It's also important for updating
# some state on the worker. # some state on the worker.
worker.submit_task( if task_spec_info["ActorCounter"] > checkpoint_index:
hex_to_object_id(task_spec_info["FunctionID"]), worker.submit_task(
task_spec_info["Args"], hex_to_object_id(task_spec_info["FunctionID"]),
actor_id=hex_to_object_id(task_spec_info["ActorID"])) task_spec_info["Args"],
actor_id=hex_to_object_id(task_spec_info["ActorID"]))
else:
# Pass in a dummy task with no arguments to avoid having to
# unnecessarily reconstruct past arguments.
worker.submit_task(
hex_to_object_id(task_spec_info["FunctionID"]),
[],
actor_id=hex_to_object_id(task_spec_info["ActorID"]))
# Clear the extra state that we set. # Clear the extra state that we set.
del worker.task_driver_id del worker.task_driver_id
@ -250,18 +304,22 @@ def reconstruct_actor_state(actor_id, worker):
# Get the task from the local scheduler. # Get the task from the local scheduler.
retrieved_task = worker._get_next_task_from_local_scheduler() retrieved_task = worker._get_next_task_from_local_scheduler()
# Assert that the retrieved task is the same as the constructed task.
assert (ray.local_scheduler.task_to_string(task_spec) ==
ray.local_scheduler.task_to_string(retrieved_task))
# Wait for the task to be ready and execute the task. # If the task happened before the most recent checkpoint, ignore it.
worker._wait_for_and_process_task(retrieved_task) # Otherwise, execute it.
if retrieved_task.actor_counter() > checkpoint_index:
# Assert that the retrieved task is the same as the constructed
# task.
assert (ray.local_scheduler.task_to_string(task_spec) ==
ray.local_scheduler.task_to_string(retrieved_task))
# Wait for the task to be ready and then execute it.
worker._wait_for_and_process_task(retrieved_task)
# Enter the main loop to receive and process tasks. # Enter the main loop to receive and process tasks.
worker.main_loop() worker.main_loop()
def make_actor(cls, num_cpus, num_gpus): def make_actor(cls, num_cpus, num_gpus, checkpoint_interval):
# Modify the class to have an additional method that will be used for # Modify the class to have an additional method that will be used for
# terminating the worker. # terminating the worker.
class Class(cls): class Class(cls):
@ -278,6 +336,26 @@ def make_actor(cls, num_cpus, num_gpus):
import os import os
os._exit(0) os._exit(0)
def __ray_save_checkpoint__(self):
if hasattr(self, "__ray_save__"):
object_to_serialize = self.__ray_save__()
else:
object_to_serialize = self
return pickle.dumps(object_to_serialize)
@classmethod
def __ray_restore_from_checkpoint__(cls, pickled_checkpoint):
checkpoint = pickle.loads(pickled_checkpoint)
if hasattr(cls, "__ray_restore__"):
actor_object = cls.__new__(cls)
actor_object.__ray_restore__(checkpoint)
else:
# TODO(rkn): It's possible that this will cause problems. When
# you unpickle the same object twice, the two objects will not
# have the same class.
actor_object = pickle.loads(checkpoint)
return actor_object
Class.__module__ = cls.__module__ Class.__module__ = cls.__module__
Class.__name__ = cls.__name__ Class.__name__ = cls.__name__
@ -363,6 +441,7 @@ def make_actor(cls, num_cpus, num_gpus):
if len(exported) == 0: if len(exported) == 0:
export_actor_class(class_id, Class, export_actor_class(class_id, Class,
self._ray_actor_methods.keys(), self._ray_actor_methods.keys(),
checkpoint_interval,
ray.worker.global_worker) ray.worker.global_worker)
exported.append(0) exported.append(0)
# Export the actor. # Export the actor.

View file

@ -581,6 +581,13 @@ class Worker(object):
"data": data}) "data": data})
self.redis_client.rpush("ErrorKeys", error_key) self.redis_client.rpush("ErrorKeys", error_key)
def _wait_for_actor(self):
"""Wait until the actor has been imported."""
assert self.actor_id != NIL_ACTOR_ID
# Wait until the actor has been imported.
while self.actor_id not in self.actors:
time.sleep(0.001)
def _wait_for_function(self, function_id, driver_id, timeout=10): def _wait_for_function(self, function_id, driver_id, timeout=10):
"""Wait until the function to be executed is present on this worker. """Wait until the function to be executed is present on this worker.
@ -764,6 +771,35 @@ class Worker(object):
data={"function_id": function_id.id(), data={"function_id": function_id.id(),
"function_name": function_name}) "function_name": function_name})
def _checkpoint_actor_state(self, actor_counter):
"""Checkpoint the actor state.
This currently saves the checkpoint to Redis, but the checkpoint really
needs to go somewhere else.
Args:
actor_counter: The index of the most recent task that ran on this
actor.
"""
print("Saving actor checkpoint. actor_counter = {}."
.format(actor_counter))
actor_key = b"Actor:" + self.actor_id
checkpoint = self.actors[self.actor_id].__ray_save_checkpoint__()
# Save the checkpoint in Redis. TODO(rkn): Checkpoints should not
# be stored in Redis. Fix this.
self.redis_client.hset(
actor_key,
"checkpoint_{}".format(actor_counter),
checkpoint)
# Remove the previous checkpoints if there is one.
checkpoint_indices = [int(key[len(b"checkpoint_"):])
for key in self.redis_client.hkeys(actor_key)
if key.startswith(b"checkpoint_")]
for index in checkpoint_indices:
if index < actor_counter:
self.redis_client.hdel(actor_key,
"checkpoint_{}".format(index))
def _wait_for_and_process_task(self, task): def _wait_for_and_process_task(self, task):
"""Wait for a task to be ready and process the task. """Wait for a task to be ready and process the task.
@ -811,6 +847,13 @@ class Worker(object):
ray.worker.global_worker.local_scheduler_client.disconnect() ray.worker.global_worker.local_scheduler_client.disconnect()
os._exit(0) os._exit(0)
# Checkpoint the actor state if it is the right time to do so.
actor_counter = task.actor_counter()
if (self.actor_id != NIL_ACTOR_ID and
self.actor_checkpoint_interval != -1 and
actor_counter % self.actor_checkpoint_interval == 0):
self._checkpoint_actor_state(actor_counter)
def _get_next_task_from_local_scheduler(self): def _get_next_task_from_local_scheduler(self):
"""Get the next task from the local scheduler. """Get the next task from the local scheduler.
@ -2118,11 +2161,13 @@ def remote(*args, **kwargs):
the driver. the driver.
max_calls (int): The maximum number of tasks of this kind that can be max_calls (int): The maximum number of tasks of this kind that can be
run on a worker before the worker needs to be restarted. run on a worker before the worker needs to be restarted.
checkpoint_interval (int): The number of tasks to run between
checkpoints of the actor state.
""" """
worker = global_worker worker = global_worker
def make_remote_decorator(num_return_vals, num_cpus, num_gpus, def make_remote_decorator(num_return_vals, num_cpus, num_gpus,
max_calls, func_id=None): max_calls, checkpoint_interval, func_id=None):
def remote_decorator(func_or_class): def remote_decorator(func_or_class):
if inspect.isfunction(func_or_class): if inspect.isfunction(func_or_class):
function_properties = FunctionProperties( function_properties = FunctionProperties(
@ -2133,7 +2178,8 @@ def remote(*args, **kwargs):
return remote_function_decorator(func_or_class, return remote_function_decorator(func_or_class,
function_properties) function_properties)
if inspect.isclass(func_or_class): if inspect.isclass(func_or_class):
return worker.make_actor(func_or_class, num_cpus, num_gpus) return worker.make_actor(func_or_class, num_cpus, num_gpus,
checkpoint_interval)
raise Exception("The @ray.remote decorator must be applied to " raise Exception("The @ray.remote decorator must be applied to "
"either a function or to a class.") "either a function or to a class.")
@ -2203,17 +2249,21 @@ def remote(*args, **kwargs):
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else 1 num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else 1
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else 0 num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else 0
max_calls = kwargs["max_calls"] if "max_calls" in kwargs else 0 max_calls = kwargs["max_calls"] if "max_calls" in kwargs else 0
checkpoint_interval = (kwargs["checkpoint_interval"]
if "checkpoint_interval" in kwargs else -1)
if _mode() == WORKER_MODE: if _mode() == WORKER_MODE:
if "function_id" in kwargs: if "function_id" in kwargs:
function_id = kwargs["function_id"] function_id = kwargs["function_id"]
return make_remote_decorator(num_return_vals, num_cpus, num_gpus, return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
max_calls, function_id) max_calls, checkpoint_interval,
function_id)
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# This is the case where the decorator is just @ray.remote. # This is the case where the decorator is just @ray.remote.
return make_remote_decorator(num_return_vals, num_cpus, return make_remote_decorator(
num_gpus, max_calls)(args[0]) num_return_vals, num_cpus,
num_gpus, max_calls, checkpoint_interval)(args[0])
else: else:
# This is the case where the decorator is something like # This is the case where the decorator is something like
# @ray.remote(num_return_vals=2). # @ray.remote(num_return_vals=2).
@ -2223,13 +2273,16 @@ def remote(*args, **kwargs):
"the arguments 'num_return_vals', 'num_cpus', " "the arguments 'num_return_vals', 'num_cpus', "
"'num_gpus', or 'max_calls', like " "'num_gpus', or 'max_calls', like "
"'@ray.remote(num_return_vals=2)'.") "'@ray.remote(num_return_vals=2)'.")
assert len(args) == 0 and ("num_return_vals" in kwargs or assert (len(args) == 0 and
"num_cpus" in kwargs or ("num_return_vals" in kwargs or
"num_gpus" in kwargs or "num_cpus" in kwargs or
"max_calls" in kwargs), error_string "num_gpus" in kwargs or
"max_calls" in kwargs or
"checkpoint_interval" in kwargs)), error_string
for key in kwargs: for key in kwargs:
assert key in ["num_return_vals", "num_cpus", assert key in ["num_return_vals", "num_cpus",
"num_gpus", "max_calls"], error_string "num_gpus", "max_calls",
"checkpoint_interval"], error_string
assert "function_id" not in kwargs assert "function_id" not in kwargs
return make_remote_decorator(num_return_vals, num_cpus, num_gpus, return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
max_calls) max_calls, checkpoint_interval)

View file

@ -1214,6 +1214,79 @@ class ActorReconstruction(unittest.TestCase):
ray.worker.cleanup() ray.worker.cleanup()
def testCheckpointing(self):
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
num_workers=0, redirect_output=True)
@ray.remote(checkpoint_interval=5)
class Counter(object):
def __init__(self):
self.x = 0
# The number of times that inc has been called. We won't bother
# restoring this in the checkpoint
self.num_inc_calls = 0
def local_plasma(self):
return ray.worker.global_worker.plasma_client.store_socket_name
def inc(self, *xs):
self.num_inc_calls += 1
self.x += 1
return self.x
def get_num_inc_calls(self):
return self.num_inc_calls
def test_restore(self):
# This method will only work if __ray_restore__ has been run.
return self.y
def __ray_save__(self):
return self.x, -1
def __ray_restore__(self, checkpoint):
self.x, val = checkpoint
self.num_inc_calls = 0
# Test that __ray_save__ has been run.
assert val == -1
self.y = self.x
local_plasma = ray.worker.global_worker.plasma_client.store_socket_name
# Create an actor that is not on the local scheduler.
actor = Counter.remote()
while ray.get(actor.local_plasma.remote()) == local_plasma:
actor = Counter.remote()
args = [ray.put(0) for _ in range(100)]
ids = [actor.inc.remote(*args[i:]) for i in range(100)]
# Wait for the last task to finish running.
ray.get(ids[-1])
# Kill the second local scheduler.
process = ray.services.all_processes[
ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][1]
process.kill()
process.wait()
# Kill the corresponding plasma store to get rid of the cached objects.
process = ray.services.all_processes[
ray.services.PROCESS_TYPE_PLASMA_STORE][1]
process.kill()
process.wait()
# Get all of the results. TODO(rkn): This currently doesn't work.
# results = ray.get(ids)
# self.assertEqual(results, list(range(1, 1 + len(results))))
self.assertEqual(ray.get(actor.test_restore.remote()), 99)
# The inc method should only have executed once on the new actor (for
# the one method call since the most recent checkpoint).
self.assertEqual(ray.get(actor.get_num_inc_calls.remote()), 1)
ray.worker.cleanup()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(verbosity=2) unittest.main(verbosity=2)