Actor checkpointing for distributed actor handles (#1498)

* Expose calls to get and set the actor frontier

* Remove fields used for old checkpointing prototype, change actor_checkpoint_failed -> succeeded

* Prototype for actor checkpointing

* Filter out duplicate tasks on the local scheduler

* Clean up some of the Python checkpointing code

* More cleanups

* Documentation

* cleanup and fix unit test

* Allow remote checkpoint calls through actor handle

* Check whether object is local before reconstructing

* Enable checkpointing for distributed actor handles, refactor tests

* Fix local scheduler tests

* lint

* Address comments

* lint

* Skip tests that fail on new GCS

* style

* Don't put same object twice when setting the actor frontier

* Address Philipp's comments, cleaner fbs naming
This commit is contained in:
Stephanie Wang 2018-02-07 11:19:32 -08:00 committed by GitHub
parent 0a9dbc84b5
commit ff8e7f8259
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 1006 additions and 590 deletions

View file

@ -8,7 +8,6 @@ import inspect
import json
import traceback
import pyarrow.plasma as plasma
import ray.cloudpickle as pickle
import ray.local_scheduler
import ray.signature as signature
@ -66,23 +65,24 @@ def compute_actor_method_function_id(class_name, attr):
return ray.local_scheduler.ObjectID(function_id)
def get_checkpoint_indices(worker, actor_id):
"""Get the checkpoint indices associated with a given actor ID.
def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
frontier):
"""Set the most recent checkpoint associated with a given actor ID.
Args:
worker: The worker to use to get the checkpoint indices.
actor_id: The actor ID of the actor to get the checkpoint indices for.
Returns:
The indices of existing checkpoints as a list of integers.
worker: The worker to use to get the checkpoint.
actor_id: The actor ID of the actor to get the checkpoint for.
checkpoint_index: The number of tasks included in the checkpoint.
checkpoint: The state object to save.
frontier: The task frontier at the time of the checkpoint.
"""
actor_key = b"Actor:" + actor_id
checkpoint_indices = []
for key in worker.redis_client.hkeys(actor_key):
if key.startswith(b"checkpoint_"):
index = int(key[len(b"checkpoint_"):])
checkpoint_indices.append(index)
return checkpoint_indices
worker.redis_client.hmset(
actor_key, {
"checkpoint_index": checkpoint_index,
"checkpoint": checkpoint,
"frontier": frontier,
})
def get_actor_checkpoint(worker, actor_id):
@ -93,30 +93,74 @@ def get_actor_checkpoint(worker, actor_id):
actor_id: The actor ID of the actor to get the checkpoint for.
Returns:
If a checkpoint exists, this returns a tuple of the checkpoint index
and the checkpoint. Otherwise it returns (-1, None). The checkpoint
index is the actor counter of the last task that was executed on
the actor before the checkpoint was made.
If a checkpoint exists, this returns a tuple of the number of tasks
included in the checkpoint, the saved checkpoint state, and the
task frontier at the time of the checkpoint. If no checkpoint
exists, all objects are set to None. The checkpoint index is the .
executed on the actor before the checkpoint was made.
"""
checkpoint_indices = get_checkpoint_indices(worker, actor_id)
if len(checkpoint_indices) == 0:
return -1, None
else:
actor_key = b"Actor:" + actor_id
checkpoint_index = max(checkpoint_indices)
checkpoint = worker.redis_client.hget(
actor_key, "checkpoint_{}".format(checkpoint_index))
return checkpoint_index, checkpoint
actor_key = b"Actor:" + actor_id
checkpoint_index, checkpoint, frontier = worker.redis_client.hmget(
actor_key, ["checkpoint_index", "checkpoint", "frontier"])
if checkpoint_index is not None:
checkpoint_index = int(checkpoint_index)
return checkpoint_index, checkpoint, frontier
def make_actor_method_executor(worker, method_name, method):
def save_and_log_checkpoint(worker, actor):
"""Save a checkpoint on the actor and log any errors.
Args:
worker: The worker to use to log errors.
actor: The actor to checkpoint.
checkpoint_index: The number of tasks that have executed so far.
"""
try:
actor.__ray_checkpoint__()
except Exception:
traceback_str = ray.utils.format_error_message(
traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
"checkpoint",
traceback_str,
driver_id=worker.task_driver_id.id(),
data={"actor_class": actor.__class__.__name__,
"function_name": actor.__ray_checkpoint__.__name__})
def restore_and_log_checkpoint(worker, actor):
"""Restore an actor from a checkpoint and log any errors.
Args:
worker: The worker to use to log errors.
actor: The actor to restore.
"""
checkpoint_resumed = False
try:
checkpoint_resumed = actor.__ray_checkpoint_restore__()
except Exception:
traceback_str = ray.utils.format_error_message(
traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
"checkpoint",
traceback_str,
driver_id=worker.task_driver_id.id(),
data={
"actor_class": actor.__class__.__name__,
"function_name":
actor.__ray_checkpoint_restore__.__name__})
return checkpoint_resumed
def make_actor_method_executor(worker, method_name, method, actor_imported):
"""Make an executor that wraps a user-defined actor method.
The executor wraps the method to update the worker's internal state. If the
task is a success, the dummy object returned is added to the object store,
to signal that the following task can run, and the worker's task counter is
updated to match the executed task. Else, the executor reports failure to
the local scheduler so that the task counter does not get updated.
The wrapped method updates the worker's internal state and performs any
necessary checkpointing operations.
Args:
worker (Worker): The worker that is executing the actor.
@ -124,6 +168,8 @@ def make_actor_method_executor(worker, method_name, method):
method (instancemethod): The actor method to wrap. This should be a
method defined on the actor class and should therefore take an
instance of the actor as the first argument.
actor_imported (bool): Whether the actor has been imported.
Checkpointing operations will not be run if this is set to False.
Returns:
A function that executes the given actor method on the worker's stored
@ -131,35 +177,48 @@ def make_actor_method_executor(worker, method_name, method):
internal state to record the executed method.
"""
def actor_method_executor(dummy_return_id, task_counter, actor,
*args):
if method_name == "__ray_checkpoint__":
# Execute the checkpoint task.
actor_checkpoint_failed, error = method(actor, *args)
# If the checkpoint was successfully loaded, update the actor's
# task counter and set a flag to notify the local scheduler, so
# that the task following the checkpoint can run.
if not actor_checkpoint_failed:
worker.actor_task_counter = task_counter + 1
# Once the actor has resumed from a checkpoint, it counts as
# loaded.
worker.actor_loaded = True
# Report to the local scheduler whether this task succeeded in
# loading the checkpoint.
worker.actor_checkpoint_failed = actor_checkpoint_failed
# If there was an exception during the checkpoint method, re-raise
# it after updating the actor's internal state.
if error is not None:
raise error
return None
def actor_method_executor(dummy_return_id, actor, *args):
# Update the actor's task counter to reflect the task we're about to
# execute.
worker.actor_task_counter += 1
# If this is the first task to execute on the actor, try to resume from
# a checkpoint.
if actor_imported and worker.actor_task_counter == 1:
checkpoint_resumed = restore_and_log_checkpoint(worker, actor)
if checkpoint_resumed:
# NOTE(swang): Since we did not actually execute the __init__
# method, this will put None as the return value. If the
# __init__ method is supposed to return multiple values, an
# exception will be logged.
return
# Determine whether we should checkpoint the actor.
checkpointing_on = (actor_imported and
worker.actor_checkpoint_interval > 0)
# We should checkpoint the actor if user checkpointing is on, we've
# executed checkpoint_interval tasks since the last checkpoint, and the
# method we're about to execute is not a checkpoint.
save_checkpoint = (checkpointing_on and
(worker.actor_task_counter %
worker.actor_checkpoint_interval == 0 and
method_name != "__ray_checkpoint__"))
# Execute the assigned method and save a checkpoint if necessary.
try:
method_returns = method(actor, *args)
except Exception:
# Save the checkpoint before allowing the method exception to be
# thrown.
if save_checkpoint:
save_and_log_checkpoint(worker, actor)
raise
else:
# Update the worker's internal state before executing the method in
# case the method throws an exception.
worker.actor_task_counter = task_counter + 1
# Once the actor executes a task, it counts as loaded.
worker.actor_loaded = True
# Execute the actor method.
return method(actor, *args)
# Save the checkpoint before returning the method's return values.
if save_checkpoint:
save_and_log_checkpoint(worker, actor)
return method_returns
return actor_method_executor
@ -207,7 +266,8 @@ def fetch_and_register_actor(actor_class_key, worker):
actor_method_name).id()
temporary_executor = make_actor_method_executor(worker,
actor_method_name,
temporary_actor_method)
temporary_actor_method,
actor_imported=False)
worker.functions[driver_id][function_id] = (actor_method_name,
temporary_executor)
worker.num_task_executions[driver_id][function_id] = 0
@ -218,7 +278,7 @@ def fetch_and_register_actor(actor_class_key, worker):
except Exception:
# If an exception was thrown when the actor was imported, we record the
# traceback and notify the scheduler of the failure.
traceback_str = ray.worker.format_error_message(traceback.format_exc())
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
push_error_to_driver(worker.redis_client, "register_actor_signatures",
traceback_str, driver_id,
@ -238,7 +298,8 @@ def fetch_and_register_actor(actor_class_key, worker):
function_id = compute_actor_method_function_id(
class_name, actor_method_name).id()
executor = make_actor_method_executor(worker, actor_method_name,
actor_method)
actor_method,
actor_imported=True)
worker.functions[driver_id][function_id] = (actor_method_name,
executor)
# We do not set worker.function_properties[driver_id][function_id]
@ -412,18 +473,6 @@ class ActorMethod(object):
dependency=self._actor._ray_actor_cursor)
# Checkpoint methods do not take in the state of the previous actor method
# as an explicit data dependency.
class CheckpointMethod(ActorMethod):
def remote(self):
# A checkpoint's arguments are the current task counter and the
# object ID of the preceding task. The latter is an implicit data
# dependency, since the checkpoint method can run at any time.
args = [self._actor._ray_actor_counter,
[self._actor._ray_actor_cursor]]
return self._actor._actor_method_call(self._method_name, args=args)
class ActorHandleWrapper(object):
"""A wrapper for the contents of an ActorHandle.
@ -455,9 +504,6 @@ def wrap_actor_handle(actor_handle):
Returns:
An ActorHandleWrapper instance that stores the ActorHandle's fields.
"""
if actor_handle._ray_checkpoint_interval > 0:
raise Exception("Checkpointing not yet supported for distributed "
"actor handles.")
wrapper = ActorHandleWrapper(
actor_handle._ray_actor_id,
compute_actor_handle_id(actor_handle._ray_actor_handle_id,
@ -600,12 +646,6 @@ def make_actor_handle_class(class_name):
self._ray_actor_counter += 1
self._ray_actor_cursor = object_ids.pop()
# Submit a checkpoint task if it is time to do so.
if (self._ray_checkpoint_interval > 1 and
self._ray_actor_counter % self._ray_checkpoint_interval ==
0):
self.__ray_checkpoint__.remote()
# The last object returned is the dummy object that should be
# passed in to the next actor method. Do not return it to the user.
if len(object_ids) == 1:
@ -629,10 +669,7 @@ def make_actor_handle_class(class_name):
# this was causing cyclic references which were prevent
# object deallocation from behaving in a predictable
# manner.
if attr == "__ray_checkpoint__":
actor_method_cls = CheckpointMethod
else:
actor_method_cls = ActorMethod
actor_method_cls = ActorMethod
return actor_method_cls(self, attr)
except AttributeError:
pass
@ -755,9 +792,6 @@ def make_actor(cls, resources, checkpoint_interval):
"actor placement.")
if checkpoint_interval == 0:
raise Exception("checkpoint_interval must be greater than 0.")
# Add one to the checkpoint interval since we will insert a mock task for
# every checkpoint.
checkpoint_interval += 1
# Modify the class to have an additional method that will be used for
# terminating the worker.
@ -802,97 +836,58 @@ def make_actor(cls, resources, checkpoint_interval):
actor_object = checkpoint
return actor_object
def __ray_checkpoint__(self, task_counter, previous_object_id):
"""Save or resume a stored checkpoint.
def __ray_checkpoint__(self):
"""Save a checkpoint.
This task checkpoints the current state of the actor. If the actor
has not yet executed to `task_counter`, then the task instead
attempts to resume from a saved checkpoint that matches
`task_counter`. If the most recently saved checkpoint is earlier
than `task_counter`, the task requests reconstruction of the tasks
that executed since the previous checkpoint and before
`task_counter`.
Args:
self: An instance of the actor class.
task_counter: The index assigned to this checkpoint method.
previous_object_id: The dummy object returned by the task that
immediately precedes this checkpoint.
Returns:
A bool representing whether the checkpoint was successfully
loaded (whether the actor can safely execute the next task)
and an Exception instance, if one was thrown.
This task saves the current state of the actor, the current task
frontier according to the local scheduler, and the checkpoint index
(number of tasks executed so far).
"""
worker = ray.worker.global_worker
previous_object_id = previous_object_id[0]
plasma_id = plasma.ObjectID(previous_object_id.id())
checkpoint_index = worker.actor_task_counter
# Get the state to save.
checkpoint = self.__ray_save_checkpoint__()
# Get the current task frontier, per actor handle.
# NOTE(swang): This only includes actor handles that the local
# scheduler has seen. Handle IDs for which no task has yet reached
# the local scheduler will not be included, and may not be runnable
# on checkpoint resumption.
actor_id = ray.local_scheduler.ObjectID(worker.actor_id)
frontier = worker.local_scheduler_client.get_actor_frontier(
actor_id)
# Save the checkpoint in Redis. TODO(rkn): Checkpoints
# should not be stored in Redis. Fix this.
set_actor_checkpoint(worker, worker.actor_id, checkpoint_index,
checkpoint, frontier)
# Initialize the return values. `actor_checkpoint_failed` will be
# set to True if we fail to load the checkpoint. `error` will be
# set to the Exception, if one is thrown.
actor_checkpoint_failed = False
error_to_return = None
def __ray_checkpoint_restore__(self):
"""Restore a checkpoint.
# Save or resume the checkpoint.
if worker.actor_loaded:
# The actor has loaded, so we are running the normal execution.
# Save the checkpoint.
print("Saving actor checkpoint. actor_counter = {}."
.format(task_counter))
actor_key = b"Actor:" + worker.actor_id
This task looks for a saved checkpoint and if found, restores the
state of the actor, the task frontier in the local scheduler, and
the checkpoint index (number of tasks executed so far).
try:
checkpoint = worker.actors[
worker.actor_id].__ray_save_checkpoint__()
# Save the checkpoint in Redis. TODO(rkn): Checkpoints
# should not be stored in Redis. Fix this.
worker.redis_client.hset(
actor_key,
"checkpoint_{}".format(task_counter),
checkpoint)
# Remove the previous checkpoints if there is one.
checkpoint_indices = get_checkpoint_indices(
worker, worker.actor_id)
for index in checkpoint_indices:
if index < task_counter:
worker.redis_client.hdel(
actor_key, "checkpoint_{}".format(index))
# An exception was thrown. Save the error.
except Exception as error:
# Checkpoint saves should not block execution on the actor,
# so we still consider the task successful.
error_to_return = error
else:
# The actor has not yet loaded. Try loading it from the most
# recent checkpoint.
checkpoint_index, checkpoint = get_actor_checkpoint(
worker, worker.actor_id)
if checkpoint_index == task_counter:
# The checkpoint matches ours. Resume the actor instance.
try:
actor = (worker.actor_class.
__ray_restore_from_checkpoint__(checkpoint))
worker.actors[worker.actor_id] = actor
# An exception was thrown. Save the error.
except Exception as error:
# We could not resume the checkpoint, so count the task
# as failed.
actor_checkpoint_failed = True
error_to_return = error
else:
# We cannot resume a mismatching checkpoint, so count the
# task as failed.
actor_checkpoint_failed = True
Returns:
A bool indicating whether a checkpoint was resumed.
"""
worker = ray.worker.global_worker
# Get the most recent checkpoint stored, if any.
checkpoint_index, checkpoint, frontier = get_actor_checkpoint(
worker, worker.actor_id)
# Try to resume from the checkpoint.
checkpoint_resumed = False
if checkpoint_index is not None:
# Load the actor state from the checkpoint.
worker.actors[worker.actor_id] = (
worker.actor_class.__ray_restore_from_checkpoint__(
checkpoint))
# Set the number of tasks executed so far.
worker.actor_task_counter = checkpoint_index
# Set the actor frontier in the local scheduler.
worker.local_scheduler_client.set_actor_frontier(frontier)
checkpoint_resumed = True
# Fall back to lineage reconstruction if we were unable to load the
# checkpoint.
if actor_checkpoint_failed:
worker.local_scheduler_client.reconstruct_object(
plasma_id.binary())
worker.local_scheduler_client.notify_unblocked()
return actor_checkpoint_failed, error_to_return
return checkpoint_resumed
Class.__module__ = cls.__module__
Class.__name__ = cls.__name__

View file

@ -20,6 +20,28 @@ def _random_string():
return np.random.bytes(20)
def format_error_message(exception_message, task_exception=False):
"""Improve the formatting of an exception thrown by a remote function.
This method takes a traceback from an exception and makes it nicer by
removing a few uninformative lines and adding some space to indent the
remaining lines nicely.
Args:
exception_message (str): A message generated by traceback.format_exc().
Returns:
A string of the formatted exception message.
"""
lines = exception_message.split("\n")
if task_exception:
# For errors that occur inside of tasks, remove lines 1, 2, 3, and 4,
# which are always the same, they just contain information about the
# main loop.
lines = lines[0:1] + lines[5:]
return "\n".join(lines)
def push_error_to_driver(redis_client, error_type, message, driver_id=None,
data=None):
"""Push an error message to the driver to be printed in the background.

View file

@ -222,14 +222,6 @@ class Worker(object):
self.make_actor = None
self.actors = {}
self.actor_task_counter = 0
# Whether an actor instance has been loaded yet. The actor counts as
# loaded once it has either executed its first task or successfully
# resumed from a checkpoint.
self.actor_loaded = False
# This field is used to report actor checkpoint failure for the last
# task assigned. Workers are not assigned a task on startup, so we
# initialize to False.
self.actor_checkpoint_failed = False
# The number of threads Plasma should use when putting an object in the
# object store.
self.memcopy_threads = 12
@ -755,7 +747,7 @@ class Worker(object):
except Exception as e:
self._handle_process_task_failure(
function_id, return_object_ids, e,
format_error_message(traceback.format_exc()))
ray.utils.format_error_message(traceback.format_exc()))
return
# Execute the task.
@ -765,15 +757,15 @@ class Worker(object):
outputs = function_executor.executor(arguments)
else:
outputs = function_executor(
dummy_return_id, task.actor_counter(),
dummy_return_id,
self.actors[task.actor_id().id()],
*arguments)
except Exception as e:
# Determine whether the exception occured during a task, not an
# actor method.
task_exception = task.actor_id().id() == NIL_ACTOR_ID
traceback_str = format_error_message(traceback.format_exc(),
task_exception=task_exception)
traceback_str = ray.utils.format_error_message(
traceback.format_exc(), task_exception=task_exception)
self._handle_process_task_failure(function_id, return_object_ids,
e, traceback_str)
return
@ -791,7 +783,7 @@ class Worker(object):
except Exception as e:
self._handle_process_task_failure(
function_id, return_object_ids, e,
format_error_message(traceback.format_exc()))
ray.utils.format_error_message(traceback.format_exc()))
def _handle_process_task_failure(self, function_id, return_object_ids,
error, backtrace):
@ -863,12 +855,7 @@ class Worker(object):
A task from the local scheduler.
"""
with log_span("ray:get_task", worker=self):
task = self.local_scheduler_client.get_task(
self.actor_checkpoint_failed)
# We assume that the task is not a checkpoint, or that if it is,
# that the task will succeed. The checkpoint task executor is
# responsible for reporting task failure to the local scheduler.
self.actor_checkpoint_failed = False
task = self.local_scheduler_client.get_task()
# Automatically restrict the GPUs available to this task.
ray.utils.set_cuda_visible_devices(ray.get_gpu_ids())
@ -1613,7 +1600,7 @@ def fetch_and_register_remote_function(key, worker=global_worker):
except Exception:
# If an exception was thrown when the remote function was imported, we
# record the traceback and notify the scheduler of the failure.
traceback_str = format_error_message(traceback.format_exc())
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(worker.redis_client,
"register_remote_function",
@ -2351,28 +2338,6 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
return ready_ids, remaining_ids
def format_error_message(exception_message, task_exception=False):
"""Improve the formatting of an exception thrown by a remote function.
This method takes a traceback from an exception and makes it nicer by
removing a few uninformative lines and adding some space to indent the
remaining lines nicely.
Args:
exception_message (str): A message generated by traceback.format_exc().
Returns:
A string of the formatted exception message.
"""
lines = exception_message.split("\n")
if task_exception:
# For errors that occur inside of tasks, remove lines 1, 2, 3, and 4,
# which are always the same, they just contain information about the
# main loop.
lines = lines[0:1] + lines[5:]
return "\n".join(lines)
def _submit_task(function_id, args, worker=global_worker):
"""This is a wrapper around worker.submit_task.

View file

@ -33,7 +33,21 @@ enum MessageType:int {
// scheduler.
NotifyUnblocked,
// Add a result table entry for an object put.
PutObject
PutObject,
// A request to get the task frontier for an actor, called by the actor when
// saving a checkpoint.
GetActorFrontierRequest,
// The ActorFrontier response to a GetActorFrontierRequest. The local
// scheduler returns the actor's per-handle task counts and execution
// dependencies, which can later be used as the argument to SetActorFrontier
// when resuming from the checkpoint.
GetActorFrontierReply,
// A request to set the task frontier for an actor, called when resuming from
// a checkpoint. The local scheduler will update the actor's per-handle task
// counts and execution dependencies, discard any tasks that already executed
// before the checkpoint, and make any tasks on the frontier runnable by
// making their execution dependencies available.
SetActorFrontier
}
table SubmitTaskRequest {
@ -41,14 +55,6 @@ table SubmitTaskRequest {
task_spec: string;
}
// This message is sent from a worker to a local scheduler.
table GetTaskRequest {
// Whether the previously assigned task was a checkpoint task that failed.
// If true, then the local scheduler will not update the actor's task
// counter to match the assigned checkpoint index.
actor_checkpoint_failed: bool;
}
// This message is sent from the local scheduler to a worker.
table GetTaskReply {
// A string of bytes representing the task specification.
@ -97,3 +103,29 @@ table PutObject {
// Object ID of the object that is being put.
object_id: string;
}
// The ActorFrontier is used to represent the current frontier of tasks that
// the local scheduler has marked as runnable for a particular actor. It is
// used to save the point in an actor's lifetime at which a checkpoint was
// taken, so that the same frontier of tasks can be made runnable again if the
// actor is resumed from that checkpoint.
table ActorFrontier {
// Actor ID of the actor whose frontier is described.
actor_id: string;
// A list of handle IDs, representing the callers of the actor that have
// submitted a runnable task to the local scheduler. A nil ID represents the
// creator of the actor.
handle_ids: [string];
// A list representing the number of tasks executed so far, per handle. Each
// count in task_counters corresponds to the handle at the same in index in
// handle_ids.
task_counters: [long];
// A list representing the execution dependency for the next runnable task,
// per handle. Each execution dependency in frontier_dependencies corresponds
// to the handle at the same in index in handle_ids.
frontier_dependencies: [string];
}
table GetActorFrontierRequest {
actor_id: string;
}

View file

@ -592,9 +592,7 @@ void assign_task_to_worker(LocalSchedulerState *state,
}
}
void finish_task(LocalSchedulerState *state,
LocalSchedulerClient *worker,
bool actor_checkpoint_failed) {
void finish_task(LocalSchedulerState *state, LocalSchedulerClient *worker) {
if (worker->task_in_progress != NULL) {
TaskSpec *spec = Task_task_execution_spec(worker->task_in_progress)->Spec();
/* Return dynamic resources back for the task in progress. */
@ -613,23 +611,10 @@ void finish_task(LocalSchedulerState *state,
worker->resources_in_use;
release_resources(state, worker, cpu_resources);
}
/* For successful actor tasks, mark returned dummy objects as locally
* available. This is not added to the object table, so the update will be
* invisible to other nodes. */
/* NOTE(swang): These objects are never cleaned up. We should consider
* removing the objects, e.g., when an actor is terminated. */
if (TaskSpec_is_actor_task(spec)) {
if (!actor_checkpoint_failed) {
handle_object_available(state, state->algorithm_state,
TaskSpec_actor_dummy_object(spec));
}
}
/* If we're connected to Redis, update tables. */
if (state->db != NULL) {
/* Update control state tables. If there was an error while executing a *
* checkpoint task, report the task as lost. Else, the task succeeded. */
int task_state =
actor_checkpoint_failed ? TASK_STATUS_LOST : TASK_STATUS_DONE;
/* Update control state tables. */
int task_state = TASK_STATUS_DONE;
Task_set_state(worker->task_in_progress, task_state);
#if !RAY_USE_NEW_GCS
task_table_update(state->db, worker->task_in_progress, NULL, NULL, NULL);
@ -903,10 +888,13 @@ void reconstruct_object_lookup_callback(
void reconstruct_object(LocalSchedulerState *state,
ObjectID reconstruct_object_id) {
LOG_DEBUG("Starting reconstruction");
/* TODO(swang): Track task lineage for puts. */
CHECK(state->db != NULL);
/* If the object is locally available, no need to reconstruct. */
if (object_locally_available(state->algorithm_state, reconstruct_object_id)) {
return;
}
/* Determine if reconstruction is necessary by checking if the object exists
* on a node. */
CHECK(state->db != NULL);
object_table_lookup(state->db, reconstruct_object_id, NULL,
reconstruct_object_lookup_callback, (void *) state);
}
@ -1066,6 +1054,63 @@ void handle_client_disconnect(LocalSchedulerState *state,
kill_worker(state, worker, false, worker->disconnected);
}
void handle_get_actor_frontier(LocalSchedulerState *state,
LocalSchedulerClient *worker,
ActorID actor_id) {
auto task_counters =
get_actor_task_counters(state->algorithm_state, actor_id);
auto frontier = get_actor_frontier(state->algorithm_state, actor_id);
/* Build the ActorFrontier flatbuffer. */
std::vector<ActorHandleID> handle_vector;
std::vector<int64_t> task_counter_vector;
std::vector<ObjectID> frontier_vector;
for (auto handle : task_counters) {
handle_vector.push_back(handle.first);
task_counter_vector.push_back(handle.second);
frontier_vector.push_back(frontier[handle.first]);
}
flatbuffers::FlatBufferBuilder fbb;
auto reply = CreateActorFrontier(
fbb, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, handle_vector),
fbb.CreateVector(task_counter_vector), to_flatbuf(fbb, frontier_vector));
fbb.Finish(reply);
/* Respond with the built ActorFrontier. */
if (write_message(worker->sock, MessageType_GetActorFrontierReply,
fbb.GetSize(), (uint8_t *) fbb.GetBufferPointer()) < 0) {
if (errno == EPIPE || errno == EBADF) {
/* Something went wrong, so kill the worker. */
kill_worker(state, worker, false, false);
LOG_WARN(
"Failed to return actor frontier to worker on fd %d. The client may "
"have hung "
"up.",
worker->sock);
} else {
LOG_FATAL("Failed to give task to client on fd %d.", worker->sock);
}
}
}
void handle_set_actor_frontier(LocalSchedulerState *state,
LocalSchedulerClient *worker,
ActorFrontier const &frontier) {
/* Parse the ActorFrontier flatbuffer. */
ActorID actor_id = from_flatbuf(*frontier.actor_id());
std::unordered_map<ActorID, int64_t, UniqueIDHasher> task_counters;
std::unordered_map<ActorID, ObjectID, UniqueIDHasher> frontier_dependencies;
for (size_t i = 0; i < frontier.handle_ids()->size(); ++i) {
ActorID handle_id = from_flatbuf(*frontier.handle_ids()->Get(i));
task_counters[handle_id] = frontier.task_counters()->Get(i);
frontier_dependencies[handle_id] =
from_flatbuf(*frontier.frontier_dependencies()->Get(i));
}
/* Set the actor's frontier. */
set_actor_task_counters(state->algorithm_state, actor_id, task_counters);
set_actor_frontier(state, state->algorithm_state, actor_id,
frontier_dependencies);
}
void process_message(event_loop *loop,
int client_sock,
void *context,
@ -1113,7 +1158,7 @@ void process_message(event_loop *loop,
case MessageType_TaskDone: {
} break;
case MessageType_DisconnectClient: {
finish_task(state, worker, false);
finish_task(state, worker);
CHECK(!worker->disconnected);
worker->disconnected = true;
/* If the disconnected worker was not an actor, start a new worker to make
@ -1139,16 +1184,13 @@ void process_message(event_loop *loop,
} break;
case MessageType_GetTask: {
/* If this worker reports a completed task, account for resources. */
auto message = flatbuffers::GetRoot<GetTaskRequest>(input);
bool actor_checkpoint_failed = message->actor_checkpoint_failed();
finish_task(state, worker, actor_checkpoint_failed);
finish_task(state, worker);
/* Let the scheduling algorithm process the fact that there is an available
* worker. */
if (worker->actor_id.is_nil()) {
handle_worker_available(state, state->algorithm_state, worker);
} else {
handle_actor_worker_available(state, state->algorithm_state, worker,
actor_checkpoint_failed);
handle_actor_worker_available(state, state->algorithm_state, worker);
}
} break;
case MessageType_ReconstructObject: {
@ -1211,6 +1253,15 @@ void process_message(event_loop *loop,
result_table_add(state->db, from_flatbuf(*message->object_id()),
from_flatbuf(*message->task_id()), true, NULL, NULL, NULL);
} break;
case MessageType_GetActorFrontierRequest: {
auto message = flatbuffers::GetRoot<GetActorFrontierRequest>(input);
ActorID actor_id = from_flatbuf(*message->actor_id());
handle_get_actor_frontier(state, worker, actor_id);
} break;
case MessageType_SetActorFrontier: {
auto message = flatbuffers::GetRoot<ActorFrontier>(input);
handle_set_actor_frontier(state, worker, *message);
} break;
default:
/* This code should be unreachable. */
CHECK(0);

View file

@ -46,13 +46,9 @@ void assign_task_to_worker(LocalSchedulerState *state,
*
* @param state The local scheduler state.
* @param worker The worker that finished the task.
* @param actor_checkpoint_failed If the last task assigned was a checkpoint
* task that failed.
* @return Void.
*/
void finish_task(LocalSchedulerState *state,
LocalSchedulerClient *worker,
bool actor_checkpoint_failed);
void finish_task(LocalSchedulerState *state, LocalSchedulerClient *worker);
/**
* This is the callback that is used to process a notification from the Plasma

View file

@ -24,6 +24,9 @@ void give_task_to_local_scheduler(LocalSchedulerState *state,
TaskExecutionSpec &execution_spec,
DBClientID local_scheduler_id);
void clear_missing_dependencies(SchedulingAlgorithmState *algorithm_state,
std::list<TaskExecutionSpec>::iterator it);
/** A data structure used to track which objects are available locally and
* which objects are being actively fetched. Objects of this type are used for
* both the scheduling algorithm state's local_objects and remote_objects
@ -48,28 +51,18 @@ typedef struct {
* handle. This is used to guarantee execution of tasks on actors in the
* order that the tasks were submitted, per handle. Tasks from different
* handles to the same actor may be interleaved. */
std::unordered_map<ActorID, int64_t, UniqueIDHasher> task_counters;
std::unordered_map<ActorHandleID, int64_t, UniqueIDHasher> task_counters;
/** These are the execution dependencies that make up the frontier of the
* actor's runnable tasks. For each actor handle, we store the object ID
* that represents the execution dependency for the next runnable task
* submitted by that handle. */
std::unordered_map<ActorHandleID, ObjectID, UniqueIDHasher>
frontier_dependencies;
/** The return value of the most recently executed task. The next task to
* execute should take this as an execution dependency at dispatch time. Set
* to nil if there are no execution dependencies (e.g., this is the first
* task to execute). */
ObjectID execution_dependency;
/** The index of the task assigned to this actor. Set to -1 if no task is
* currently assigned. If the actor process reports back success for the
* assigned task execution, then the corresponding task_counter should be
* updated to this value. */
int64_t assigned_task_counter;
/** The handle that the currently assigned task was submitted by. This field
* is only valid if assigned_task_counter is set. If the actor process
* reports back success for the assigned task execution, then the
* task_counter corresponding to this handle should be updated. */
ActorID assigned_task_handle_id;
/** Whether the actor process has loaded yet. The actor counts as loaded once
* it has either executed its first task or successfully resumed from a
* checkpoint. Before the actor has loaded, we may dispatch the first task
* or any checkpoint tasks. After it has loaded, we may only dispatch tasks
* in order. */
bool loaded;
/** A queue of tasks to be executed on this actor. The tasks will be sorted by
* the order of their actor counters. */
std::list<TaskExecutionSpec> *task_queue;
@ -223,16 +216,14 @@ void create_actor(SchedulingAlgorithmState *algorithm_state,
ActorID actor_id,
LocalSchedulerClient *worker) {
LocalActorInfo entry;
entry.task_counters[ActorID::nil()] = 0;
entry.task_counters[ActorHandleID::nil()] = 0;
entry.frontier_dependencies[ActorHandleID::nil()] = ObjectID::nil();
/* The actor has not yet executed any tasks, so there are no execution
* dependencies for the next task to be scheduled. */
entry.execution_dependency = ObjectID::nil();
entry.assigned_task_counter = -1;
entry.assigned_task_handle_id = ActorID::nil();
entry.task_queue = new std::list<TaskExecutionSpec>();
entry.worker = worker;
entry.worker_available = false;
entry.loaded = false;
CHECK(algorithm_state->local_actor_infos.count(actor_id) == 0)
algorithm_state->local_actor_infos[actor_id] = entry;
@ -305,21 +296,11 @@ bool dispatch_actor_task(LocalSchedulerState *state,
/* Check whether we can execute the first task in the queue. */
auto task = entry.task_queue->begin();
TaskSpec *spec = task->Spec();
int64_t next_task_counter = TaskSpec_actor_counter(spec);
ActorID next_task_handle_id = TaskSpec_actor_handle_id(spec);
if (entry.loaded) {
/* Once the actor has loaded, we can only execute tasks in order of
* task_counter. */
if (next_task_counter != entry.task_counters[next_task_handle_id]) {
return false;
}
} else {
/* If the actor has not yet loaded, we can only execute the task that
* matches task_counter (the first task), or a checkpoint task. */
if (next_task_counter != entry.task_counters[next_task_handle_id]) {
/* No other task should be first in the queue. */
CHECK(TaskSpec_is_actor_checkpoint_method(spec));
}
ActorHandleID next_task_handle_id = TaskSpec_actor_handle_id(spec);
/* We can only execute tasks in order of task_counter. */
if (TaskSpec_actor_counter(spec) !=
entry.task_counters[next_task_handle_id]) {
return false;
}
/* If there are not enough resources available, we cannot assign the task. */
@ -339,13 +320,7 @@ bool dispatch_actor_task(LocalSchedulerState *state,
/* Only overwrite execution dependencies for tasks that have a
* submission-time dependency (meaning it is not the initial task). */
if (!entry.execution_dependency.is_nil()) {
/* A checkpoint resumption should be able to run at any time, so only add
* execution dependencies for non-checkpoint tasks. */
if (!TaskSpec_is_actor_checkpoint_method(spec)) {
/* All other tasks have a dependency on the task that executed most
* recently on the actor. */
ordered_execution_dependencies.push_back(entry.execution_dependency);
}
ordered_execution_dependencies.push_back(entry.execution_dependency);
}
task->SetExecutionDependencies(ordered_execution_dependencies);
@ -353,12 +328,13 @@ bool dispatch_actor_task(LocalSchedulerState *state,
* as unavailable. */
assign_task_to_worker(state, *task, entry.worker);
entry.execution_dependency = TaskSpec_actor_dummy_object(spec);
entry.assigned_task_counter = next_task_counter;
entry.assigned_task_handle_id = next_task_handle_id;
entry.worker_available = false;
/* Extend the frontier to include the assigned task. */
entry.task_counters[next_task_handle_id] += 1;
entry.frontier_dependencies[next_task_handle_id] = entry.execution_dependency;
/* Remove the task from the actor's task queue. */
entry.task_queue->erase(task);
/* If there are no more tasks in the queue, then indicate that the actor has
* no tasks. */
if (entry.task_queue->empty()) {
@ -437,7 +413,7 @@ void insert_actor_task_queue(LocalSchedulerState *state,
TaskSpec *spec = task_entry.Spec();
/* Get the local actor entry for this actor. */
ActorID actor_id = TaskSpec_actor_id(spec);
ActorID task_handle_id = TaskSpec_actor_handle_id(spec);
ActorHandleID task_handle_id = TaskSpec_actor_handle_id(spec);
int64_t task_counter = TaskSpec_actor_counter(spec);
/* Fail the task immediately; it's destined for a dead actor. */
@ -459,6 +435,12 @@ void insert_actor_task_queue(LocalSchedulerState *state,
if (entry.task_counters.count(task_handle_id) == 0) {
entry.task_counters[task_handle_id] = 0;
}
/* Extend the frontier to include the new handle. */
if (entry.frontier_dependencies.count(task_handle_id) == 0) {
CHECK(task_entry.ExecutionDependencies().size() == 1);
entry.frontier_dependencies[task_handle_id] =
task_entry.ExecutionDependencies()[1];
}
/* As a sanity check, the counter of the new task should be greater than the
* number of tasks that have executed on this actor so far (since we are
@ -638,6 +620,47 @@ void fetch_missing_dependencies(
CHECK(num_missing_dependencies > 0);
}
/**
* Clear a queued task's missing object dependencies. This is the inverse of
* fetch_missing_dependencies.
* TODO(swang): Test this function.
*
* @param algorithm_state The scheduling algorithm state.
* @param task_entry_it A reference to the task entry in the waiting queue.
* @returns Void.
*/
void clear_missing_dependencies(
SchedulingAlgorithmState *algorithm_state,
std::list<TaskExecutionSpec>::iterator task_entry_it) {
int64_t num_dependencies = task_entry_it->NumDependencies();
for (int64_t i = 0; i < num_dependencies; ++i) {
int count = task_entry_it->DependencyIdCount(i);
for (int j = 0; j < count; ++j) {
ObjectID obj_id = task_entry_it->DependencyId(i, j);
/* If this object dependency is missing, remove this task from the
* object's list of dependent tasks. */
auto entry = algorithm_state->remote_objects.find(obj_id);
if (entry != algorithm_state->remote_objects.end()) {
/* Find and remove the given task. */
auto &dependent_tasks = entry->second.dependent_tasks;
for (auto dependent_task_it = dependent_tasks.begin();
dependent_task_it != dependent_tasks.end();) {
if (*dependent_task_it == task_entry_it) {
dependent_task_it = dependent_tasks.erase(dependent_task_it);
} else {
dependent_task_it++;
}
}
/* If the missing object dependency has no more dependent tasks, then
* remove it. */
if (dependent_tasks.empty()) {
algorithm_state->remote_objects.erase(entry);
}
}
}
}
}
/**
* Check if all of the remote object arguments for a task are available in the
* local object store.
@ -665,6 +688,11 @@ bool can_run(SchedulingAlgorithmState *algorithm_state,
return true;
}
bool object_locally_available(SchedulingAlgorithmState *algorithm_state,
ObjectID object_id) {
return algorithm_state->local_objects.count(object_id) == 1;
}
/* TODO(swang): This method is not covered by any valgrind tests. */
int fetch_object_timeout_handler(event_loop *loop, timer_id id, void *context) {
int64_t start_time = current_time_ms();
@ -943,6 +971,27 @@ void queue_waiting_task(LocalSchedulerState *state,
SchedulingAlgorithmState *algorithm_state,
TaskExecutionSpec &execution_spec,
bool from_global_scheduler) {
/* For actor tasks, do not queue tasks that have already been executed. */
auto spec = execution_spec.Spec();
if (!TaskSpec_actor_id(spec).is_nil()) {
auto entry =
algorithm_state->local_actor_infos.find(TaskSpec_actor_id(spec));
if (entry != algorithm_state->local_actor_infos.end()) {
/* Find the highest task counter with the same handle ID as the task to
* queue. */
auto &task_counters = entry->second.task_counters;
auto task_counter = task_counters.find(TaskSpec_actor_handle_id(spec));
if (task_counter != task_counters.end() &&
TaskSpec_actor_counter(spec) < task_counter->second) {
/* If the task to queue has a lower task counter, do not queue it. */
LOG_INFO(
"A task that has already been executed has been resubmitted, so we "
"are ignoring it. This should only happen during reconstruction.");
return;
}
}
}
LOG_DEBUG("Queueing task in waiting queue");
auto it = queue_task(state, algorithm_state->waiting_task_queue,
execution_spec, from_global_scheduler);
@ -1349,33 +1398,29 @@ void handle_actor_worker_disconnect(LocalSchedulerState *state,
dispatch_all_tasks(state, algorithm_state);
}
/* NOTE(swang): For tasks that saved a checkpoint, we should consider
* overwriting the result table entries for the current task frontier to
* avoid duplicate task submissions during reconstruction. */
void handle_actor_worker_available(LocalSchedulerState *state,
SchedulingAlgorithmState *algorithm_state,
LocalSchedulerClient *worker,
bool actor_checkpoint_failed) {
LocalSchedulerClient *worker) {
ActorID actor_id = worker->actor_id;
CHECK(!actor_id.is_nil());
/* Get the actor info for this worker. */
CHECK(algorithm_state->local_actor_infos.count(actor_id) == 1);
LocalActorInfo &entry =
algorithm_state->local_actor_infos.find(actor_id)->second;
CHECK(worker == entry.worker);
CHECK(!entry.worker_available);
/* If the assigned task was not a checkpoint task, or if it was but it
* loaded the checkpoint successfully, then we update the actor's counter
* to the assigned counter. */
if (!actor_checkpoint_failed) {
entry.task_counters[entry.assigned_task_handle_id] =
entry.assigned_task_counter + 1;
/* If a task was assigned to this actor and there was no checkpoint
* failure, then it is now loaded. */
if (entry.assigned_task_counter > -1) {
entry.loaded = true;
}
/* If an actor task was assigned, mark returned dummy object as locally
* available. This is not added to the object table, so the update will be
* invisible to other nodes. */
/* NOTE(swang): These objects are never cleaned up. We should consider
* removing the objects, e.g., when an actor is terminated. */
if (!entry.execution_dependency.is_nil()) {
handle_object_available(state, algorithm_state, entry.execution_dependency);
}
entry.assigned_task_counter = -1;
entry.assigned_task_handle_id = ActorID::nil();
/* Unset the fields indicating an assigned task. */
entry.worker_available = true;
/* Assign new tasks if possible. */
dispatch_all_tasks(state, algorithm_state);
@ -1611,3 +1656,82 @@ void print_worker_info(const char *message,
algorithm_state->executing_workers.size(),
algorithm_state->blocked_workers.size());
}
std::unordered_map<ActorHandleID, int64_t, UniqueIDHasher>
get_actor_task_counters(SchedulingAlgorithmState *algorithm_state,
ActorID actor_id) {
CHECK(algorithm_state->local_actor_infos.count(actor_id) != 0);
return algorithm_state->local_actor_infos[actor_id].task_counters;
}
void set_actor_task_counters(
SchedulingAlgorithmState *algorithm_state,
ActorID actor_id,
const std::unordered_map<ActorHandleID, int64_t, UniqueIDHasher>
&task_counters) {
CHECK(algorithm_state->local_actor_infos.count(actor_id) != 0);
/* Overwrite the current task counters for the actor. This is necessary
* during reconstruction when resuming from a checkpoint so that we can
* resume the task frontier at the time that the checkpoint was saved. */
auto &entry = algorithm_state->local_actor_infos[actor_id];
entry.task_counters = task_counters;
/* Filter out tasks for the actor that were submitted earlier than the new
* task counter. These represent tasks that executed before the actor's
* resumed checkpoint, and therefore should not be re-executed. */
for (auto it = entry.task_queue->begin(); it != entry.task_queue->end();) {
/* Filter out duplicate tasks for the actor that are runnable. */
TaskSpec *pending_task_spec = it->Spec();
ActorHandleID handle_id = TaskSpec_actor_handle_id(pending_task_spec);
auto task_counter = entry.task_counters.find(handle_id);
if (task_counter != entry.task_counters.end() &&
TaskSpec_actor_counter(pending_task_spec) < task_counter->second) {
/* If the task's counter is less than the highest count for that handle,
* then remove it from the actor's runnable queue. */
it = entry.task_queue->erase(it);
} else {
it++;
}
}
for (auto it = algorithm_state->waiting_task_queue->begin();
it != algorithm_state->waiting_task_queue->end();) {
/* Filter out duplicate tasks for the actor that are waiting on a missing
* dependency. */
TaskSpec *spec = it->Spec();
if (TaskSpec_actor_id(spec) == actor_id &&
TaskSpec_actor_counter(spec) <
entry.task_counters[TaskSpec_actor_handle_id(spec)]) {
/* If the waiting task is for the same actor and its task counter is less
* than the highest count for that handle, then clear its object
* dependencies and remove it from the queue. */
clear_missing_dependencies(algorithm_state, it);
it = algorithm_state->waiting_task_queue->erase(it);
} else {
it++;
}
}
}
std::unordered_map<ActorHandleID, ObjectID, UniqueIDHasher> get_actor_frontier(
SchedulingAlgorithmState *algorithm_state,
ActorID actor_id) {
CHECK(algorithm_state->local_actor_infos.count(actor_id) != 0);
return algorithm_state->local_actor_infos[actor_id].frontier_dependencies;
}
void set_actor_frontier(
LocalSchedulerState *state,
SchedulingAlgorithmState *algorithm_state,
ActorID actor_id,
const std::unordered_map<ActorHandleID, ObjectID, UniqueIDHasher>
&frontier_dependencies) {
CHECK(algorithm_state->local_actor_infos.count(actor_id) != 0);
auto entry = algorithm_state->local_actor_infos[actor_id];
entry.frontier_dependencies = frontier_dependencies;
for (auto frontier_dependency : entry.frontier_dependencies) {
if (algorithm_state->local_objects.count(frontier_dependency.second) == 0) {
handle_object_available(state, algorithm_state,
frontier_dependency.second);
}
}
}

View file

@ -165,14 +165,11 @@ void handle_worker_removed(LocalSchedulerState *state,
* @param state The state of the local scheduler.
* @param algorithm_state State maintained by the scheduling algorithm.
* @param worker The worker that is available.
* @param actor_checkpoint_failed If the last task assigned was a checkpoint
* task that failed.
* @return Void.
*/
void handle_actor_worker_available(LocalSchedulerState *state,
SchedulingAlgorithmState *algorithm_state,
LocalSchedulerClient *worker,
bool actor_checkpoint_failed);
LocalSchedulerClient *worker);
/**
* Handle the fact that a new worker is available for running an actor.
@ -295,6 +292,16 @@ int fetch_object_timeout_handler(event_loop *loop, timer_id id, void *context);
int reconstruct_object_timeout_handler(event_loop *loop,
timer_id id,
void *context);
/**
* Check whether an object, including actor dummy objects, is locally
* available.
*
* @param algorithm_state State maintained by the scheduling algorithm.
* @param object_id The ID of the object to check for.
* @return A bool representing whether the object is locally available.
*/
bool object_locally_available(SchedulingAlgorithmState *algorithm_state,
ObjectID object_id);
/**
* A helper function to print debug information about the current state and
@ -307,6 +314,87 @@ int reconstruct_object_timeout_handler(event_loop *loop,
void print_worker_info(const char *message,
SchedulingAlgorithmState *algorithm_state);
/*
* The actor frontier consists of the number of tasks executed so far and the
* execution dependencies required by the current runnable tasks, according to
* the actor's local scheduler. Since an actor may have multiple handles, the
* tasks submitted to the actor form a DAG, where nodes are tasks and edges are
* execution dependencies. The frontier is a cut across this DAG. The number of
* tasks so far is the number of nodes included in the DAG root's partition.
*
* The actor gets the current frontier of tasks from the local scheduler during
* a checkpoint save, so that it can save the point in the actor's lifetime at
* which the checkpoint was taken. If the actor later resumes from that
* checkpoint, the actor can set the current frontier of tasks in the local
* scheduler so that the same frontier of tasks can be made runnable again
* during reconstruction, and so that we do not duplicate execution of tasks
* that already executed before the checkpoint.
*/
/**
* Get the number of tasks, per actor handle, that have been executed on an
* actor so far.
*
* @param algorithm_state State maintained by the scheduling algorithm.
* @param actor_id The ID of the actor whose task counters are returned.
* @return A map from handle ID to the number of tasks submitted by that handle
* that have executed so far.
*/
std::unordered_map<ActorHandleID, int64_t, UniqueIDHasher>
get_actor_task_counters(SchedulingAlgorithmState *algorithm_state,
ActorID actor_id);
/**
* Set the number of tasks, per actor handle, that have been executed on an
* actor so far. All previous counts will be overwritten. Tasks that are
* waiting or runnable on the local scheduler that have a lower task count will
* be discarded, so that we don't duplicate execution.
*
* @param algorithm_state State maintained by the scheduling algorithm.
* @param actor_id The ID of the actor whose task counters are returned.
* @param task_counters A map from handle ID to the number of tasks submitted
* by that handle that have executed so far.
* @return Void.
*/
void set_actor_task_counters(
SchedulingAlgorithmState *algorithm_state,
ActorID actor_id,
const std::unordered_map<ActorHandleID, int64_t, UniqueIDHasher>
&task_counters);
/**
* Get the actor's frontier of task dependencies.
* NOTE(swang): The returned frontier only includes handles known by the local
* scheduler. It does not include handles for which the local scheduler has not
* seen a runnable task yet.
*
* @param algorithm_state State maintained by the scheduling algorithm.
* @param actor_id The ID of the actor whose task counters are returned.
* @return A map from handle ID to execution dependency for the earliest
* runnable task submitted through that handle.
*/
std::unordered_map<ActorHandleID, ObjectID, UniqueIDHasher> get_actor_frontier(
SchedulingAlgorithmState *algorithm_state,
ActorID actor_id);
/**
* Set the actor's frontier of task dependencies. The previous frontier will be
* overwritten. Any tasks that have an execution dependency on the new frontier
* (and that have all other dependencies fulfilled) will become runnable.
*
* @param algorithm_state State maintained by the scheduling algorithm.
* @param actor_id The ID of the actor whose task counters are returned.
* @param frontier_dependencies A map from handle ID to execution dependency
* for the earliest runnable task submitted through that handle.
* @return Void.
*/
void set_actor_frontier(
LocalSchedulerState *state,
SchedulingAlgorithmState *algorithm_state,
ActorID actor_id,
const std::unordered_map<ActorHandleID, ObjectID, UniqueIDHasher>
&frontier_dependencies);
/** The following methods are for testing purposes only. */
#ifdef LOCAL_SCHEDULER_TEST
/**

View file

@ -102,13 +102,8 @@ void local_scheduler_submit(LocalSchedulerConnection *conn,
}
TaskSpec *local_scheduler_get_task(LocalSchedulerConnection *conn,
int64_t *task_size,
bool actor_checkpoint_failed) {
flatbuffers::FlatBufferBuilder fbb;
auto message = CreateGetTaskRequest(fbb, actor_checkpoint_failed);
fbb.Finish(message);
write_message(conn->conn, MessageType_GetTask, fbb.GetSize(),
fbb.GetBufferPointer());
int64_t *task_size) {
write_message(conn->conn, MessageType_GetTask, 0, NULL);
int64_t type;
int64_t reply_size;
uint8_t *reply;
@ -177,3 +172,29 @@ void local_scheduler_put_object(LocalSchedulerConnection *conn,
write_message(conn->conn, MessageType_PutObject, fbb.GetSize(),
fbb.GetBufferPointer());
}
const std::vector<uint8_t> local_scheduler_get_actor_frontier(
LocalSchedulerConnection *conn,
ActorID actor_id) {
flatbuffers::FlatBufferBuilder fbb;
auto message = CreateGetActorFrontierRequest(fbb, to_flatbuf(fbb, actor_id));
fbb.Finish(message);
write_message(conn->conn, MessageType_GetActorFrontierRequest, fbb.GetSize(),
fbb.GetBufferPointer());
int64_t type;
std::vector<uint8_t> reply;
read_vector(conn->conn, &type, reply);
if (type == DISCONNECT_CLIENT) {
LOG_DEBUG("Exiting because local scheduler closed connection.");
exit(1);
}
CHECK(type == MessageType_GetActorFrontierReply);
return reply;
}
void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn,
const std::vector<uint8_t> &frontier) {
write_message(conn->conn, MessageType_SetActorFrontier, frontier.size(),
const_cast<uint8_t *>(frontier.data()));
}

View file

@ -94,13 +94,10 @@ void local_scheduler_log_event(LocalSchedulerConnection *conn,
*
* @param conn The connection information.
* @param task_size A pointer to fill out with the task size.
* @param actor_checkpoint_failed If the last task assigned was a checkpoint
* task that failed.
* @return The address of the assigned task.
*/
TaskSpec *local_scheduler_get_task(LocalSchedulerConnection *conn,
int64_t *task_size,
bool actor_checkpoint_failed);
int64_t *task_size);
/**
* Tell the local scheduler that the client has finished executing a task.
@ -148,4 +145,25 @@ void local_scheduler_put_object(LocalSchedulerConnection *conn,
TaskID task_id,
ObjectID object_id);
/**
* Get an actor's current task frontier.
*
* @param conn The connection information.
* @param actor_id The ID of the actor whose frontier is returned.
* @return A byte vector that can be traversed as an ActorFrontier flatbuffer.
*/
const std::vector<uint8_t> local_scheduler_get_actor_frontier(
LocalSchedulerConnection *conn,
ActorID actor_id);
/**
* Set an actor's current task frontier.
*
* @param conn The connection information.
* @param frontier An ActorFrontier flatbuffer to set the frontier to.
* @return Void.
*/
void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn,
const std::vector<uint8_t> &frontier);
#endif

View file

@ -63,25 +63,15 @@ static PyObject *PyLocalSchedulerClient_submit(PyObject *self, PyObject *args) {
}
// clang-format off
static PyObject *PyLocalSchedulerClient_get_task(PyObject *self, PyObject *args) {
PyObject *py_actor_checkpoint_failed = NULL;
if (!PyArg_ParseTuple(args, "|O", &py_actor_checkpoint_failed)) {
return NULL;
}
static PyObject *PyLocalSchedulerClient_get_task(PyObject *self) {
TaskSpec *task_spec;
int64_t task_size;
/* If no argument for actor_checkpoint_failed was provided, default to false,
* since we assume that there was no previous task. */
bool actor_checkpoint_failed = false;
if (py_actor_checkpoint_failed != NULL) {
actor_checkpoint_failed = (bool) PyObject_IsTrue(py_actor_checkpoint_failed);
}
/* Drop the global interpreter lock while we get a task because
* local_scheduler_get_task may block for a long time. */
Py_BEGIN_ALLOW_THREADS
task_spec = local_scheduler_get_task(
((PyLocalSchedulerClient *) self)->local_scheduler_connection,
&task_size, actor_checkpoint_failed);
&task_size);
Py_END_ALLOW_THREADS
return PyTask_make(task_spec, task_size);
}
@ -148,12 +138,41 @@ static PyObject *PyLocalSchedulerClient_gpu_ids(PyObject *self) {
return gpu_ids_list;
}
static PyObject *PyLocalSchedulerClient_get_actor_frontier(PyObject *self,
PyObject *args) {
ActorID actor_id;
if (!PyArg_ParseTuple(args, "O&", &PyObjectToUniqueID, &actor_id)) {
return NULL;
}
auto frontier = local_scheduler_get_actor_frontier(
((PyLocalSchedulerClient *) self)->local_scheduler_connection, actor_id);
return PyBytes_FromStringAndSize(
reinterpret_cast<const char *>(frontier.data()), frontier.size());
}
static PyObject *PyLocalSchedulerClient_set_actor_frontier(PyObject *self,
PyObject *args) {
PyObject *py_frontier;
if (!PyArg_ParseTuple(args, "O", &py_frontier)) {
return NULL;
}
std::vector<uint8_t> frontier;
Py_ssize_t length = PyBytes_Size(py_frontier);
char *frontier_data = PyBytes_AsString(py_frontier);
frontier.assign(frontier_data, frontier_data + length);
local_scheduler_set_actor_frontier(
((PyLocalSchedulerClient *) self)->local_scheduler_connection, frontier);
Py_RETURN_NONE;
}
static PyMethodDef PyLocalSchedulerClient_methods[] = {
{"disconnect", (PyCFunction) PyLocalSchedulerClient_disconnect, METH_NOARGS,
"Notify the local scheduler that this client is exiting gracefully."},
{"submit", (PyCFunction) PyLocalSchedulerClient_submit, METH_VARARGS,
"Submit a task to the local scheduler."},
{"get_task", (PyCFunction) PyLocalSchedulerClient_get_task, METH_VARARGS,
{"get_task", (PyCFunction) PyLocalSchedulerClient_get_task, METH_NOARGS,
"Get a task from the local scheduler."},
{"reconstruct_object",
(PyCFunction) PyLocalSchedulerClient_reconstruct_object, METH_VARARGS,
@ -166,6 +185,10 @@ static PyMethodDef PyLocalSchedulerClient_methods[] = {
METH_VARARGS, "Return the object ID for a put call within a task."},
{"gpu_ids", (PyCFunction) PyLocalSchedulerClient_gpu_ids, METH_NOARGS,
"Get the IDs of the GPUs that are reserved for this client."},
{"get_actor_frontier",
(PyCFunction) PyLocalSchedulerClient_get_actor_frontier, METH_VARARGS, ""},
{"set_actor_frontier",
(PyCFunction) PyLocalSchedulerClient_set_actor_frontier, METH_VARARGS, ""},
{NULL} /* Sentinel */
};

View file

@ -210,12 +210,12 @@ TEST object_reconstruction_test(void) {
int64_t task_assigned_size;
local_scheduler_submit(worker, execution_spec);
TaskSpec *task_assigned =
local_scheduler_get_task(worker, &task_assigned_size, true);
local_scheduler_get_task(worker, &task_assigned_size);
ASSERT_EQ(memcmp(task_assigned, spec, task_size), 0);
ASSERT_EQ(task_assigned_size, task_size);
int64_t reconstruct_task_size;
TaskSpec *reconstruct_task =
local_scheduler_get_task(worker, &reconstruct_task_size, true);
local_scheduler_get_task(worker, &reconstruct_task_size);
ASSERT_EQ(memcmp(reconstruct_task, spec, task_size), 0);
ASSERT_EQ(reconstruct_task_size, task_size);
/* Clean up. */
@ -278,16 +278,10 @@ TEST object_reconstruction_recursive_test(void) {
specs.push_back(example_task_execution_spec(0, 1));
for (int i = 1; i < NUM_TASKS; ++i) {
ObjectID arg_id = TaskSpec_return(specs[i - 1].Spec(), 0);
handle_object_available(
local_scheduler->local_scheduler_state,
local_scheduler->local_scheduler_state->algorithm_state, arg_id);
specs.push_back(example_task_execution_spec_with_args(1, 1, &arg_id));
}
/* Add an empty object table entry for each object we want to reconstruct, to
* simulate their having been created and evicted. */
const char *client_id = "clientid";
/* Lookup the shard locations for the object table. */
const char *client_id = "clientid";
std::vector<std::string> db_shards_addresses;
std::vector<int> db_shards_ports;
redisContext *context = redisConnect("127.0.0.1", 6379);
@ -319,8 +313,7 @@ TEST object_reconstruction_recursive_test(void) {
/* Make sure we receive each task from the initial submission. */
for (int i = 0; i < NUM_TASKS; ++i) {
int64_t task_size;
TaskSpec *task_assigned =
local_scheduler_get_task(worker, &task_size, true);
TaskSpec *task_assigned = local_scheduler_get_task(worker, &task_size);
ASSERT_EQ(memcmp(task_assigned, specs[i].Spec(), specs[i].SpecSize()), 0);
ASSERT_EQ(task_size, specs[i].SpecSize());
free(task_assigned);
@ -330,7 +323,7 @@ TEST object_reconstruction_recursive_test(void) {
for (int i = 0; i < NUM_TASKS; ++i) {
int64_t task_assigned_size;
TaskSpec *task_assigned =
local_scheduler_get_task(worker, &task_assigned_size, true);
local_scheduler_get_task(worker, &task_assigned_size);
for (auto it = specs.begin(); it != specs.end(); it++) {
if (memcmp(task_assigned, it->Spec(), task_assigned_size) == 0) {
specs.erase(it);
@ -343,8 +336,17 @@ TEST object_reconstruction_recursive_test(void) {
LocalSchedulerMock_free(local_scheduler);
exit(0);
} else {
/* Run the event loop. NOTE: OSX appears to require the parent process to
* listen for events on the open file descriptors. */
/* Simulate each task putting its return values in the object store so that
* the next task can run. */
for (int i = 0; i < NUM_TASKS; ++i) {
ObjectID return_id = TaskSpec_return(specs[i].Spec(), 0);
handle_object_available(
local_scheduler->local_scheduler_state,
local_scheduler->local_scheduler_state->algorithm_state, return_id);
}
/* Run the event loop. All tasks should now be dispatched. NOTE: OSX
* appears to require the parent process to listen for events on the open
* file descriptors. */
event_loop_add_timer(local_scheduler->loop, 500,
(event_loop_timer_handler) timeout_handler, NULL);
event_loop_run(local_scheduler->loop);
@ -361,10 +363,27 @@ TEST object_reconstruction_recursive_test(void) {
&local_scheduler->local_scheduler_state->gcs_client, last_task));
Task_free(last_task);
#endif
/* Trigger reconstruction for the last object, and run the event loop
* again. */
/* Simulate eviction of the objects, so that reconstruction is required. */
for (int i = 0; i < NUM_TASKS; ++i) {
ObjectID return_id = TaskSpec_return(specs[i].Spec(), 0);
handle_object_removed(local_scheduler->local_scheduler_state, return_id);
}
/* Trigger reconstruction for the last object. */
ObjectID return_id = TaskSpec_return(specs[NUM_TASKS - 1].Spec(), 0);
local_scheduler_reconstruct_object(worker, return_id);
/* Run the event loop again. All tasks should be resubmitted. */
event_loop_add_timer(local_scheduler->loop, 500,
(event_loop_timer_handler) timeout_handler, NULL);
event_loop_run(local_scheduler->loop);
/* Simulate each task putting its return values in the object store so that
* the next task can run. */
for (int i = 0; i < NUM_TASKS; ++i) {
ObjectID return_id = TaskSpec_return(specs[i].Spec(), 0);
handle_object_available(
local_scheduler->local_scheduler_state,
local_scheduler->local_scheduler_state->algorithm_state, return_id);
}
/* Run the event loop again. All tasks should be dispatched again. */
event_loop_add_timer(local_scheduler->loop, 500,
(event_loop_timer_handler) timeout_handler, NULL);
event_loop_run(local_scheduler->loop);
@ -412,7 +431,7 @@ TEST object_reconstruction_suppression_test(void) {
* object_table_add callback completes. */
int64_t task_assigned_size;
TaskSpec *task_assigned =
local_scheduler_get_task(worker, &task_assigned_size, true);
local_scheduler_get_task(worker, &task_assigned_size);
ASSERT_EQ(
memcmp(task_assigned, object_reconstruction_suppression_spec->Spec(),
object_reconstruction_suppression_spec->SpecSize()),

View file

@ -1291,50 +1291,53 @@ class ActorReconstruction(unittest.TestCase):
self.assertEqual(ray.get(result_id_list),
list(range(1, len(result_id_list) + 1)))
def setup_test_checkpointing(self, save_exception=False,
resume_exception=False):
def setup_counter_actor(self, test_checkpoint=False, save_exception=False,
resume_exception=False):
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
num_workers=0, redirect_output=True)
@ray.remote(checkpoint_interval=5)
# Only set the checkpoint interval if we're testing with checkpointing.
checkpoint_interval = -1
if test_checkpoint:
checkpoint_interval = 5
@ray.remote(checkpoint_interval=checkpoint_interval)
class Counter(object):
_resume_exception = resume_exception
def __init__(self, save_exception):
self.x = 0
# The number of times that inc has been called. We won't bother
# restoring this in the checkpoint
self.num_inc_calls = 0
self.save_exception = save_exception
self.restored = False
def local_plasma(self):
return ray.worker.global_worker.plasma_client.store_socket_name
def inc(self, *xs):
self.num_inc_calls += 1
self.x += 1
self.num_inc_calls += 1
return self.x
def get_num_inc_calls(self):
return self.num_inc_calls
def test_restore(self):
# This method will only work if __ray_restore__ has been run.
return self.y
# This method will only return True if __ray_restore__ has been
# called.
return self.restored
def __ray_save__(self):
if self.save_exception:
raise Exception("Exception raised in checkpoint save")
return self.x, -1
return self.x, self.save_exception
def __ray_restore__(self, checkpoint):
if self._resume_exception:
raise Exception("Exception raised in checkpoint resume")
self.x, val = checkpoint
self.x, self.save_exception = checkpoint
self.num_inc_calls = 0
# Test that __ray_save__ has been run.
assert val == -1
self.y = self.x
self.restored = True
local_plasma = ray.worker.global_worker.plasma_client.store_socket_name
@ -1352,7 +1355,7 @@ class ActorReconstruction(unittest.TestCase):
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
def testCheckpointing(self):
actor, ids = self.setup_test_checkpointing()
actor, ids = self.setup_counter_actor(test_checkpoint=True)
# Wait for the last task to finish running.
ray.get(ids[-1])
@ -1362,50 +1365,78 @@ class ActorReconstruction(unittest.TestCase):
process.kill()
process.wait()
# Get all of the results. TODO(rkn): This currently doesn't work.
# results = ray.get(ids)
# self.assertEqual(results, list(range(1, 1 + len(results))))
# Check that the actor restored from a checkpoint.
self.assertTrue(ray.get(actor.test_restore.remote()))
# Check that we can submit another call on the actor and get the
# correct counter result.
x = ray.get(actor.inc.remote())
self.assertEqual(x, 101)
# Check that the number of inc calls since actor initialization is less
# than the counter value, since the actor initialized from a
# checkpoint.
num_inc_calls = ray.get(actor.get_num_inc_calls.remote())
self.assertLess(num_inc_calls, x)
self.assertEqual(ray.get(actor.test_restore.remote()), 99)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
def testRemoteCheckpoint(self):
actor, ids = self.setup_counter_actor(test_checkpoint=True)
# The inc method should only have executed once on the new actor (for
# the one method call since the most recent checkpoint).
self.assertEqual(ray.get(actor.get_num_inc_calls.remote()), 1)
# Do a remote checkpoint call and wait for it to finish.
ray.get(actor.__ray_checkpoint__.remote())
# Kill the corresponding plasma store to get rid of the cached objects.
process = ray.services.all_processes[
ray.services.PROCESS_TYPE_PLASMA_STORE][1]
process.kill()
process.wait()
# Check that the actor restored from a checkpoint.
self.assertTrue(ray.get(actor.test_restore.remote()))
# Check that the number of inc calls since actor initialization is
# exactly zero, since there could not have been another inc call since
# the remote checkpoint.
num_inc_calls = ray.get(actor.get_num_inc_calls.remote())
self.assertEqual(num_inc_calls, 0)
# Check that we can submit another call on the actor and get the
# correct counter result.
x = ray.get(actor.inc.remote())
self.assertEqual(x, 101)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
def testLostCheckpoint(self):
actor, ids = self.setup_test_checkpointing()
actor, ids = self.setup_counter_actor(test_checkpoint=True)
# Wait for the first fraction of tasks to finish running.
ray.get(ids[len(ids) // 10])
actor_key = b"Actor:" + actor._ray_actor_id.id()
for index in ray.actor.get_checkpoint_indices(
ray.worker.global_worker, actor._ray_actor_id.id()):
ray.worker.global_worker.redis_client.hdel(
actor_key, "checkpoint_{}".format(index))
# Kill the corresponding plasma store to get rid of the cached objects.
process = ray.services.all_processes[
ray.services.PROCESS_TYPE_PLASMA_STORE][1]
process.kill()
process.wait()
self.assertEqual(ray.get(actor.inc.remote()), 101)
# Each inc method has been reexecuted once on the new actor.
self.assertEqual(ray.get(actor.get_num_inc_calls.remote()), 101)
# Get all of the results that were previously lost. Because the
# checkpoints were lost, all methods should be reconstructed.
results = ray.get(ids)
self.assertEqual(results, list(range(1, 1 + len(results))))
# Check that the actor restored from a checkpoint.
self.assertTrue(ray.get(actor.test_restore.remote()))
# Check that we can submit another call on the actor and get the
# correct counter result.
x = ray.get(actor.inc.remote())
self.assertEqual(x, 101)
# Check that the number of inc calls since actor initialization is less
# than the counter value, since the actor initialized from a
# checkpoint.
num_inc_calls = ray.get(actor.get_num_inc_calls.remote())
self.assertLess(num_inc_calls, x)
self.assertLess(5, num_inc_calls)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
def testCheckpointException(self):
actor, ids = self.setup_test_checkpointing(save_exception=True)
actor, ids = self.setup_counter_actor(test_checkpoint=True,
save_exception=True)
# Wait for the last task to finish running.
ray.get(ids[-1])
@ -1415,28 +1446,27 @@ class ActorReconstruction(unittest.TestCase):
process.kill()
process.wait()
self.assertEqual(ray.get(actor.inc.remote()), 101)
# Each inc method has been reexecuted once on the new actor, since all
# checkpoint saves failed.
self.assertEqual(ray.get(actor.get_num_inc_calls.remote()), 101)
# Get all of the results that were previously lost. Because the
# checkpoints were lost, all methods should be reconstructed.
results = ray.get(ids)
self.assertEqual(results, list(range(1, 1 + len(results))))
# Check that we can submit another call on the actor and get the
# correct counter result.
x = ray.get(actor.inc.remote())
self.assertEqual(x, 101)
# Check that the number of inc calls since actor initialization is
# equal to the counter value, since the actor did not initialize from a
# checkpoint.
num_inc_calls = ray.get(actor.get_num_inc_calls.remote())
self.assertEqual(num_inc_calls, x)
# Check that errors were raised when trying to save the checkpoint.
errors = ray.error_info()
# We submitted 101 tasks with a checkpoint interval of 5.
num_checkpoints = 101 // 5
# Each checkpoint task throws an exception when saving during initial
# execution, and then again during re-execution.
self.assertEqual(len([error for error in errors if error[b"type"] ==
b"task"]), num_checkpoints * 2)
self.assertLess(0, len(errors))
for error in errors:
self.assertEqual(error[b"type"], b"checkpoint")
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
def testCheckpointResumeException(self):
actor, ids = self.setup_test_checkpointing(resume_exception=True)
actor, ids = self.setup_counter_actor(test_checkpoint=True,
resume_exception=True)
# Wait for the last task to finish running.
ray.get(ids[-1])
@ -1446,166 +1476,42 @@ class ActorReconstruction(unittest.TestCase):
process.kill()
process.wait()
self.assertEqual(ray.get(actor.inc.remote()), 101)
# Each inc method has been reexecuted once on the new actor, since all
# checkpoint resumes failed.
self.assertEqual(ray.get(actor.get_num_inc_calls.remote()), 101)
# Get all of the results that were previously lost. Because the
# checkpoints were lost, all methods should be reconstructed.
results = ray.get(ids)
self.assertEqual(results, list(range(1, 1 + len(results))))
# Check that we can submit another call on the actor and get the
# correct counter result.
x = ray.get(actor.inc.remote())
self.assertEqual(x, 101)
# Check that the number of inc calls since actor initialization is
# equal to the counter value, since the actor did not initialize from a
# checkpoint.
num_inc_calls = ray.get(actor.get_num_inc_calls.remote())
self.assertEqual(num_inc_calls, x)
# Check that an error was raised when trying to resume from the
# checkpoint.
errors = ray.error_info()
# The most recently executed checkpoint task should throw an exception
# when trying to resume. All other checkpoint tasks should reconstruct
# the previous task but throw no errors.
self.assertTrue(len([error for error in errors if error[b"type"] ==
b"task"]) > 0)
self.assertEqual(len(errors), 1)
for error in errors:
self.assertEqual(error[b"type"], b"checkpoint")
class DistributedActorHandles(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
def make_counter_actor(self, checkpoint_interval=-1):
ray.init()
@ray.remote(checkpoint_interval=checkpoint_interval)
class Counter(object):
def __init__(self):
self.value = 0
def increase(self):
self.value += 1
return self.value
return Counter.remote()
def testFork(self):
counter = self.make_counter_actor()
num_calls = 1
self.assertEqual(ray.get(counter.increase.remote()), num_calls)
@ray.remote
def fork(counter):
return ray.get(counter.increase.remote())
# Fork once.
num_calls += 1
self.assertEqual(ray.get(fork.remote(counter)), num_calls)
num_calls += 1
self.assertEqual(ray.get(counter.increase.remote()), num_calls)
# Fork num_iters times.
num_iters = 100
num_calls += num_iters
ray.get([fork.remote(counter) for _ in range(num_iters)])
num_calls += 1
self.assertEqual(ray.get(counter.increase.remote()), num_calls)
def testForkConsistency(self):
counter = self.make_counter_actor()
@unittest.skip("Fork/join consistency not yet implemented.")
def testDistributedHandle(self):
counter, ids = self.setup_counter_actor(test_checkpoint=False)
@ray.remote
def fork_many_incs(counter, num_incs):
x = None
for _ in range(num_incs):
x = counter.increase.remote()
x = counter.inc.remote()
# Only call ray.get() on the last task submitted.
return ray.get(x)
num_incs = 100
# Fork once.
num_calls = num_incs
self.assertEqual(ray.get(fork_many_incs.remote(counter, num_incs)),
num_calls)
num_calls += 1
self.assertEqual(ray.get(counter.increase.remote()), num_calls)
# Fork num_iters times.
count = ray.get(ids[-1])
num_incs = 100
num_iters = 10
num_calls += num_iters * num_incs
ray.get([fork_many_incs.remote(counter, num_incs) for _ in
range(num_iters)])
# Check that we ensured per-handle serialization.
num_calls += 1
self.assertEqual(ray.get(counter.increase.remote()), num_calls)
@unittest.skip("Garbage collection for distributed actor handles not "
"implemented.")
def testGarbageCollection(self):
counter = self.make_counter_actor()
@ray.remote
def fork(counter):
for _ in range(10):
x = counter.increase.remote()
time.sleep(0.1)
return ray.get(x)
x = fork.remote(counter)
ray.get(counter.increase.remote())
del counter
print(ray.get(x))
def testCheckpoint(self):
counter = self.make_counter_actor(checkpoint_interval=1)
num_calls = 1
self.assertEqual(ray.get(counter.increase.remote()), num_calls)
@ray.remote
def fork(counter):
return ray.get(counter.increase.remote())
# Passing an actor handle with checkpointing enabled shouldn't be
# allowed yet.
with self.assertRaises(Exception):
fork.remote(counter)
num_calls += 1
self.assertEqual(ray.get(counter.increase.remote()), num_calls)
@unittest.skip("Fork/join consistency not yet implemented.")
def testLocalSchedulerDying(self):
ray.worker._init(start_ray_local=True, num_local_schedulers=2,
num_workers=0, redirect_output=False)
@ray.remote
class Counter(object):
def __init__(self):
self.x = 0
def local_plasma(self):
return ray.worker.global_worker.plasma_client.store_socket_name
def inc(self):
self.x += 1
return self.x
@ray.remote
def foo(counter):
for _ in range(100):
x = counter.inc.remote()
return ray.get(x)
local_plasma = ray.worker.global_worker.plasma_client.store_socket_name
# Create an actor that is not on the local scheduler.
actor = Counter.remote()
while ray.get(actor.local_plasma.remote()) == local_plasma:
actor = Counter.remote()
# Concurrently, submit many tasks to the actor through the original
# handle and the forked handle.
x = foo.remote(actor)
ids = [actor.inc.remote() for _ in range(100)]
# Wait for the last task to finish running.
ray.get(ids[-1])
y = ray.get(x)
forks = [fork_many_incs.remote(counter, num_incs) for _ in
range(num_iters)]
ray.wait(forks, num_returns=len(forks))
count += num_incs * num_iters
# Kill the second plasma store to get rid of the cached objects and
# trigger the corresponding local scheduler to exit.
@ -1614,37 +1520,90 @@ class DistributedActorHandles(unittest.TestCase):
process.kill()
process.wait()
# Submit a new task. Its results should reflect the tasks submitted
# through both the original handle and the forked handle.
self.assertEqual(ray.get(actor.inc.remote()), y + 1)
# Check that the actor did not restore from a checkpoint.
self.assertFalse(ray.get(counter.test_restore.remote()))
# Check that we can submit another call on the actor and get the
# correct counter result.
x = ray.get(counter.inc.remote())
self.assertEqual(x, count + 1)
def testCallingPutOnActorHandle(self):
ray.worker.init(num_workers=1)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
"Hanging with new GCS API.")
def testRemoteCheckpointDistributedHandle(self):
counter, ids = self.setup_counter_actor(test_checkpoint=True)
@ray.remote
class Counter(object):
pass
def fork_many_incs(counter, num_incs):
x = None
for _ in range(num_incs):
x = counter.inc.remote()
# Only call ray.get() on the last task submitted.
return ray.get(x)
# Fork num_iters times.
count = ray.get(ids[-1])
num_incs = 100
num_iters = 10
forks = [fork_many_incs.remote(counter, num_incs) for _ in
range(num_iters)]
ray.wait(forks, num_returns=len(forks))
ray.wait([counter.__ray_checkpoint__.remote()])
count += num_incs * num_iters
# Kill the second plasma store to get rid of the cached objects and
# trigger the corresponding local scheduler to exit.
process = ray.services.all_processes[
ray.services.PROCESS_TYPE_PLASMA_STORE][1]
process.kill()
process.wait()
# Check that the actor restored from a checkpoint.
self.assertTrue(ray.get(counter.test_restore.remote()))
# Check that the number of inc calls since actor initialization is
# exactly zero, since there could not have been another inc call since
# the remote checkpoint.
num_inc_calls = ray.get(counter.get_num_inc_calls.remote())
self.assertEqual(num_inc_calls, 0)
# Check that we can submit another call on the actor and get the
# correct counter result.
x = ray.get(counter.inc.remote())
self.assertEqual(x, count + 1)
@unittest.skip("Fork/join consistency not yet implemented.")
def testCheckpointDistributedHandle(self):
counter, ids = self.setup_counter_actor(test_checkpoint=True)
@ray.remote
def f():
return Counter.remote()
def fork_many_incs(counter, num_incs):
x = None
for _ in range(num_incs):
x = counter.inc.remote()
# Only call ray.get() on the last task submitted.
return ray.get(x)
@ray.remote
def g():
return [Counter.remote()]
# Fork num_iters times.
count = ray.get(ids[-1])
num_incs = 100
num_iters = 10
forks = [fork_many_incs.remote(counter, num_incs) for _ in
range(num_iters)]
ray.wait(forks, num_returns=len(forks))
count += num_incs * num_iters
with self.assertRaises(Exception):
ray.put(Counter.remote())
# Kill the second plasma store to get rid of the cached objects and
# trigger the corresponding local scheduler to exit.
process = ray.services.all_processes[
ray.services.PROCESS_TYPE_PLASMA_STORE][1]
process.kill()
process.wait()
with self.assertRaises(Exception):
ray.get(f.remote())
# The below test is commented out because it currently does not behave
# properly. The call to g.remote() does not raise an exception because
# even though the actor handle cannot be pickled, pyarrow attempts to
# serialize it as a dictionary of its fields which kind of works.
# self.assertRaises(Exception):
# ray.get(g.remote())
# Check that the actor restored from a checkpoint.
self.assertTrue(ray.get(counter.test_restore.remote()))
# Check that we can submit another call on the actor and get the
# correct counter result.
x = ray.get(counter.inc.remote())
self.assertEqual(x, count + 1)
def _testNondeterministicReconstruction(self, num_forks,
num_items_per_fork,
@ -1730,6 +1689,109 @@ class DistributedActorHandles(unittest.TestCase):
self._testNondeterministicReconstruction(10, 100, 1)
class DistributedActorHandles(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
def setup_queue_actor(self):
ray.init()
@ray.remote
class Queue(object):
def __init__(self):
self.queue = []
def enqueue(self, key, item):
self.queue.append((key, item))
def read(self):
return self.queue
return Queue.remote()
def testFork(self):
queue = self.setup_queue_actor()
@ray.remote
def fork(queue, key, item):
return ray.get(queue.enqueue.remote(key, item))
# Fork num_iters times.
num_iters = 100
ray.get([fork.remote(queue, i, 0) for i in range(num_iters)])
items = ray.get(queue.read.remote())
for i in range(num_iters):
filtered_items = [item[1] for item in items if item[0] == i]
self.assertEqual(filtered_items, list(range(1)))
def testForkConsistency(self):
queue = self.setup_queue_actor()
@ray.remote
def fork(queue, key, num_items):
x = None
for item in range(num_items):
x = queue.enqueue.remote(key, item)
return ray.get(x)
# Fork num_iters times.
num_forks = 10
num_items_per_fork = 100
ray.get([fork.remote(queue, i, num_items_per_fork) for i in
range(num_forks)])
items = ray.get(queue.read.remote())
for i in range(num_forks):
filtered_items = [item[1] for item in items if item[0] == i]
self.assertEqual(filtered_items, list(range(num_items_per_fork)))
@unittest.skip("Garbage collection for distributed actor handles not "
"implemented.")
def testGarbageCollection(self):
queue = self.setup_queue_actor()
@ray.remote
def fork(queue):
for i in range(10):
x = queue.enqueue.remote(0, i)
time.sleep(0.1)
return ray.get(x)
x = fork.remote(queue)
ray.get(queue.read.remote())
del queue
print(ray.get(x))
def testCallingPutOnActorHandle(self):
ray.worker.init(num_workers=1)
@ray.remote
class Counter(object):
pass
@ray.remote
def f():
return Counter.remote()
@ray.remote
def g():
return [Counter.remote()]
with self.assertRaises(Exception):
ray.put(Counter.remote())
with self.assertRaises(Exception):
ray.get(f.remote())
# The below test is commented out because it currently does not behave
# properly. The call to g.remote() does not raise an exception because
# even though the actor handle cannot be pickled, pyarrow attempts to
# serialize it as a dictionary of its fields which kind of works.
# self.assertRaises(Exception):
# ray.get(g.remote())
@unittest.skip("Actor placement currently does not use custom resources.")
class ActorPlacement(unittest.TestCase):