mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Prototype actor checkpointing. (#814)
* Initial testing of checkpointing functions. * Save checkpoints in Redis. * Pipe checkpoint_interval through remote decorator. * Add a test. * Small cleanups. * Submit dummy tasks when reconstructing tasks before the most recent tasks so that we don't end up reconstructing the arguments for those tasks. * Remove old checkpoints to save space. * Fix linting.
This commit is contained in:
parent
d7b10a84b6
commit
dbe3d9351c
3 changed files with 229 additions and 24 deletions
|
@ -39,6 +39,33 @@ def get_actor_method_function_id(attr):
|
|||
return ray.local_scheduler.ObjectID(function_id)
|
||||
|
||||
|
||||
def get_actor_checkpoint(actor_id, worker):
|
||||
"""Get the most recent checkpoint associated with a given actor ID.
|
||||
|
||||
Args:
|
||||
actor_id: The actor ID of the actor to get the checkpoint for.
|
||||
worker: The worker to use to get the checkpoint.
|
||||
|
||||
Returns:
|
||||
If a checkpoint exists, this returns a tuple of the checkpoint index
|
||||
and the checkpoint. Otherwise it returns (-1, None). The checkpoint
|
||||
index is the actor counter of the last task that was executed on
|
||||
the actor before the checkpoint was made.
|
||||
"""
|
||||
# Get all of the keys associated with checkpoints for this actor.
|
||||
actor_key = b"Actor:" + actor_id
|
||||
checkpoint_indices = [int(key[len(b"checkpoint_"):])
|
||||
for key in worker.redis_client.hkeys(actor_key)
|
||||
if key.startswith(b"checkpoint_")]
|
||||
if len(checkpoint_indices) == 0:
|
||||
return -1, None
|
||||
most_recent_checkpoint_index = max(checkpoint_indices)
|
||||
# Get the most recent checkpoint.
|
||||
checkpoint = worker.redis_client.hget(
|
||||
actor_key, "checkpoint_{}".format(most_recent_checkpoint_index))
|
||||
return most_recent_checkpoint_index, checkpoint
|
||||
|
||||
|
||||
def fetch_and_register_actor(actor_class_key, worker):
|
||||
"""Import an actor.
|
||||
|
||||
|
@ -48,12 +75,15 @@ def fetch_and_register_actor(actor_class_key, worker):
|
|||
"""
|
||||
actor_id_str = worker.actor_id
|
||||
(driver_id, class_id, class_name,
|
||||
module, pickled_class, actor_method_names) = worker.redis_client.hmget(
|
||||
module, pickled_class, checkpoint_interval,
|
||||
actor_method_names) = worker.redis_client.hmget(
|
||||
actor_class_key, ["driver_id", "class_id", "class_name", "module",
|
||||
"class", "actor_method_names"])
|
||||
"class", "checkpoint_interval",
|
||||
"actor_method_names"])
|
||||
|
||||
actor_name = class_name.decode("ascii")
|
||||
module = module.decode("ascii")
|
||||
checkpoint_interval = int(checkpoint_interval)
|
||||
actor_method_names = json.loads(actor_method_names.decode("ascii"))
|
||||
|
||||
# Create a temporary actor with some temporary methods so that if the actor
|
||||
|
@ -62,6 +92,7 @@ def fetch_and_register_actor(actor_class_key, worker):
|
|||
class TemporaryActor(object):
|
||||
pass
|
||||
worker.actors[actor_id_str] = TemporaryActor()
|
||||
worker.actor_checkpoint_interval = checkpoint_interval
|
||||
|
||||
def temporary_actor_method(*xs):
|
||||
raise Exception("The actor with name {} failed to be imported, and so "
|
||||
|
@ -79,6 +110,7 @@ def fetch_and_register_actor(actor_class_key, worker):
|
|||
|
||||
try:
|
||||
unpickled_class = pickle.loads(pickled_class)
|
||||
worker.actor_class = unpickled_class
|
||||
except Exception:
|
||||
# If an exception was thrown when the actor was imported, we record the
|
||||
# traceback and notify the scheduler of the failure.
|
||||
|
@ -100,7 +132,8 @@ def fetch_and_register_actor(actor_class_key, worker):
|
|||
# for the actor.
|
||||
|
||||
|
||||
def export_actor_class(class_id, Class, actor_method_names, worker):
|
||||
def export_actor_class(class_id, Class, actor_method_names,
|
||||
checkpoint_interval, worker):
|
||||
if worker.mode is None:
|
||||
raise NotImplemented("TODO(pcm): Cache actors")
|
||||
key = b"ActorClass:" + class_id
|
||||
|
@ -108,6 +141,7 @@ def export_actor_class(class_id, Class, actor_method_names, worker):
|
|||
"class_name": Class.__name__,
|
||||
"module": Class.__module__,
|
||||
"class": pickle.dumps(Class),
|
||||
"checkpoint_interval": checkpoint_interval,
|
||||
"actor_method_names": json.dumps(list(actor_method_names))}
|
||||
worker.redis_client.hmset(key, d)
|
||||
worker.redis_client.rpush("Exports", key)
|
||||
|
@ -173,6 +207,18 @@ def reconstruct_actor_state(actor_id, worker):
|
|||
actor_id: The ID of the actor being reconstructed.
|
||||
worker: The worker object that is running the actor.
|
||||
"""
|
||||
# Get the most recent actor checkpoint.
|
||||
checkpoint_index, checkpoint = get_actor_checkpoint(actor_id, worker)
|
||||
if checkpoint is not None:
|
||||
print("Loading actor state from checkpoint {}"
|
||||
.format(checkpoint_index))
|
||||
# Wait for the actor to have been defined.
|
||||
worker._wait_for_actor()
|
||||
# TODO(rkn): Restoring from the checkpoint may fail, so this should be
|
||||
# in a try-except block and we should give a good error message.
|
||||
worker.actors[actor_id] = (
|
||||
worker.actor_class.__ray_restore_from_checkpoint__(checkpoint))
|
||||
|
||||
# TODO(rkn): This call is expensive. It'd be nice to find a way to get only
|
||||
# the tasks that are relevant to this actor.
|
||||
tasks = ray.global_state.task_table()
|
||||
|
@ -238,10 +284,18 @@ def reconstruct_actor_state(actor_id, worker):
|
|||
# local scheduler does bookkeeping about this actor's resource
|
||||
# utilization and things like that. It's also important for updating
|
||||
# some state on the worker.
|
||||
worker.submit_task(
|
||||
hex_to_object_id(task_spec_info["FunctionID"]),
|
||||
task_spec_info["Args"],
|
||||
actor_id=hex_to_object_id(task_spec_info["ActorID"]))
|
||||
if task_spec_info["ActorCounter"] > checkpoint_index:
|
||||
worker.submit_task(
|
||||
hex_to_object_id(task_spec_info["FunctionID"]),
|
||||
task_spec_info["Args"],
|
||||
actor_id=hex_to_object_id(task_spec_info["ActorID"]))
|
||||
else:
|
||||
# Pass in a dummy task with no arguments to avoid having to
|
||||
# unnecessarily reconstruct past arguments.
|
||||
worker.submit_task(
|
||||
hex_to_object_id(task_spec_info["FunctionID"]),
|
||||
[],
|
||||
actor_id=hex_to_object_id(task_spec_info["ActorID"]))
|
||||
|
||||
# Clear the extra state that we set.
|
||||
del worker.task_driver_id
|
||||
|
@ -250,18 +304,22 @@ def reconstruct_actor_state(actor_id, worker):
|
|||
|
||||
# Get the task from the local scheduler.
|
||||
retrieved_task = worker._get_next_task_from_local_scheduler()
|
||||
# Assert that the retrieved task is the same as the constructed task.
|
||||
assert (ray.local_scheduler.task_to_string(task_spec) ==
|
||||
ray.local_scheduler.task_to_string(retrieved_task))
|
||||
|
||||
# Wait for the task to be ready and execute the task.
|
||||
worker._wait_for_and_process_task(retrieved_task)
|
||||
# If the task happened before the most recent checkpoint, ignore it.
|
||||
# Otherwise, execute it.
|
||||
if retrieved_task.actor_counter() > checkpoint_index:
|
||||
# Assert that the retrieved task is the same as the constructed
|
||||
# task.
|
||||
assert (ray.local_scheduler.task_to_string(task_spec) ==
|
||||
ray.local_scheduler.task_to_string(retrieved_task))
|
||||
# Wait for the task to be ready and then execute it.
|
||||
worker._wait_for_and_process_task(retrieved_task)
|
||||
|
||||
# Enter the main loop to receive and process tasks.
|
||||
worker.main_loop()
|
||||
|
||||
|
||||
def make_actor(cls, num_cpus, num_gpus):
|
||||
def make_actor(cls, num_cpus, num_gpus, checkpoint_interval):
|
||||
# Modify the class to have an additional method that will be used for
|
||||
# terminating the worker.
|
||||
class Class(cls):
|
||||
|
@ -278,6 +336,26 @@ def make_actor(cls, num_cpus, num_gpus):
|
|||
import os
|
||||
os._exit(0)
|
||||
|
||||
def __ray_save_checkpoint__(self):
|
||||
if hasattr(self, "__ray_save__"):
|
||||
object_to_serialize = self.__ray_save__()
|
||||
else:
|
||||
object_to_serialize = self
|
||||
return pickle.dumps(object_to_serialize)
|
||||
|
||||
@classmethod
|
||||
def __ray_restore_from_checkpoint__(cls, pickled_checkpoint):
|
||||
checkpoint = pickle.loads(pickled_checkpoint)
|
||||
if hasattr(cls, "__ray_restore__"):
|
||||
actor_object = cls.__new__(cls)
|
||||
actor_object.__ray_restore__(checkpoint)
|
||||
else:
|
||||
# TODO(rkn): It's possible that this will cause problems. When
|
||||
# you unpickle the same object twice, the two objects will not
|
||||
# have the same class.
|
||||
actor_object = pickle.loads(checkpoint)
|
||||
return actor_object
|
||||
|
||||
Class.__module__ = cls.__module__
|
||||
Class.__name__ = cls.__name__
|
||||
|
||||
|
@ -363,6 +441,7 @@ def make_actor(cls, num_cpus, num_gpus):
|
|||
if len(exported) == 0:
|
||||
export_actor_class(class_id, Class,
|
||||
self._ray_actor_methods.keys(),
|
||||
checkpoint_interval,
|
||||
ray.worker.global_worker)
|
||||
exported.append(0)
|
||||
# Export the actor.
|
||||
|
|
|
@ -581,6 +581,13 @@ class Worker(object):
|
|||
"data": data})
|
||||
self.redis_client.rpush("ErrorKeys", error_key)
|
||||
|
||||
def _wait_for_actor(self):
|
||||
"""Wait until the actor has been imported."""
|
||||
assert self.actor_id != NIL_ACTOR_ID
|
||||
# Wait until the actor has been imported.
|
||||
while self.actor_id not in self.actors:
|
||||
time.sleep(0.001)
|
||||
|
||||
def _wait_for_function(self, function_id, driver_id, timeout=10):
|
||||
"""Wait until the function to be executed is present on this worker.
|
||||
|
||||
|
@ -764,6 +771,35 @@ class Worker(object):
|
|||
data={"function_id": function_id.id(),
|
||||
"function_name": function_name})
|
||||
|
||||
def _checkpoint_actor_state(self, actor_counter):
|
||||
"""Checkpoint the actor state.
|
||||
|
||||
This currently saves the checkpoint to Redis, but the checkpoint really
|
||||
needs to go somewhere else.
|
||||
|
||||
Args:
|
||||
actor_counter: The index of the most recent task that ran on this
|
||||
actor.
|
||||
"""
|
||||
print("Saving actor checkpoint. actor_counter = {}."
|
||||
.format(actor_counter))
|
||||
actor_key = b"Actor:" + self.actor_id
|
||||
checkpoint = self.actors[self.actor_id].__ray_save_checkpoint__()
|
||||
# Save the checkpoint in Redis. TODO(rkn): Checkpoints should not
|
||||
# be stored in Redis. Fix this.
|
||||
self.redis_client.hset(
|
||||
actor_key,
|
||||
"checkpoint_{}".format(actor_counter),
|
||||
checkpoint)
|
||||
# Remove the previous checkpoints if there is one.
|
||||
checkpoint_indices = [int(key[len(b"checkpoint_"):])
|
||||
for key in self.redis_client.hkeys(actor_key)
|
||||
if key.startswith(b"checkpoint_")]
|
||||
for index in checkpoint_indices:
|
||||
if index < actor_counter:
|
||||
self.redis_client.hdel(actor_key,
|
||||
"checkpoint_{}".format(index))
|
||||
|
||||
def _wait_for_and_process_task(self, task):
|
||||
"""Wait for a task to be ready and process the task.
|
||||
|
||||
|
@ -811,6 +847,13 @@ class Worker(object):
|
|||
ray.worker.global_worker.local_scheduler_client.disconnect()
|
||||
os._exit(0)
|
||||
|
||||
# Checkpoint the actor state if it is the right time to do so.
|
||||
actor_counter = task.actor_counter()
|
||||
if (self.actor_id != NIL_ACTOR_ID and
|
||||
self.actor_checkpoint_interval != -1 and
|
||||
actor_counter % self.actor_checkpoint_interval == 0):
|
||||
self._checkpoint_actor_state(actor_counter)
|
||||
|
||||
def _get_next_task_from_local_scheduler(self):
|
||||
"""Get the next task from the local scheduler.
|
||||
|
||||
|
@ -2118,11 +2161,13 @@ def remote(*args, **kwargs):
|
|||
the driver.
|
||||
max_calls (int): The maximum number of tasks of this kind that can be
|
||||
run on a worker before the worker needs to be restarted.
|
||||
checkpoint_interval (int): The number of tasks to run between
|
||||
checkpoints of the actor state.
|
||||
"""
|
||||
worker = global_worker
|
||||
|
||||
def make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
||||
max_calls, func_id=None):
|
||||
max_calls, checkpoint_interval, func_id=None):
|
||||
def remote_decorator(func_or_class):
|
||||
if inspect.isfunction(func_or_class):
|
||||
function_properties = FunctionProperties(
|
||||
|
@ -2133,7 +2178,8 @@ def remote(*args, **kwargs):
|
|||
return remote_function_decorator(func_or_class,
|
||||
function_properties)
|
||||
if inspect.isclass(func_or_class):
|
||||
return worker.make_actor(func_or_class, num_cpus, num_gpus)
|
||||
return worker.make_actor(func_or_class, num_cpus, num_gpus,
|
||||
checkpoint_interval)
|
||||
raise Exception("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
||||
|
@ -2203,17 +2249,21 @@ def remote(*args, **kwargs):
|
|||
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else 1
|
||||
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else 0
|
||||
max_calls = kwargs["max_calls"] if "max_calls" in kwargs else 0
|
||||
checkpoint_interval = (kwargs["checkpoint_interval"]
|
||||
if "checkpoint_interval" in kwargs else -1)
|
||||
|
||||
if _mode() == WORKER_MODE:
|
||||
if "function_id" in kwargs:
|
||||
function_id = kwargs["function_id"]
|
||||
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
||||
max_calls, function_id)
|
||||
max_calls, checkpoint_interval,
|
||||
function_id)
|
||||
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
# This is the case where the decorator is just @ray.remote.
|
||||
return make_remote_decorator(num_return_vals, num_cpus,
|
||||
num_gpus, max_calls)(args[0])
|
||||
return make_remote_decorator(
|
||||
num_return_vals, num_cpus,
|
||||
num_gpus, max_calls, checkpoint_interval)(args[0])
|
||||
else:
|
||||
# This is the case where the decorator is something like
|
||||
# @ray.remote(num_return_vals=2).
|
||||
|
@ -2223,13 +2273,16 @@ def remote(*args, **kwargs):
|
|||
"the arguments 'num_return_vals', 'num_cpus', "
|
||||
"'num_gpus', or 'max_calls', like "
|
||||
"'@ray.remote(num_return_vals=2)'.")
|
||||
assert len(args) == 0 and ("num_return_vals" in kwargs or
|
||||
"num_cpus" in kwargs or
|
||||
"num_gpus" in kwargs or
|
||||
"max_calls" in kwargs), error_string
|
||||
assert (len(args) == 0 and
|
||||
("num_return_vals" in kwargs or
|
||||
"num_cpus" in kwargs or
|
||||
"num_gpus" in kwargs or
|
||||
"max_calls" in kwargs or
|
||||
"checkpoint_interval" in kwargs)), error_string
|
||||
for key in kwargs:
|
||||
assert key in ["num_return_vals", "num_cpus",
|
||||
"num_gpus", "max_calls"], error_string
|
||||
"num_gpus", "max_calls",
|
||||
"checkpoint_interval"], error_string
|
||||
assert "function_id" not in kwargs
|
||||
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
||||
max_calls)
|
||||
max_calls, checkpoint_interval)
|
||||
|
|
|
@ -1214,6 +1214,79 @@ class ActorReconstruction(unittest.TestCase):
|
|||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testCheckpointing(self):
|
||||
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
|
||||
num_workers=0, redirect_output=True)
|
||||
|
||||
@ray.remote(checkpoint_interval=5)
|
||||
class Counter(object):
|
||||
def __init__(self):
|
||||
self.x = 0
|
||||
# The number of times that inc has been called. We won't bother
|
||||
# restoring this in the checkpoint
|
||||
self.num_inc_calls = 0
|
||||
|
||||
def local_plasma(self):
|
||||
return ray.worker.global_worker.plasma_client.store_socket_name
|
||||
|
||||
def inc(self, *xs):
|
||||
self.num_inc_calls += 1
|
||||
self.x += 1
|
||||
return self.x
|
||||
|
||||
def get_num_inc_calls(self):
|
||||
return self.num_inc_calls
|
||||
|
||||
def test_restore(self):
|
||||
# This method will only work if __ray_restore__ has been run.
|
||||
return self.y
|
||||
|
||||
def __ray_save__(self):
|
||||
return self.x, -1
|
||||
|
||||
def __ray_restore__(self, checkpoint):
|
||||
self.x, val = checkpoint
|
||||
self.num_inc_calls = 0
|
||||
# Test that __ray_save__ has been run.
|
||||
assert val == -1
|
||||
self.y = self.x
|
||||
|
||||
local_plasma = ray.worker.global_worker.plasma_client.store_socket_name
|
||||
|
||||
# Create an actor that is not on the local scheduler.
|
||||
actor = Counter.remote()
|
||||
while ray.get(actor.local_plasma.remote()) == local_plasma:
|
||||
actor = Counter.remote()
|
||||
|
||||
args = [ray.put(0) for _ in range(100)]
|
||||
ids = [actor.inc.remote(*args[i:]) for i in range(100)]
|
||||
|
||||
# Wait for the last task to finish running.
|
||||
ray.get(ids[-1])
|
||||
|
||||
# Kill the second local scheduler.
|
||||
process = ray.services.all_processes[
|
||||
ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][1]
|
||||
process.kill()
|
||||
process.wait()
|
||||
# Kill the corresponding plasma store to get rid of the cached objects.
|
||||
process = ray.services.all_processes[
|
||||
ray.services.PROCESS_TYPE_PLASMA_STORE][1]
|
||||
process.kill()
|
||||
process.wait()
|
||||
|
||||
# Get all of the results. TODO(rkn): This currently doesn't work.
|
||||
# results = ray.get(ids)
|
||||
# self.assertEqual(results, list(range(1, 1 + len(results))))
|
||||
|
||||
self.assertEqual(ray.get(actor.test_restore.remote()), 99)
|
||||
|
||||
# The inc method should only have executed once on the new actor (for
|
||||
# the one method call since the most recent checkpoint).
|
||||
self.assertEqual(ray.get(actor.get_num_inc_calls.remote()), 1)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
|
Loading…
Add table
Reference in a new issue