mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
Retry application-level errors (#18176)
* Retry application-level errors * Retry application-level errors * Push retry message to the driver
This commit is contained in:
parent
673bf35c1f
commit
fbb3ac6a86
26 changed files with 296 additions and 59 deletions
|
@ -42,8 +42,8 @@ ObjectID NativeTaskSubmitter::Submit(InvocationSpec &invocation,
|
|||
invocation.args, options, &return_ids);
|
||||
} else {
|
||||
core_worker.SubmitTask(BuildRayFunction(invocation), invocation.args, options,
|
||||
&return_ids, 1, std::make_pair(PlacementGroupID::Nil(), -1),
|
||||
true, "");
|
||||
&return_ids, 1, false,
|
||||
std::make_pair(PlacementGroupID::Nil(), -1), true, "");
|
||||
}
|
||||
return return_ids[0];
|
||||
}
|
||||
|
|
|
@ -126,7 +126,8 @@ Status TaskExecutor::ExecuteTask(
|
|||
const std::vector<ObjectID> &arg_reference_ids,
|
||||
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
||||
std::vector<std::shared_ptr<ray::RayObject>> *results,
|
||||
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes) {
|
||||
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
|
||||
bool *is_application_level_error) {
|
||||
RAY_LOG(INFO) << "Execute task: " << TaskType_Name(task_type);
|
||||
RAY_CHECK(ray_function.GetLanguage() == ray::Language::CPP);
|
||||
auto function_descriptor = ray_function.GetFunctionDescriptor();
|
||||
|
|
|
@ -79,7 +79,8 @@ class TaskExecutor {
|
|||
const std::vector<ObjectID> &arg_reference_ids,
|
||||
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
||||
std::vector<std::shared_ptr<ray::RayObject>> *results,
|
||||
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes);
|
||||
std::shared_ptr<ray::LocalMemoryBuffer> &creation_task_exception_pb_bytes,
|
||||
bool *is_application_level_error);
|
||||
|
||||
virtual ~TaskExecutor(){};
|
||||
|
||||
|
|
|
@ -406,7 +406,10 @@ cdef execute_task(
|
|||
const c_vector[CObjectID] &c_arg_reference_ids,
|
||||
const c_vector[CObjectID] &c_return_ids,
|
||||
const c_string debugger_breakpoint,
|
||||
c_vector[shared_ptr[CRayObject]] *returns):
|
||||
c_vector[shared_ptr[CRayObject]] *returns,
|
||||
c_bool *is_application_level_error):
|
||||
|
||||
is_application_level_error[0] = False
|
||||
|
||||
worker = ray.worker.global_worker
|
||||
manager = worker.function_actor_manager
|
||||
|
@ -579,6 +582,9 @@ cdef execute_task(
|
|||
except KeyboardInterrupt as e:
|
||||
raise TaskCancelledError(
|
||||
core_worker.get_current_task_id())
|
||||
except Exception as e:
|
||||
is_application_level_error[0] = True
|
||||
raise e
|
||||
if c_return_ids.size() == 1:
|
||||
outputs = (outputs,)
|
||||
# Check for a cancellation that was called when the function
|
||||
|
@ -656,7 +662,8 @@ cdef CRayStatus task_execution_handler(
|
|||
const c_vector[CObjectID] &c_return_ids,
|
||||
const c_string debugger_breakpoint,
|
||||
c_vector[shared_ptr[CRayObject]] *returns,
|
||||
shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes) nogil:
|
||||
shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes,
|
||||
c_bool *is_application_level_error) nogil:
|
||||
with gil, disable_client_hook():
|
||||
try:
|
||||
try:
|
||||
|
@ -664,7 +671,8 @@ cdef CRayStatus task_execution_handler(
|
|||
# it does, that indicates that there was an internal error.
|
||||
execute_task(task_type, task_name, ray_function, c_resources,
|
||||
c_args, c_arg_reference_ids, c_return_ids,
|
||||
debugger_breakpoint, returns)
|
||||
debugger_breakpoint, returns,
|
||||
is_application_level_error)
|
||||
except Exception as e:
|
||||
sys_exit = SystemExit()
|
||||
if isinstance(e, RayActorError) and \
|
||||
|
@ -1318,6 +1326,7 @@ cdef class CoreWorker:
|
|||
int num_returns,
|
||||
resources,
|
||||
int max_retries,
|
||||
c_bool retry_exceptions,
|
||||
PlacementGroupID placement_group_id,
|
||||
int64_t placement_group_bundle_index,
|
||||
c_bool placement_group_capture_child_tasks,
|
||||
|
@ -1354,7 +1363,7 @@ cdef class CoreWorker:
|
|||
b"",
|
||||
c_serialized_runtime_env,
|
||||
c_override_environment_variables),
|
||||
&return_ids, max_retries,
|
||||
&return_ids, max_retries, retry_exceptions,
|
||||
c_pair[CPlacementGroupID, int64_t](
|
||||
c_placement_group_id, placement_group_bundle_index),
|
||||
placement_group_capture_child_tasks,
|
||||
|
|
|
@ -77,7 +77,8 @@ def java_function(class_name, function_name):
|
|||
None, # accelerator_type,
|
||||
None, # num_returns,
|
||||
None, # max_calls,
|
||||
None, # max_retries
|
||||
None, # max_retries,
|
||||
None, # retry_exceptions,
|
||||
None) # runtime_env
|
||||
|
||||
|
||||
|
|
|
@ -104,6 +104,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
|||
const c_vector[unique_ptr[CTaskArg]] &args,
|
||||
const CTaskOptions &options, c_vector[CObjectID] *return_ids,
|
||||
int max_retries,
|
||||
c_bool retry_exceptions,
|
||||
c_pair[CPlacementGroupID, int64_t] placement_options,
|
||||
c_bool placement_group_capture_child_tasks,
|
||||
c_string debugger_breakpoint)
|
||||
|
@ -280,7 +281,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
|||
const c_string debugger_breakpoint,
|
||||
c_vector[shared_ptr[CRayObject]] *returns,
|
||||
shared_ptr[LocalMemoryBuffer]
|
||||
&creation_task_exception_pb_bytes) nogil
|
||||
&creation_task_exception_pb_bytes,
|
||||
c_bool *is_application_level_error) nogil
|
||||
) task_execution_callback
|
||||
(void(const CWorkerID &) nogil) on_worker_shutdown
|
||||
(CRayStatus() nogil) check_signals
|
||||
|
|
|
@ -24,6 +24,7 @@ DEFAULT_REMOTE_FUNCTION_MAX_CALLS = 0
|
|||
# Normal tasks may be retried on failure this many times.
|
||||
# TODO(swang): Allow this to be set globally for an application.
|
||||
DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES = 3
|
||||
DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,6 +54,9 @@ class RemoteFunction:
|
|||
of this remote function.
|
||||
_max_calls: The number of times a worker can execute this function
|
||||
before exiting.
|
||||
_max_retries: The number of times this task may be retried
|
||||
on worker failure.
|
||||
_retry_exceptions: Whether application-level errors should be retried.
|
||||
_runtime_env: The runtime environment for this task.
|
||||
_decorator: An optional decorator that should be applied to the remote
|
||||
function invocation (as opposed to the function execution) before
|
||||
|
@ -73,7 +77,7 @@ class RemoteFunction:
|
|||
def __init__(self, language, function, function_descriptor, num_cpus,
|
||||
num_gpus, memory, object_store_memory, resources,
|
||||
accelerator_type, num_returns, max_calls, max_retries,
|
||||
runtime_env):
|
||||
retry_exceptions, runtime_env):
|
||||
if inspect.iscoroutinefunction(function):
|
||||
raise ValueError("'async def' should not be used for remote "
|
||||
"tasks. You can wrap the async function with "
|
||||
|
@ -100,6 +104,9 @@ class RemoteFunction:
|
|||
if max_calls is None else max_calls)
|
||||
self._max_retries = (DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES
|
||||
if max_retries is None else max_retries)
|
||||
self._retry_exceptions = (DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS
|
||||
if retry_exceptions is None else
|
||||
retry_exceptions)
|
||||
self._runtime_env = runtime_env
|
||||
self._decorator = getattr(function, "__ray_invocation_decorator__",
|
||||
None)
|
||||
|
@ -131,6 +138,7 @@ class RemoteFunction:
|
|||
accelerator_type=None,
|
||||
resources=None,
|
||||
max_retries=None,
|
||||
retry_exceptions=None,
|
||||
placement_group="default",
|
||||
placement_group_bundle_index=-1,
|
||||
placement_group_capture_child_tasks=None,
|
||||
|
@ -168,6 +176,7 @@ class RemoteFunction:
|
|||
accelerator_type=accelerator_type,
|
||||
resources=resources,
|
||||
max_retries=max_retries,
|
||||
retry_exceptions=retry_exceptions,
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=placement_group_bundle_index,
|
||||
placement_group_capture_child_tasks=(
|
||||
|
@ -191,6 +200,7 @@ class RemoteFunction:
|
|||
accelerator_type=None,
|
||||
resources=None,
|
||||
max_retries=None,
|
||||
retry_exceptions=None,
|
||||
placement_group="default",
|
||||
placement_group_bundle_index=-1,
|
||||
placement_group_capture_child_tasks=None,
|
||||
|
@ -211,6 +221,7 @@ class RemoteFunction:
|
|||
accelerator_type=accelerator_type,
|
||||
resources=resources,
|
||||
max_retries=max_retries,
|
||||
retry_exceptions=retry_exceptions,
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=placement_group_bundle_index,
|
||||
placement_group_capture_child_tasks=(
|
||||
|
@ -251,6 +262,8 @@ class RemoteFunction:
|
|||
num_returns = self._num_returns
|
||||
if max_retries is None:
|
||||
max_retries = self._max_retries
|
||||
if retry_exceptions is None:
|
||||
retry_exceptions = self._retry_exceptions
|
||||
|
||||
if placement_group_capture_child_tasks is None:
|
||||
placement_group_capture_child_tasks = (
|
||||
|
@ -307,6 +320,7 @@ class RemoteFunction:
|
|||
num_returns,
|
||||
resources,
|
||||
max_retries,
|
||||
retry_exceptions,
|
||||
placement_group.id,
|
||||
placement_group_bundle_index,
|
||||
placement_group_capture_child_tasks,
|
||||
|
|
|
@ -10,6 +10,68 @@ from ray._private.test_utils import (init_error_pubsub, get_error_message,
|
|||
run_string_as_driver)
|
||||
|
||||
|
||||
def test_retry_system_level_error(ray_start_regular):
|
||||
@ray.remote
|
||||
class Counter:
|
||||
def __init__(self):
|
||||
self.value = 0
|
||||
|
||||
def increment(self):
|
||||
self.value += 1
|
||||
return self.value
|
||||
|
||||
@ray.remote(max_retries=1)
|
||||
def func(counter):
|
||||
count = counter.increment.remote()
|
||||
if ray.get(count) == 1:
|
||||
import os
|
||||
os._exit(0)
|
||||
else:
|
||||
return 1
|
||||
|
||||
counter1 = Counter.remote()
|
||||
r1 = func.remote(counter1)
|
||||
assert ray.get(r1) == 1
|
||||
|
||||
counter2 = Counter.remote()
|
||||
r2 = func.options(max_retries=0).remote(counter2)
|
||||
with pytest.raises(ray.exceptions.WorkerCrashedError):
|
||||
ray.get(r2)
|
||||
|
||||
|
||||
def test_retry_application_level_error(ray_start_regular):
|
||||
@ray.remote
|
||||
class Counter:
|
||||
def __init__(self):
|
||||
self.value = 0
|
||||
|
||||
def increment(self):
|
||||
self.value += 1
|
||||
return self.value
|
||||
|
||||
@ray.remote(max_retries=1, retry_exceptions=True)
|
||||
def func(counter):
|
||||
count = counter.increment.remote()
|
||||
if ray.get(count) == 1:
|
||||
raise ValueError()
|
||||
else:
|
||||
return 2
|
||||
|
||||
counter1 = Counter.remote()
|
||||
r1 = func.remote(counter1)
|
||||
assert ray.get(r1) == 2
|
||||
|
||||
counter2 = Counter.remote()
|
||||
r2 = func.options(max_retries=0).remote(counter2)
|
||||
with pytest.raises(ValueError):
|
||||
ray.get(r2)
|
||||
|
||||
counter3 = Counter.remote()
|
||||
r3 = func.options(retry_exceptions=False).remote(counter3)
|
||||
with pytest.raises(ValueError):
|
||||
ray.get(r3)
|
||||
|
||||
|
||||
def test_connect_with_disconnected_node(shutdown_only):
|
||||
config = {
|
||||
"num_heartbeats_timeout": 50,
|
||||
|
|
|
@ -22,6 +22,7 @@ options = {
|
|||
"max_retries": (int, lambda x: x >= -1,
|
||||
"The keyword 'max_retries' only accepts 0, -1 "
|
||||
"or a positive integer"),
|
||||
"retry_exceptions": (),
|
||||
"max_concurrency": (),
|
||||
"name": (),
|
||||
"namespace": (),
|
||||
|
|
|
@ -651,7 +651,7 @@ def init(
|
|||
a raylet, a plasma store, a plasma manager, and some workers.
|
||||
It will also kill these processes when Python exits. If the driver
|
||||
is running on a node in a Ray cluster, using `auto` as the value
|
||||
tells the driver to detect the the cluster, removing the need to
|
||||
tells the driver to detect the cluster, removing the need to
|
||||
specify a specific node address. If the environment variable
|
||||
`RAY_ADDRESS` is defined and the address is None or "auto", Ray
|
||||
will set `address` to `RAY_ADDRESS`.
|
||||
|
@ -1924,7 +1924,8 @@ def make_decorator(num_returns=None,
|
|||
max_restarts=None,
|
||||
max_task_retries=None,
|
||||
runtime_env=None,
|
||||
worker=None):
|
||||
worker=None,
|
||||
retry_exceptions=None):
|
||||
def decorator(function_or_class):
|
||||
if (inspect.isfunction(function_or_class)
|
||||
or is_cython(function_or_class)):
|
||||
|
@ -1953,12 +1954,19 @@ def make_decorator(num_returns=None,
|
|||
return ray.remote_function.RemoteFunction(
|
||||
Language.PYTHON, function_or_class, None, num_cpus, num_gpus,
|
||||
memory, object_store_memory, resources, accelerator_type,
|
||||
num_returns, max_calls, max_retries, runtime_env)
|
||||
num_returns, max_calls, max_retries, retry_exceptions,
|
||||
runtime_env)
|
||||
|
||||
if inspect.isclass(function_or_class):
|
||||
if num_returns is not None:
|
||||
raise TypeError("The keyword 'num_returns' is not "
|
||||
"allowed for actors.")
|
||||
if max_retries is not None:
|
||||
raise TypeError("The keyword 'max_retries' is not "
|
||||
"allowed for actors.")
|
||||
if retry_exceptions is not None:
|
||||
raise TypeError("The keyword 'retry_exceptions' is not "
|
||||
"allowed for actors.")
|
||||
if max_calls is not None:
|
||||
raise TypeError("The keyword 'max_calls' is not "
|
||||
"allowed for actors.")
|
||||
|
@ -2082,6 +2090,9 @@ def remote(*args, **kwargs):
|
|||
this actor or task and its children. See
|
||||
:ref:`runtime-environments` for detailed documentation. This API is
|
||||
in beta and may change before becoming stable.
|
||||
retry_exceptions (bool): Only for *remote functions*. This specifies
|
||||
whether application-level errors should be retried
|
||||
up to max_retries times.
|
||||
override_environment_variables (Dict[str, str]): (Deprecated in Ray
|
||||
1.4.0, will be removed in Ray 1.6--please use the ``env_vars``
|
||||
field of :ref:`runtime-environments` instead.) This specifies
|
||||
|
@ -2102,7 +2113,7 @@ def remote(*args, **kwargs):
|
|||
valid_kwargs = [
|
||||
"num_returns", "num_cpus", "num_gpus", "memory", "object_store_memory",
|
||||
"resources", "accelerator_type", "max_calls", "max_restarts",
|
||||
"max_task_retries", "max_retries", "runtime_env"
|
||||
"max_task_retries", "max_retries", "runtime_env", "retry_exceptions"
|
||||
]
|
||||
error_string = ("The @ray.remote decorator must be applied either "
|
||||
"with no arguments and no parentheses, for example "
|
||||
|
@ -2135,6 +2146,7 @@ def remote(*args, **kwargs):
|
|||
object_store_memory = kwargs.get("object_store_memory")
|
||||
max_retries = kwargs.get("max_retries")
|
||||
runtime_env = kwargs.get("runtime_env")
|
||||
retry_exceptions = kwargs.get("retry_exceptions")
|
||||
|
||||
return make_decorator(
|
||||
num_returns=num_returns,
|
||||
|
@ -2149,4 +2161,5 @@ def remote(*args, **kwargs):
|
|||
max_task_retries=max_task_retries,
|
||||
max_retries=max_retries,
|
||||
runtime_env=runtime_env,
|
||||
worker=worker)
|
||||
worker=worker,
|
||||
retry_exceptions=retry_exceptions)
|
||||
|
|
|
@ -134,6 +134,12 @@ class TaskSpecBuilder {
|
|||
return *this;
|
||||
}
|
||||
|
||||
TaskSpecBuilder &SetNormalTaskSpec(int max_retries, bool retry_exceptions) {
|
||||
message_->set_max_retries(max_retries);
|
||||
message_->set_retry_exceptions(retry_exceptions);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Set the driver attributes of the task spec.
|
||||
/// See `common.proto` for meaning of the arguments.
|
||||
///
|
||||
|
|
|
@ -385,9 +385,9 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
|
|||
// Initialize task receivers.
|
||||
if (options_.worker_type == WorkerType::WORKER || options_.is_local_mode) {
|
||||
RAY_CHECK(options_.task_execution_callback != nullptr);
|
||||
auto execute_task =
|
||||
std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
|
||||
auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3,
|
||||
std::placeholders::_4, std::placeholders::_5);
|
||||
direct_task_receiver_ = std::make_unique<CoreWorkerDirectTaskReceiver>(
|
||||
worker_context_, task_execution_service_, execute_task,
|
||||
[this] { return local_raylet_client_->TaskDone(); });
|
||||
|
@ -565,6 +565,10 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
|
|||
},
|
||||
"CoreWorker.ReconstructObject");
|
||||
};
|
||||
auto push_error_callback = [this](const JobID &job_id, const std::string &type,
|
||||
const std::string &error_message, double timestamp) {
|
||||
return PushError(job_id, type, error_message, timestamp);
|
||||
};
|
||||
task_manager_.reset(new TaskManager(
|
||||
memory_store_, reference_counter_,
|
||||
/* retry_task_callback= */
|
||||
|
@ -589,7 +593,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
|
|||
}
|
||||
}
|
||||
},
|
||||
check_node_alive_fn, reconstruct_object_callback));
|
||||
check_node_alive_fn, reconstruct_object_callback, push_error_callback));
|
||||
|
||||
// Create an entry for the driver task in the task table. This task is
|
||||
// added immediately with status RUNNING. This allows us to push errors
|
||||
|
@ -1643,7 +1647,7 @@ void CoreWorker::SubmitTask(const RayFunction &function,
|
|||
const std::vector<std::unique_ptr<TaskArg>> &args,
|
||||
const TaskOptions &task_options,
|
||||
std::vector<ObjectID> *return_ids, int max_retries,
|
||||
BundleID placement_options,
|
||||
bool retry_exceptions, BundleID placement_options,
|
||||
bool placement_group_capture_child_tasks,
|
||||
const std::string &debugger_breakpoint) {
|
||||
TaskSpecBuilder builder;
|
||||
|
@ -1674,6 +1678,7 @@ void CoreWorker::SubmitTask(const RayFunction &function,
|
|||
placement_options, placement_group_capture_child_tasks,
|
||||
debugger_breakpoint, task_options.serialized_runtime_env,
|
||||
override_environment_variables);
|
||||
builder.SetNormalTaskSpec(max_retries, retry_exceptions);
|
||||
TaskSpecification task_spec = builder.Build();
|
||||
RAY_LOG(DEBUG) << "Submit task " << task_spec.DebugString();
|
||||
if (options_.is_local_mode) {
|
||||
|
@ -2146,7 +2151,8 @@ Status CoreWorker::AllocateReturnObject(const ObjectID &object_id,
|
|||
Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
|
||||
const std::shared_ptr<ResourceMappingType> &resource_ids,
|
||||
std::vector<std::shared_ptr<RayObject>> *return_objects,
|
||||
ReferenceCounter::ReferenceTableProto *borrowed_refs) {
|
||||
ReferenceCounter::ReferenceTableProto *borrowed_refs,
|
||||
bool *is_application_level_error) {
|
||||
RAY_LOG(DEBUG) << "Executing task, task info = " << task_spec.DebugString();
|
||||
task_queue_length_ -= 1;
|
||||
num_executed_tasks_ += 1;
|
||||
|
@ -2212,7 +2218,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
|
|||
task_type, task_spec.GetName(), func,
|
||||
task_spec.GetRequiredResources().GetResourceMap(), args, arg_reference_ids,
|
||||
return_ids, task_spec.GetDebuggerBreakpoint(), return_objects,
|
||||
creation_task_exception_pb_bytes);
|
||||
creation_task_exception_pb_bytes, is_application_level_error);
|
||||
|
||||
// Get the reference counts for any IDs that we borrowed during this task and
|
||||
// return them to the caller. This will notify the caller of any IDs that we
|
||||
|
@ -2301,7 +2307,9 @@ void CoreWorker::ExecuteTaskLocalMode(const TaskSpecification &task_spec,
|
|||
}
|
||||
auto old_id = GetActorId();
|
||||
SetActorId(actor_id);
|
||||
RAY_UNUSED(ExecuteTask(task_spec, resource_ids, &return_objects, &borrowed_refs));
|
||||
bool is_application_level_error;
|
||||
RAY_UNUSED(ExecuteTask(task_spec, resource_ids, &return_objects, &borrowed_refs,
|
||||
&is_application_level_error));
|
||||
SetActorId(old_id);
|
||||
}
|
||||
|
||||
|
|
|
@ -70,7 +70,8 @@ struct CoreWorkerOptions {
|
|||
const std::vector<ObjectID> &arg_reference_ids,
|
||||
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
||||
std::vector<std::shared_ptr<RayObject>> *results,
|
||||
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb_bytes)>;
|
||||
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb_bytes,
|
||||
bool *is_application_level_error)>;
|
||||
|
||||
CoreWorkerOptions()
|
||||
: store_socket(""),
|
||||
|
@ -712,7 +713,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
|||
void SubmitTask(const RayFunction &function,
|
||||
const std::vector<std::unique_ptr<TaskArg>> &args,
|
||||
const TaskOptions &task_options, std::vector<ObjectID> *return_ids,
|
||||
int max_retries, BundleID placement_options,
|
||||
int max_retries, bool retry_exceptions, BundleID placement_options,
|
||||
bool placement_group_capture_child_tasks,
|
||||
const std::string &debugger_breakpoint);
|
||||
|
||||
|
@ -1122,7 +1123,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
|
|||
Status ExecuteTask(const TaskSpecification &task_spec,
|
||||
const std::shared_ptr<ResourceMappingType> &resource_ids,
|
||||
std::vector<std::shared_ptr<RayObject>> *return_objects,
|
||||
ReferenceCounter::ReferenceTableProto *borrowed_refs);
|
||||
ReferenceCounter::ReferenceTableProto *borrowed_refs,
|
||||
bool *is_application_level_error);
|
||||
|
||||
/// Execute a local mode task (runs normal ExecuteTask)
|
||||
///
|
||||
|
|
|
@ -100,7 +100,11 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
|
|||
const std::vector<ObjectID> &arg_reference_ids,
|
||||
const std::vector<ObjectID> &return_ids, const std::string &debugger_breakpoint,
|
||||
std::vector<std::shared_ptr<RayObject>> *results,
|
||||
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb) {
|
||||
std::shared_ptr<LocalMemoryBuffer> &creation_task_exception_pb,
|
||||
bool *is_application_level_error) {
|
||||
// TODO(jjyao): Support retrying application-level errors for Java
|
||||
*is_application_level_error = false;
|
||||
|
||||
JNIEnv *env = GetJNIEnv();
|
||||
RAY_CHECK(java_task_executor);
|
||||
|
||||
|
|
|
@ -308,6 +308,7 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSub
|
|||
CoreWorkerProcess::GetCoreWorker().SubmitTask(
|
||||
ray_function, task_args, task_options, &return_ids,
|
||||
/*max_retries=*/0,
|
||||
/*retry_exceptions=*/false,
|
||||
/*placement_options=*/placement_group_options,
|
||||
/*placement_group_capture_child_tasks=*/true,
|
||||
/*debugger_breakpoint*/ "");
|
||||
|
|
|
@ -266,7 +266,7 @@ void TaskManager::CompletePendingTask(const TaskID &task_id,
|
|||
it->second.pending = false;
|
||||
num_pending_tasks_--;
|
||||
|
||||
// A finished task can be only be re-executed if it has some number of
|
||||
// A finished task can only be re-executed if it has some number of
|
||||
// retries left and returned at least one object that is still in use and
|
||||
// stored in plasma.
|
||||
bool task_retryable = it->second.num_retries_left != 0 &&
|
||||
|
@ -284,6 +284,45 @@ void TaskManager::CompletePendingTask(const TaskID &task_id,
|
|||
ShutdownIfNeeded();
|
||||
}
|
||||
|
||||
bool TaskManager::RetryTaskIfPossible(const TaskID &task_id) {
|
||||
int num_retries_left = 0;
|
||||
TaskSpecification spec;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = submissible_tasks_.find(task_id);
|
||||
RAY_CHECK(it != submissible_tasks_.end())
|
||||
<< "Tried to retry task that was not pending " << task_id;
|
||||
RAY_CHECK(it->second.pending)
|
||||
<< "Tried to retry task that was not pending " << task_id;
|
||||
spec = it->second.spec;
|
||||
num_retries_left = it->second.num_retries_left;
|
||||
if (num_retries_left > 0) {
|
||||
it->second.num_retries_left--;
|
||||
} else {
|
||||
RAY_CHECK(num_retries_left == 0 || num_retries_left == -1);
|
||||
}
|
||||
}
|
||||
|
||||
// We should not hold the lock during these calls because they may trigger
|
||||
// callbacks in this or other classes.
|
||||
if (num_retries_left != 0) {
|
||||
auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch())
|
||||
.count();
|
||||
std::ostringstream stream;
|
||||
auto num_retries_left_str =
|
||||
num_retries_left == -1 ? "infinite" : std::to_string(num_retries_left);
|
||||
stream << num_retries_left_str << " retries left for task " << spec.TaskId()
|
||||
<< ", attempting to resubmit.";
|
||||
RAY_CHECK_OK(
|
||||
push_error_callback_(spec.JobId(), "retry_task", stream.str(), timestamp));
|
||||
retry_task_callback_(spec, /*delay=*/true);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool TaskManager::PendingTaskFailed(
|
||||
const TaskID &task_id, rpc::ErrorType error_type, Status *status,
|
||||
const std::shared_ptr<rpc::RayException> &creation_task_exception,
|
||||
|
@ -292,9 +331,9 @@ bool TaskManager::PendingTaskFailed(
|
|||
// loudly with ERROR here.
|
||||
RAY_LOG(DEBUG) << "Task " << task_id << " failed with error "
|
||||
<< rpc::ErrorType_Name(error_type);
|
||||
int num_retries_left = 0;
|
||||
const bool will_retry = RetryTaskIfPossible(task_id);
|
||||
const bool release_lineage = !will_retry;
|
||||
TaskSpecification spec;
|
||||
bool release_lineage = true;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = submissible_tasks_.find(task_id);
|
||||
|
@ -303,30 +342,13 @@ bool TaskManager::PendingTaskFailed(
|
|||
RAY_CHECK(it->second.pending)
|
||||
<< "Tried to complete task that was not pending " << task_id;
|
||||
spec = it->second.spec;
|
||||
num_retries_left = it->second.num_retries_left;
|
||||
if (num_retries_left == 0) {
|
||||
if (!will_retry) {
|
||||
submissible_tasks_.erase(it);
|
||||
num_pending_tasks_--;
|
||||
} else if (num_retries_left == -1) {
|
||||
release_lineage = false;
|
||||
} else {
|
||||
RAY_CHECK(num_retries_left > 0);
|
||||
it->second.num_retries_left--;
|
||||
release_lineage = false;
|
||||
}
|
||||
}
|
||||
|
||||
bool will_retry = false;
|
||||
// We should not hold the lock during these calls because they may trigger
|
||||
// callbacks in this or other classes.
|
||||
if (num_retries_left != 0) {
|
||||
auto retries_str =
|
||||
num_retries_left == -1 ? "infinite" : std::to_string(num_retries_left);
|
||||
RAY_LOG(INFO) << retries_str << " retries left for task " << spec.TaskId()
|
||||
<< ", attempting to resubmit.";
|
||||
retry_task_callback_(spec, /*delay=*/true);
|
||||
will_retry = true;
|
||||
} else {
|
||||
if (!will_retry) {
|
||||
// Throttled logging of task failure errors.
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
|
|
@ -31,6 +31,8 @@ class TaskFinisherInterface {
|
|||
virtual void CompletePendingTask(const TaskID &task_id, const rpc::PushTaskReply &reply,
|
||||
const rpc::Address &actor_addr) = 0;
|
||||
|
||||
virtual bool RetryTaskIfPossible(const TaskID &task_id) = 0;
|
||||
|
||||
virtual bool PendingTaskFailed(
|
||||
const TaskID &task_id, rpc::ErrorType error_type, Status *status,
|
||||
const std::shared_ptr<rpc::RayException> &creation_task_exception = nullptr,
|
||||
|
@ -61,6 +63,9 @@ class TaskResubmissionInterface {
|
|||
|
||||
using RetryTaskCallback = std::function<void(TaskSpecification &spec, bool delay)>;
|
||||
using ReconstructObjectCallback = std::function<void(const ObjectID &object_id)>;
|
||||
using PushErrorCallback =
|
||||
std::function<Status(const JobID &job_id, const std::string &type,
|
||||
const std::string &error_message, double timestamp)>;
|
||||
|
||||
class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterface {
|
||||
public:
|
||||
|
@ -68,12 +73,14 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
|
|||
std::shared_ptr<ReferenceCounter> reference_counter,
|
||||
RetryTaskCallback retry_task_callback,
|
||||
const std::function<bool(const NodeID &node_id)> &check_node_alive,
|
||||
ReconstructObjectCallback reconstruct_object_callback)
|
||||
ReconstructObjectCallback reconstruct_object_callback,
|
||||
PushErrorCallback push_error_callback)
|
||||
: in_memory_store_(in_memory_store),
|
||||
reference_counter_(reference_counter),
|
||||
retry_task_callback_(retry_task_callback),
|
||||
check_node_alive_(check_node_alive),
|
||||
reconstruct_object_callback_(reconstruct_object_callback) {
|
||||
reconstruct_object_callback_(reconstruct_object_callback),
|
||||
push_error_callback_(push_error_callback) {
|
||||
reference_counter_->SetReleaseLineageCallback(
|
||||
[this](const ObjectID &object_id, std::vector<ObjectID> *ids_to_release) {
|
||||
RemoveLineageReference(object_id, ids_to_release);
|
||||
|
@ -118,6 +125,8 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
|
|||
void CompletePendingTask(const TaskID &task_id, const rpc::PushTaskReply &reply,
|
||||
const rpc::Address &worker_addr) override;
|
||||
|
||||
bool RetryTaskIfPossible(const TaskID &task_id) override;
|
||||
|
||||
/// A pending task failed. This will either retry the task or mark the task
|
||||
/// as failed if there are no retries left.
|
||||
///
|
||||
|
@ -266,6 +275,9 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa
|
|||
/// recoverable).
|
||||
const ReconstructObjectCallback reconstruct_object_callback_;
|
||||
|
||||
// Called to push an error to the relevant driver.
|
||||
const PushErrorCallback push_error_callback_;
|
||||
|
||||
// The number of task failures we have logged total.
|
||||
int64_t num_failure_logs_ GUARDED_BY(mu_) = 0;
|
||||
|
||||
|
|
|
@ -249,6 +249,7 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map<std::string, double> &res
|
|||
TaskOptions options;
|
||||
std::vector<ObjectID> return_ids;
|
||||
driver.SubmitTask(func, args, options, &return_ids, /*max_retries=*/0,
|
||||
/*retry_exceptions=*/false,
|
||||
std::make_pair(PlacementGroupID::Nil(), -1), true,
|
||||
/*debugger_breakpoint=*/"");
|
||||
|
||||
|
@ -855,12 +856,14 @@ TEST_F(SingleNodeTest, TestCancelTasks) {
|
|||
|
||||
// Submit func1. The function should start looping forever.
|
||||
driver.SubmitTask(func1, args, options, &return_ids1, /*max_retries=*/0,
|
||||
/*retry_exceptions=*/false,
|
||||
std::make_pair(PlacementGroupID::Nil(), -1), true,
|
||||
/*debugger_breakpoint=*/"");
|
||||
ASSERT_EQ(return_ids1.size(), 1);
|
||||
|
||||
// Submit func2. The function should be queued at the worker indefinitely.
|
||||
driver.SubmitTask(func2, args, options, &return_ids2, /*max_retries=*/0,
|
||||
/*retry_exceptions=*/false,
|
||||
std::make_pair(PlacementGroupID::Nil(), -1), true,
|
||||
/*debugger_breakpoint=*/"");
|
||||
ASSERT_EQ(return_ids2.size(), 1);
|
||||
|
|
|
@ -91,6 +91,9 @@ class MockTaskFinisher : public TaskFinisherInterface {
|
|||
|
||||
MOCK_METHOD3(CompletePendingTask, void(const TaskID &, const rpc::PushTaskReply &,
|
||||
const rpc::Address &addr));
|
||||
|
||||
MOCK_METHOD1(RetryTaskIfPossible, bool(const TaskID &task_id));
|
||||
|
||||
MOCK_METHOD5(PendingTaskFailed,
|
||||
bool(const TaskID &task_id, rpc::ErrorType error_type, Status *status,
|
||||
const std::shared_ptr<rpc::RayException> &creation_task_exception,
|
||||
|
|
|
@ -64,8 +64,8 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ReplyPushTask(Status status = Status::OK(), bool exit = false,
|
||||
bool stolen = false) {
|
||||
bool ReplyPushTask(Status status = Status::OK(), bool exit = false, bool stolen = false,
|
||||
bool is_application_level_error = false) {
|
||||
if (callbacks.size() == 0) {
|
||||
return false;
|
||||
}
|
||||
|
@ -77,6 +77,9 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
|
|||
if (stolen) {
|
||||
reply.set_task_stolen(true);
|
||||
}
|
||||
if (is_application_level_error) {
|
||||
reply.set_is_application_level_error(true);
|
||||
}
|
||||
callback(status, reply);
|
||||
callbacks.pop_front();
|
||||
return true;
|
||||
|
@ -101,6 +104,11 @@ class MockTaskFinisher : public TaskFinisherInterface {
|
|||
num_tasks_complete++;
|
||||
}
|
||||
|
||||
bool RetryTaskIfPossible(const TaskID &task_id) override {
|
||||
num_task_retries_attempted++;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PendingTaskFailed(
|
||||
const TaskID &task_id, rpc::ErrorType error_type, Status *status,
|
||||
const std::shared_ptr<rpc::RayException> &creation_task_exception = nullptr,
|
||||
|
@ -131,6 +139,7 @@ class MockTaskFinisher : public TaskFinisherInterface {
|
|||
int num_tasks_failed = 0;
|
||||
int num_inlined_dependencies = 0;
|
||||
int num_contained_ids = 0;
|
||||
int num_task_retries_attempted = 0;
|
||||
};
|
||||
|
||||
class MockRayletClient : public WorkerLeaseInterface {
|
||||
|
@ -441,6 +450,54 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) {
|
|||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
|
||||
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
|
||||
ASSERT_EQ(task_finisher->num_task_retries_attempted, 0);
|
||||
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
|
||||
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
|
||||
|
||||
// Check that there are no entries left in the scheduling_key_entries_ hashmap. These
|
||||
// would otherwise cause a memory leak.
|
||||
ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic());
|
||||
}
|
||||
|
||||
TEST(DirectTaskTransportTest, TestRetryTaskApplicationLevelError) {
|
||||
rpc::Address address;
|
||||
auto raylet_client = std::make_shared<MockRayletClient>();
|
||||
auto worker_client = std::make_shared<MockWorkerClient>();
|
||||
auto store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
auto client_pool = std::make_shared<rpc::CoreWorkerClientPool>(
|
||||
[&](const rpc::Address &addr) { return worker_client; });
|
||||
auto task_finisher = std::make_shared<MockTaskFinisher>();
|
||||
auto actor_creator = std::make_shared<MockActorCreator>();
|
||||
auto lease_policy = std::make_shared<MockLeasePolicy>();
|
||||
CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr,
|
||||
lease_policy, store, task_finisher,
|
||||
NodeID::Nil(), kLongTimeout, actor_creator);
|
||||
TaskSpecification task = BuildEmptyTaskSpec();
|
||||
task.GetMutableMessage().set_retry_exceptions(true);
|
||||
|
||||
ASSERT_TRUE(submitter.SubmitTask(task).ok());
|
||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, NodeID::Nil()));
|
||||
// Simulate an application-level error.
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask(Status::OK(), false, false, true));
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 1);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_complete, 1);
|
||||
ASSERT_EQ(task_finisher->num_task_retries_attempted, 1);
|
||||
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
|
||||
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
|
||||
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
|
||||
|
||||
task.GetMutableMessage().set_retry_exceptions(false);
|
||||
|
||||
ASSERT_TRUE(submitter.SubmitTask(task).ok());
|
||||
ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, NodeID::Nil()));
|
||||
// Simulate an application-level error.
|
||||
ASSERT_TRUE(worker_client->ReplyPushTask(Status::OK(), false, false, true));
|
||||
ASSERT_EQ(raylet_client->num_workers_returned, 2);
|
||||
ASSERT_EQ(raylet_client->num_workers_disconnected, 0);
|
||||
ASSERT_EQ(task_finisher->num_tasks_complete, 2);
|
||||
ASSERT_EQ(task_finisher->num_task_retries_attempted, 1);
|
||||
ASSERT_EQ(task_finisher->num_tasks_failed, 0);
|
||||
ASSERT_EQ(raylet_client->num_leases_canceled, 0);
|
||||
ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease());
|
||||
|
||||
|
|
|
@ -55,7 +55,10 @@ class TaskManagerTest : public ::testing::Test {
|
|||
[this](const NodeID &node_id) { return all_nodes_alive_; },
|
||||
[this](const ObjectID &object_id) {
|
||||
objects_to_recover_.push_back(object_id);
|
||||
}) {}
|
||||
},
|
||||
[](const JobID &job_id, const std::string &type,
|
||||
const std::string &error_message,
|
||||
double timestamp) { return Status::OK(); }) {}
|
||||
|
||||
std::shared_ptr<CoreWorkerMemoryStore> store_;
|
||||
std::shared_ptr<mock_pubsub::MockPublisher> publisher_;
|
||||
|
|
|
@ -487,8 +487,11 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
|||
RAY_CHECK(num_returns >= 0);
|
||||
|
||||
std::vector<std::shared_ptr<RayObject>> return_objects;
|
||||
auto status = task_handler_(task_spec, resource_ids, &return_objects,
|
||||
reply->mutable_borrowed_refs());
|
||||
bool is_application_level_error = false;
|
||||
auto status =
|
||||
task_handler_(task_spec, resource_ids, &return_objects,
|
||||
reply->mutable_borrowed_refs(), &is_application_level_error);
|
||||
reply->set_is_application_level_error(is_application_level_error);
|
||||
|
||||
bool objects_valid = return_objects.size() == num_returns;
|
||||
if (objects_valid) {
|
||||
|
|
|
@ -781,7 +781,8 @@ class CoreWorkerDirectTaskReceiver {
|
|||
std::function<Status(const TaskSpecification &task_spec,
|
||||
const std::shared_ptr<ResourceMappingType> resource_ids,
|
||||
std::vector<std::shared_ptr<RayObject>> *return_objects,
|
||||
ReferenceCounter::ReferenceTableProto *borrower_refs)>;
|
||||
ReferenceCounter::ReferenceTableProto *borrower_refs,
|
||||
bool *is_application_level_error)>;
|
||||
|
||||
using OnTaskDone = std::function<Status()>;
|
||||
|
||||
|
|
|
@ -638,7 +638,11 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(
|
|||
is_actor ? rpc::ErrorType::ACTOR_DIED : rpc::ErrorType::WORKER_DIED,
|
||||
&status));
|
||||
} else {
|
||||
task_finisher_->CompletePendingTask(task_id, reply, addr.ToProto());
|
||||
if (!task_spec.GetMessage().retry_exceptions() ||
|
||||
!reply.is_application_level_error() ||
|
||||
!task_finisher_->RetryTaskIfPossible(task_id)) {
|
||||
task_finisher_->CompletePendingTask(task_id, reply, addr.ToProto());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
@ -213,6 +213,8 @@ message TaskSpec {
|
|||
string serialized_runtime_env = 24;
|
||||
// The concurrency group name in which this task will be performed.
|
||||
string concurrency_group_name = 25;
|
||||
// Whether application-level errors (exceptions) should be retried.
|
||||
bool retry_exceptions = 26;
|
||||
}
|
||||
|
||||
message Bundle {
|
||||
|
|
|
@ -124,6 +124,8 @@ message PushTaskReply {
|
|||
// may now be borrowing. The reference counts also include any new borrowers
|
||||
// that the worker created by passing a borrowed ID into a nested task.
|
||||
repeated ObjectReferenceCount borrowed_refs = 4;
|
||||
// Whether the result contains an application-level error (exception).
|
||||
bool is_application_level_error = 5;
|
||||
}
|
||||
|
||||
message DirectActorCallArgWaitCompleteRequest {
|
||||
|
|
Loading…
Add table
Reference in a new issue