Retry application-level errors (#18176)

* Retry application-level errors

* Retry application-level errors

* Push retry message to the driver
This commit is contained in:
Jiajun Yao 2021-09-01 10:53:06 -07:00 committed by GitHub
parent 673bf35c1f
commit fbb3ac6a86
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 296 additions and 59 deletions

View file

@ -42,8 +42,8 @@ ObjectID NativeTaskSubmitter::Submit(InvocationSpec &invocation,
invocation.args, options, &return_ids);
} else {
core_worker.SubmitTask(BuildRayFunction(invocation), invocation.args, options,
&return_ids, 1, std::make_pair(PlacementGroupID::Nil(), -1),
true, "");
&return_ids, 1, false,
std::make_pair(PlacementGroupID::Nil(), -1), true, "");
}
return return_ids[0];
}

View file

@ -126,7 +126,8 @@ Status TaskExecutor::ExecuteTask(
const std::vector<ObjectID> &arg_reference_ids,
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
std::vector<std::shared_ptr<ray::RayObject>> *results,
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes) {
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
bool *is_application_level_error) {
RAY_LOG(INFO) << "Execute task: " << TaskType_Name(task_type);
RAY_CHECK(ray_function.GetLanguage() == ray::Language::CPP);
auto function_descriptor = ray_function.GetFunctionDescriptor();

View file

@ -79,7 +79,8 @@ class TaskExecutor {
const std::vector<ObjectID> &arg_reference_ids,
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
std::vector<std::shared_ptr<ray::RayObject>> *results,
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes);
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
bool *is_application_level_error);
virtual ~TaskExecutor(){};

View file

@ -406,7 +406,10 @@ cdef execute_task(
const c_vector[CObjectID] &c_arg_reference_ids,
const c_vector[CObjectID] &c_return_ids,
const c_string debugger_breakpoint,
c_vector[shared_ptr[CRayObject]] *returns):
c_vector[shared_ptr[CRayObject]] *returns,
c_bool *is_application_level_error):
is_application_level_error[0] = False
worker = ray.worker.global_worker
manager = worker.function_actor_manager
@ -579,6 +582,9 @@ cdef execute_task(
except KeyboardInterrupt as e:
raise TaskCancelledError(
core_worker.get_current_task_id())
except Exception as e:
is_application_level_error[0] = True
raise e
if c_return_ids.size() == 1:
outputs = (outputs,)
# Check for a cancellation that was called when the function
@ -656,7 +662,8 @@ cdef CRayStatus task_execution_handler(
const c_vector[CObjectID] &c_return_ids,
const c_string debugger_breakpoint,
c_vector[shared_ptr[CRayObject]] *returns,
shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes) nogil:
shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes,
c_bool *is_application_level_error) nogil:
with gil, disable_client_hook():
try:
try:
@ -664,7 +671,8 @@ cdef CRayStatus task_execution_handler(
# it does, that indicates that there was an internal error.
execute_task(task_type, task_name, ray_function, c_resources,
c_args, c_arg_reference_ids, c_return_ids,
debugger_breakpoint, returns)
debugger_breakpoint, returns,
is_application_level_error)
except Exception as e:
sys_exit = SystemExit()
if isinstance(e, RayActorError) and \
@ -1318,6 +1326,7 @@ cdef class CoreWorker:
int num_returns,
resources,
int max_retries,
c_bool retry_exceptions,
PlacementGroupID placement_group_id,
int64_t placement_group_bundle_index,
c_bool placement_group_capture_child_tasks,
@ -1354,7 +1363,7 @@ cdef class CoreWorker:
b"",
c_serialized_runtime_env,
c_override_environment_variables),
&return_ids, max_retries,
&return_ids, max_retries, retry_exceptions,
c_pair[CPlacementGroupID, int64_t](
c_placement_group_id, placement_group_bundle_index),
placement_group_capture_child_tasks,

View file

@ -77,7 +77,8 @@ def java_function(class_name, function_name):
None, # accelerator_type,
None, # num_returns,
None, # max_calls,
None, # max_retries
None, # max_retries,
None, # retry_exceptions,
None) # runtime_env

View file

@ -104,6 +104,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
const c_vector[unique_ptr[CTaskArg]] &args,
const CTaskOptions &options, c_vector[CObjectID] *return_ids,
int max_retries,
c_bool retry_exceptions,
c_pair[CPlacementGroupID, int64_t] placement_options,
c_bool placement_group_capture_child_tasks,
c_string debugger_breakpoint)
@ -280,7 +281,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
const c_string debugger_breakpoint,
c_vector[shared_ptr[CRayObject]] *returns,
shared_ptr[LocalMemoryBuffer]
&creation_task_exception_pb_bytes) nogil
&creation_task_exception_pb_bytes,
c_bool *is_application_level_error) nogil
) task_execution_callback
(void(const CWorkerID &) nogil) on_worker_shutdown
(CRayStatus() nogil) check_signals

View file

@ -24,6 +24,7 @@ DEFAULT_REMOTE_FUNCTION_MAX_CALLS = 0
# Normal tasks may be retried on failure this many times.
# TODO(swang): Allow this to be set globally for an application.
DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES = 3
DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS = False
logger = logging.getLogger(__name__)
@ -53,6 +54,9 @@ class RemoteFunction:
of this remote function.
_max_calls: The number of times a worker can execute this function
before exiting.
_max_retries: The number of times this task may be retried
on worker failure.
_retry_exceptions: Whether application-level errors should be retried.
_runtime_env: The runtime environment for this task.
_decorator: An optional decorator that should be applied to the remote
function invocation (as opposed to the function execution) before
@ -73,7 +77,7 @@ class RemoteFunction:
def __init__(self, language, function, function_descriptor, num_cpus,
num_gpus, memory, object_store_memory, resources,
accelerator_type, num_returns, max_calls, max_retries,
runtime_env):
retry_exceptions, runtime_env):
if inspect.iscoroutinefunction(function):
raise ValueError("'async def' should not be used for remote "
"tasks. You can wrap the async function with "
@ -100,6 +104,9 @@ class RemoteFunction:
if max_calls is None else max_calls)
self._max_retries = (DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES
if max_retries is None else max_retries)
self._retry_exceptions = (DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS
if retry_exceptions is None else
retry_exceptions)
self._runtime_env = runtime_env
self._decorator = getattr(function, "__ray_invocation_decorator__",
None)
@ -131,6 +138,7 @@ class RemoteFunction:
accelerator_type=None,
resources=None,
max_retries=None,
retry_exceptions=None,
placement_group="default",
placement_group_bundle_index=-1,
placement_group_capture_child_tasks=None,
@ -168,6 +176,7 @@ class RemoteFunction:
accelerator_type=accelerator_type,
resources=resources,
max_retries=max_retries,
retry_exceptions=retry_exceptions,
placement_group=placement_group,
placement_group_bundle_index=placement_group_bundle_index,
placement_group_capture_child_tasks=(
@ -191,6 +200,7 @@ class RemoteFunction:
accelerator_type=None,
resources=None,
max_retries=None,
retry_exceptions=None,
placement_group="default",
placement_group_bundle_index=-1,
placement_group_capture_child_tasks=None,
@ -211,6 +221,7 @@ class RemoteFunction:
accelerator_type=accelerator_type,
resources=resources,
max_retries=max_retries,
retry_exceptions=retry_exceptions,
placement_group=placement_group,
placement_group_bundle_index=placement_group_bundle_index,
placement_group_capture_child_tasks=(
@ -251,6 +262,8 @@ class RemoteFunction:
num_returns = self._num_returns
if max_retries is None:
max_retries = self._max_retries
if retry_exceptions is None:
retry_exceptions = self._retry_exceptions
if placement_group_capture_child_tasks is None:
placement_group_capture_child_tasks = (
@ -307,6 +320,7 @@ class RemoteFunction:
num_returns,
resources,
max_retries,
retry_exceptions,
placement_group.id,
placement_group_bundle_index,
placement_group_capture_child_tasks,

View file

@ -10,6 +10,68 @@ from ray._private.test_utils import (init_error_pubsub, get_error_message,
run_string_as_driver)
def test_retry_system_level_error(ray_start_regular):
@ray.remote
class Counter:
def __init__(self):
self.value = 0
def increment(self):
self.value += 1
return self.value
@ray.remote(max_retries=1)
def func(counter):
count = counter.increment.remote()
if ray.get(count) == 1:
import os
os._exit(0)
else:
return 1
counter1 = Counter.remote()
r1 = func.remote(counter1)
assert ray.get(r1) == 1
counter2 = Counter.remote()
r2 = func.options(max_retries=0).remote(counter2)
with pytest.raises(ray.exceptions.WorkerCrashedError):
ray.get(r2)
def test_retry_application_level_error(ray_start_regular):
@ray.remote
class Counter:
def __init__(self):
self.value = 0
def increment(self):
self.value += 1
return self.value
@ray.remote(max_retries=1, retry_exceptions=True)
def func(counter):
count = counter.increment.remote()
if ray.get(count) == 1:
raise ValueError()
else:
return 2
counter1 = Counter.remote()
r1 = func.remote(counter1)
assert ray.get(r1) == 2
counter2 = Counter.remote()
r2 = func.options(max_retries=0).remote(counter2)
with pytest.raises(ValueError):
ray.get(r2)
counter3 = Counter.remote()
r3 = func.options(retry_exceptions=False).remote(counter3)
with pytest.raises(ValueError):
ray.get(r3)
def test_connect_with_disconnected_node(shutdown_only):
config = {
"num_heartbeats_timeout": 50,

View file

@ -22,6 +22,7 @@ options = {
"max_retries": (int, lambda x: x >= -1,
"The keyword 'max_retries' only accepts 0, -1 "
"or a positive integer"),
"retry_exceptions": (),
"max_concurrency": (),
"name": (),
"namespace": (),

View file

@ -651,7 +651,7 @@ def init(
a raylet, a plasma store, a plasma manager, and some workers.
It will also kill these processes when Python exits. If the driver
is running on a node in a Ray cluster, using `auto` as the value
tells the driver to detect the the cluster, removing the need to
tells the driver to detect the cluster, removing the need to
specify a specific node address. If the environment variable
`RAY_ADDRESS` is defined and the address is None or "auto", Ray
will set `address` to `RAY_ADDRESS`.
@ -1924,7 +1924,8 @@ def make_decorator(num_returns=None,
max_restarts=None,
max_task_retries=None,
runtime_env=None,
worker=None):
worker=None,
retry_exceptions=None):
def decorator(function_or_class):
if (inspect.isfunction(function_or_class)
or is_cython(function_or_class)):
@ -1953,12 +1954,19 @@ def make_decorator(num_returns=None,
return ray.remote_function.RemoteFunction(
Language.PYTHON, function_or_class, None, num_cpus, num_gpus,
memory, object_store_memory, resources, accelerator_type,
num_returns, max_calls, max_retries, runtime_env)
num_returns, max_calls, max_retries, retry_exceptions,
runtime_env)
if inspect.isclass(function_or_class):
if num_returns is not None:
raise TypeError("The keyword 'num_returns' is not "
"allowed for actors.")
if max_retries is not None:
raise TypeError("The keyword 'max_retries' is not "
"allowed for actors.")
if retry_exceptions is not None:
raise TypeError("The keyword 'retry_exceptions' is not "
"allowed for actors.")
if max_calls is not None:
raise TypeError("The keyword 'max_calls' is not "
"allowed for actors.")
@ -2082,6 +2090,9 @@ def remote(*args, **kwargs):
this actor or task and its children. See
:ref:`runtime-environments` for detailed documentation. This API is
in beta and may change before becoming stable.
retry_exceptions (bool): Only for *remote functions*. This specifies
whether application-level errors should be retried
up to max_retries times.
override_environment_variables (Dict[str, str]): (Deprecated in Ray
1.4.0, will be removed in Ray 1.6--please use the ``env_vars``
field of :ref:`runtime-environments` instead.) This specifies
@ -2102,7 +2113,7 @@ def remote(*args, **kwargs):
valid_kwargs = [
"num_returns", "num_cpus", "num_gpus", "memory", "object_store_memory",
"resources", "accelerator_type", "max_calls", "max_restarts",
"max_task_retries", "max_retries", "runtime_env"
"max_task_retries", "max_retries", "runtime_env", "retry_exceptions"
]
error_string = ("The @ray.remote decorator must be applied either "
"with no arguments and no parentheses, for example "
@ -2135,6 +2146,7 @@ def remote(*args, **kwargs):
object_store_memory = kwargs.get("object_store_memory")
max_retries = kwargs.get("max_retries")
runtime_env = kwargs.get("runtime_env")
retry_exceptions = kwargs.get("retry_exceptions")
return make_decorator(
num_returns=num_returns,
@ -2149,4 +2161,5 @@ def remote(*args, **kwargs):
max_task_retries=max_task_retries,
max_retries=max_retries,
runtime_env=runtime_env,
worker=worker)
worker=worker,
retry_exceptions=retry_exceptions)

View file

@ -134,6 +134,12 @@ class TaskSpecBuilder {
return *this;
}
TaskSpecBuilder &SetNormalTaskSpec(int max_retries, bool retry_exceptions) {
message_->set_max_retries(max_retries);
message_->set_retry_exceptions(retry_exceptions);
return *this;
}
/// Set the driver attributes of the task spec.
/// See `common.proto` for meaning of the arguments.
///

View file

@ -385,9 +385,9 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
// Initialize task receivers.
if (options_.worker_type == WorkerType::WORKER || options_.is_local_mode) {
RAY_CHECK(options_.task_execution_callback != nullptr);
auto execute_task =
std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3,
std::placeholders::_4, std::placeholders::_5);
direct_task_receiver_ = std::make_unique<CoreWorkerDirectTaskReceiver>(
worker_context_, task_execution_service_, execute_task,
[this] { return local_raylet_client_->TaskDone(); });
@ -565,6 +565,10 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
},
"CoreWorker.ReconstructObject");
};
auto push_error_callback = [this](const JobID &job_id, const std::string &type,
const std::string &error_message, double timestamp) {
return PushError(job_id, type, error_message, timestamp);
};
task_manager_.reset(new TaskManager(
memory_store_, reference_counter_,
/* retry_task_callback= */
@ -589,7 +593,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
}
}
},
check_node_alive_fn, reconstruct_object_callback));
check_node_alive_fn, reconstruct_object_callback, push_error_callback));
// Create an entry for the driver task in the task table. This task is
// added immediately with status RUNNING. This allows us to push errors
@ -1643,7 +1647,7 @@ void CoreWorker::SubmitTask(const RayFunction &function,
const std::vector<std::unique_ptr<TaskArg>> &args,
const TaskOptions &task_options,
std::vector<ObjectID> *return_ids, int max_retries,
BundleID placement_options,
bool retry_exceptions, BundleID placement_options,
bool placement_group_capture_child_tasks,
const std::string &debugger_breakpoint) {
TaskSpecBuilder builder;
@ -1674,6 +1678,7 @@ void CoreWorker::SubmitTask(const RayFunction &function,
placement_options, placement_group_capture_child_tasks,
debugger_breakpoint, task_options.serialized_runtime_env,
override_environment_variables);
builder.SetNormalTaskSpec(max_retries, retry_exceptions);
TaskSpecification task_spec = builder.Build();
RAY_LOG(DEBUG) << "Submit task " << task_spec.DebugString();
if (options_.is_local_mode) {
@ -2146,7 +2151,8 @@ Status CoreWorker::AllocateReturnObject(const ObjectID &object_id,
Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
const std::shared_ptr<ResourceMappingType> &resource_ids,
std::vector<std::shared_ptr<RayObject>> *return_objects,
ReferenceCounter::ReferenceTableProto *borrowed_refs) {
ReferenceCounter::ReferenceTableProto *borrowed_refs,
bool *is_application_level_error) {
RAY_LOG(DEBUG) << "Executing task, task info = " << task_spec.DebugString();
task_queue_length_ -= 1;
num_executed_tasks_ += 1;
@ -2212,7 +2218,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
task_type, task_spec.GetName(), func,
task_spec.GetRequiredResources().GetResourceMap(), args, arg_reference_ids,
return_ids, task_spec.GetDebuggerBreakpoint(), return_objects,
creation_task_exception_pb_bytes);
creation_task_exception_pb_bytes, is_application_level_error);
// Get the reference counts for any IDs that we borrowed during this task and
// return them to the caller. This will notify the caller of any IDs that we
@ -2301,7 +2307,9 @@ void CoreWorker::ExecuteTaskLocalMode(const TaskSpecification &task_spec,
}
auto old_id = GetActorId();
SetActorId(actor_id);
RAY_UNUSED(ExecuteTask(task_spec, resource_ids, &return_objects, &borrowed_refs));
bool is_application_level_error;
RAY_UNUSED(ExecuteTask(task_spec, resource_ids, &return_objects, &borrowed_refs,
&is_application_level_error));
SetActorId(old_id);
}

View file

@ -70,7 +70,8 @@ struct CoreWorkerOptions {
const std::vector<ObjectID> &arg_reference_ids,
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
std::vector<std::shared_ptr<RayObject>> *results,
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb_bytes)>;
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb_bytes,
bool *is_application_level_error)>;
CoreWorkerOptions()
: store_socket(""),
@ -712,7 +713,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
void SubmitTask(const RayFunction &function,
const std::vector<std::unique_ptr<TaskArg>> &args,
const TaskOptions &task_options, std::vector<ObjectID> *return_ids,
int max_retries, BundleID placement_options,
int max_retries, bool retry_exceptions, BundleID placement_options,
bool placement_group_capture_child_tasks,
const std::string &debugger_breakpoint);
@ -1122,7 +1123,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
Status ExecuteTask(const TaskSpecification &task_spec,
const std::shared_ptr<ResourceMappingType> &resource_ids,
std::vector<std::shared_ptr<RayObject>> *return_objects,
ReferenceCounter::ReferenceTableProto *borrowed_refs);
ReferenceCounter::ReferenceTableProto *borrowed_refs,
bool *is_application_level_error);
/// Execute a local mode task (runs normal ExecuteTask)
///

View file

@ -100,7 +100,11 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
const std::vector<ObjectID> &arg_reference_ids,
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
std::vector<std::shared_ptr<RayObject>> *results,
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb) {
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb,
bool *is_application_level_error) {
// TODO(jjyao): Support retrying application-level errors for Java
*is_application_level_error = false;
JNIEnv *env = GetJNIEnv();
RAY_CHECK(java_task_executor);

View file

@ -308,6 +308,7 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSub
CoreWorkerProcess::GetCoreWorker().SubmitTask(
ray_function, task_args, task_options, &return_ids,
/*max_retries=*/0,
/*retry_exceptions=*/false,
/*placement_options=*/placement_group_options,
/*placement_group_capture_child_tasks=*/true,
/*debugger_breakpoint*/ "");

View file

@ -266,7 +266,7 @@ void TaskManager::CompletePendingTask(const TaskID &task_id,
it->second.pending = false;
num_pending_tasks_--;
// A finished task can be only be re-executed if it has some number of
// A finished task can only be re-executed if it has some number of
// retries left and returned at least one object that is still in use and
// stored in plasma.
bool task_retryable = it->second.num_retries_left != 0 &&
@ -284,6 +284,45 @@ void TaskManager::CompletePendingTask(const TaskID &task_id,
ShutdownIfNeeded();
}
bool TaskManager::RetryTaskIfPossible(const TaskID &task_id) {
int num_retries_left = 0;
TaskSpecification spec;
{
absl::MutexLock lock(&mu_);
auto it = submissible_tasks_.find(task_id);
RAY_CHECK(it != submissible_tasks_.end())
<< "Tried to retry task that was not pending " << task_id;
RAY_CHECK(it->second.pending)
<< "Tried to retry task that was not pending " << task_id;
spec = it->second.spec;
num_retries_left = it->second.num_retries_left;
if (num_retries_left > 0) {
it->second.num_retries_left--;
} else {
RAY_CHECK(num_retries_left == 0 || num_retries_left == -1);
}
}
// We should not hold the lock during these calls because they may trigger
// callbacks in this or other classes.
if (num_retries_left != 0) {
auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
std::ostringstream stream;
auto num_retries_left_str =
num_retries_left == -1 ? "infinite" : std::to_string(num_retries_left);
stream << num_retries_left_str << " retries left for task " << spec.TaskId()
<< ", attempting to resubmit.";
RAY_CHECK_OK(
push_error_callback_(spec.JobId(), "retry_task", stream.str(), timestamp));
retry_task_callback_(spec, /*delay=*/true);
return true;
} else {
return false;
}
}
bool TaskManager::PendingTaskFailed(
const TaskID &task_id, rpc::ErrorType error_type, Status *status,
const std::shared_ptr<rpc::RayException> &creation_task_exception,
@ -292,9 +331,9 @@ bool TaskManager::PendingTaskFailed(
// loudly with ERROR here.
RAY_LOG(DEBUG) << "Task " << task_id << " failed with error "
<< rpc::ErrorType_Name(error_type);
int num_retries_left = 0;
const bool will_retry = RetryTaskIfPossible(task_id);
const bool release_lineage = !will_retry;
TaskSpecification spec;
bool release_lineage = true;
{
absl::MutexLock lock(&mu_);
auto it = submissible_tasks_.find(task_id);
@ -303,30 +342,13 @@ bool TaskManager::PendingTaskFailed(
RAY_CHECK(it->second.pending)
<< "Tried to complete task that was not pending " << task_id;
spec = it->second.spec;
num_retries_left = it->second.num_retries_left;
if (num_retries_left == 0) {
if (!will_retry) {
submissible_tasks_.erase(it);
num_pending_tasks_--;
} else if (num_retries_left == -1) {
release_lineage = false;
} else {
RAY_CHECK(num_retries_left > 0);
it->second.num_retries_left--;
release_lineage = false;
}
}
bool will_retry = false;
// We should not hold the lock during these calls because they may trigger
// callbacks in this or other classes.
if (num_retries_left != 0) {
auto retries_str =
num_retries_left == -1 ? "infinite" : std::to_string(num_retries_left);
RAY_LOG(INFO) << retries_str << " retries left for task " << spec.TaskId()
<< ", attempting to resubmit.";
retry_task_callback_(spec, /*delay=*/true);
will_retry = true;
} else {
if (!will_retry) {
// Throttled logging of task failure errors.
{
absl::MutexLock lock(&mu_);

View file

@ -31,6 +31,8 @@ class TaskFinisherInterface {
virtual void CompletePendingTask(const TaskID &task_id, const rpc::PushTaskReply &reply,
const rpc::Address &actor_addr) = 0;
virtual bool RetryTaskIfPossible(const TaskID &task_id) = 0;
virtual bool PendingTaskFailed(
const TaskID &task_id, rpc::ErrorType error_type, Status *status,
const std::shared_ptr<rpc::RayException> &creation_task_exception = nullptr,
@ -61,6 +63,9 @@ class TaskResubmissionInterface {
using RetryTaskCallback = std::function<void(TaskSpecification &spec, bool delay)>;
using ReconstructObjectCallback = std::function<void(const ObjectID &object_id)>;
using PushErrorCallback =
std::function<Status(const JobID &job_id, const std::string &type,
const std::string &error_message, double timestamp)>;
class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterface {
public:
@ -68,12 +73,14 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
std::shared_ptr<ReferenceCounter> reference_counter,
RetryTaskCallback retry_task_callback,
const std::function<bool(const NodeID &node_id)> &check_node_alive,
ReconstructObjectCallback reconstruct_object_callback)
ReconstructObjectCallback reconstruct_object_callback,
PushErrorCallback push_error_callback)
: in_memory_store_(in_memory_store),
reference_counter_(reference_counter),
retry_task_callback_(retry_task_callback),
check_node_alive_(check_node_alive),
reconstruct_object_callback_(reconstruct_object_callback) {
reconstruct_object_callback_(reconstruct_object_callback),
push_error_callback_(push_error_callback) {
reference_counter_->SetReleaseLineageCallback(
[this](const ObjectID &object_id, std::vector<ObjectID> *ids_to_release) {
RemoveLineageReference(object_id, ids_to_release);
@ -118,6 +125,8 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
void CompletePendingTask(const TaskID &task_id, const rpc::PushTaskReply &reply,
const rpc::Address &worker_addr) override;
bool RetryTaskIfPossible(const TaskID &task_id) override;
/// A pending task failed. This will either retry the task or mark the task
/// as failed if there are no retries left.
///
@ -266,6 +275,9 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
/// recoverable).
const ReconstructObjectCallback reconstruct_object_callback_;
// Called to push an error to the relevant driver.
const PushErrorCallback push_error_callback_;
// The number of task failures we have logged total.
int64_t num_failure_logs_ GUARDED_BY(mu_) = 0;

View file

@ -249,6 +249,7 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map<std::string, double> &res
TaskOptions options;
std::vector<ObjectID> return_ids;
driver.SubmitTask(func, args, options, &return_ids, /*max_retries=*/0,
/*retry_exceptions=*/false,
std::make_pair(PlacementGroupID::Nil(), -1), true,
/*debugger_breakpoint=*/"");
@ -855,12 +856,14 @@ TEST_F(SingleNodeTest, TestCancelTasks) {
// Submit func1. The function should start looping forever.
driver.SubmitTask(func1, args, options, &return_ids1, /*max_retries=*/0,
/*retry_exceptions=*/false,
std::make_pair(PlacementGroupID::Nil(), -1), true,
/*debugger_breakpoint=*/"");
ASSERT_EQ(return_ids1.size(), 1);
// Submit func2. The function should be queued at the worker indefinitely.
driver.SubmitTask(func2, args, options, &return_ids2, /*max_retries=*/0,
/*retry_exceptions=*/false,
std::make_pair(PlacementGroupID::Nil(), -1), true,
/*debugger_breakpoint=*/"");
ASSERT_EQ(return_ids2.size(), 1);

View file

@ -91,6 +91,9 @@ class MockTaskFinisher : public TaskFinisherInterface {
MOCK_METHOD3(CompletePendingTask, void(const TaskID &, const rpc::PushTaskReply &,
const rpc::Address &addr));
MOCK_METHOD1(RetryTaskIfPossible, bool(const TaskID &task_id));
MOCK_METHOD5(PendingTaskFailed,
bool(const TaskID &task_id, rpc::ErrorType error_type, Status *status,
const std::shared_ptr<rpc::RayException> &creation_task_exception,

View file

@ -64,8 +64,8 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
return true;
}
bool ReplyPushTask(Status status = Status::OK(), bool exit = false,
bool stolen = false) {
bool ReplyPushTask(Status status = Status::OK(), bool exit = false, bool stolen = false,
bool is_application_level_error = false) {
if (callbacks.size() == 0) {
return false;
}
@ -77,6 +77,9 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
if (stolen) {
reply.set_task_stolen(true);
}
if (is_application_level_error) {
reply.set_is_application_level_error(true);
}
callback(status, reply);
callbacks.pop_front();
return true;
@ -101,6 +104,11 @@ class MockTaskFinisher : public TaskFinisherInterface {
num_tasks_complete++;
}
bool RetryTaskIfPossible(const TaskID &task_id) override {
num_task_retries_attempted++;
return false;
}
bool PendingTaskFailed(
const TaskID &task_id, rpc::ErrorType error_type, Status *status,
const std::shared_ptr<rpc::RayException> &creation_task_exception = nullptr,
@ -131,6 +139,7 @@ class MockTaskFinisher : public TaskFinisherInterface {
int num_tasks_failed = 0;
int num_inlined_dependencies = 0;
int num_contained_ids = 0;
int num_task_retries_attempted = 0;
};
class MockRayletClient : public WorkerLeaseInterface {
@ -441,6 +450,54 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) {
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(task_finisher->num_task_retries_attempted, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
// Check that there are no entries left in the scheduling_key_entries_ hashmap. These
// would otherwise cause a memory leak.
ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
}
TEST(DirectTaskTransportTest, TestRetryTaskApplicationLevelError) {
rpc::Address address;
auto raylet_client = std::make_shared<MockRayletClient>();
auto worker_client = std::make_shared<MockWorkerClient>();
auto store = std::make_shared<CoreWorkerMemoryStore>();
auto client_pool = std::make_shared<rpc::CoreWorkerClientPool>(
[&](const rpc::Address &addr) { return worker_client; });
auto task_finisher = std::make_shared<MockTaskFinisher>();
auto actor_creator = std::make_shared<MockActorCreator>();
auto lease_policy = std::make_shared<MockLeasePolicy>();
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr,
lease_policy, store, task_finisher,
NodeID::Nil(), kLongTimeout, actor_creator);
TaskSpecification task = BuildEmptyTaskSpec();
task.GetMutableMessage().set_retry_exceptions(true);
ASSERT_TRUE(submitter.SubmitTask(task).ok());
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, NodeID::Nil()));
// Simulate an application-level error.
ASSERT_TRUE(worker_client->ReplyPushTask(Status::OK(), false, false, true));
ASSERT_EQ(raylet_client->num_workers_returned, 1);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
ASSERT_EQ(task_finisher->num_task_retries_attempted, 1);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
task.GetMutableMessage().set_retry_exceptions(false);
ASSERT_TRUE(submitter.SubmitTask(task).ok());
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, NodeID::Nil()));
// Simulate an application-level error.
ASSERT_TRUE(worker_client->ReplyPushTask(Status::OK(), false, false, true));
ASSERT_EQ(raylet_client->num_workers_returned, 2);
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
ASSERT_EQ(task_finisher->num_tasks_complete, 2);
ASSERT_EQ(task_finisher->num_task_retries_attempted, 1);
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());

View file

@ -55,7 +55,10 @@ class TaskManagerTest : public ::testing::Test {
[this](const NodeID &node_id) { return all_nodes_alive_; },
[this](const ObjectID &object_id) {
objects_to_recover_.push_back(object_id);
}) {}
},
[](const JobID &job_id, const std::string &type,
const std::string &error_message,
double timestamp) { return Status::OK(); }) {}
std::shared_ptr<CoreWorkerMemoryStore> store_;
std::shared_ptr<mock_pubsub::MockPublisher> publisher_;

View file

@ -487,8 +487,11 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
RAY_CHECK(num_returns >= 0);
std::vector<std::shared_ptr<RayObject>> return_objects;
auto status = task_handler_(task_spec, resource_ids, &return_objects,
reply->mutable_borrowed_refs());
bool is_application_level_error = false;
auto status =
task_handler_(task_spec, resource_ids, &return_objects,
reply->mutable_borrowed_refs(), &is_application_level_error);
reply->set_is_application_level_error(is_application_level_error);
bool objects_valid = return_objects.size() == num_returns;
if (objects_valid) {

View file

@ -781,7 +781,8 @@ class CoreWorkerDirectTaskReceiver {
std::function<Status(const TaskSpecification &task_spec,
const std::shared_ptr<ResourceMappingType> resource_ids,
std::vector<std::shared_ptr<RayObject>> *return_objects,
ReferenceCounter::ReferenceTableProto *borrower_refs)>;
ReferenceCounter::ReferenceTableProto *borrower_refs,
bool *is_application_level_error)>;
using OnTaskDone = std::function<Status()>;

View file

@ -638,7 +638,11 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(
is_actor ? rpc::ErrorType::ACTOR_DIED : rpc::ErrorType::WORKER_DIED,
&status));
} else {
task_finisher_->CompletePendingTask(task_id, reply, addr.ToProto());
if (!task_spec.GetMessage().retry_exceptions() ||
!reply.is_application_level_error() ||
!task_finisher_->RetryTaskIfPossible(task_id)) {
task_finisher_->CompletePendingTask(task_id, reply, addr.ToProto());
}
}
});
}

View file

@ -213,6 +213,8 @@ message TaskSpec {
string serialized_runtime_env = 24;
// The concurrency group name in which this task will be performed.
string concurrency_group_name = 25;
// Whether application-level errors (exceptions) should be retried.
bool retry_exceptions = 26;
}
message Bundle {

View file

@ -124,6 +124,8 @@ message PushTaskReply {
// may now be borrowing. The reference counts also include any new borrowers
// that the worker created by passing a borrowed ID into a nested task.
repeated ObjectReferenceCount borrowed_refs = 4;
// Whether the result contains an application-level error (exception).
bool is_application_level_error = 5;
}
message DirectActorCallArgWaitCompleteRequest {