From fbb3ac6a868cc4d4cd4dcb5e141dd203d1bbeb86 Mon Sep 17 00:00:00 2001 From: Jiajun Yao Date: Wed, 1 Sep 2021 10:53:06 -0700 Subject: [PATCH] Retry application-level errors (#18176) * Retry application-level errors * Retry application-level errors * Push retry message to the driver --- .../ray/runtime/task/native_task_submitter.cc | 4 +- cpp/src/ray/runtime/task/task_executor.cc | 3 +- cpp/src/ray/runtime/task/task_executor.h | 3 +- python/ray/_raylet.pyx | 17 +++-- python/ray/cross_language.py | 3 +- python/ray/includes/libcoreworker.pxd | 4 +- python/ray/remote_function.py | 16 ++++- python/ray/tests/test_failure_4.py | 62 +++++++++++++++++ python/ray/util/client/options.py | 1 + python/ray/worker.py | 23 +++++-- src/ray/common/task/task_util.h | 6 ++ src/ray/core_worker/core_worker.cc | 24 ++++--- src/ray/core_worker/core_worker.h | 8 ++- .../java/io_ray_runtime_RayNativeRuntime.cc | 6 +- ...io_ray_runtime_task_NativeTaskSubmitter.cc | 1 + src/ray/core_worker/task_manager.cc | 66 ++++++++++++------- src/ray/core_worker/task_manager.h | 16 ++++- src/ray/core_worker/test/core_worker_test.cc | 3 + .../test/direct_actor_transport_test.cc | 3 + .../test/direct_task_transport_test.cc | 61 ++++++++++++++++- src/ray/core_worker/test/task_manager_test.cc | 5 +- .../transport/direct_actor_transport.cc | 7 +- .../transport/direct_actor_transport.h | 3 +- .../transport/direct_task_transport.cc | 6 +- src/ray/protobuf/common.proto | 2 + src/ray/protobuf/core_worker.proto | 2 + 26 files changed, 296 insertions(+), 59 deletions(-) diff --git a/cpp/src/ray/runtime/task/native_task_submitter.cc b/cpp/src/ray/runtime/task/native_task_submitter.cc index 962e112a2..35d4a4475 100644 --- a/cpp/src/ray/runtime/task/native_task_submitter.cc +++ b/cpp/src/ray/runtime/task/native_task_submitter.cc @@ -42,8 +42,8 @@ ObjectID NativeTaskSubmitter::Submit(InvocationSpec &invocation, invocation.args, options, &return_ids); } else { core_worker.SubmitTask(BuildRayFunction(invocation), invocation.args, options, - &return_ids, 1, std::make_pair(PlacementGroupID::Nil(), -1), - true, ""); + &return_ids, 1, false, + std::make_pair(PlacementGroupID::Nil(), -1), true, ""); } return return_ids[0]; } diff --git a/cpp/src/ray/runtime/task/task_executor.cc b/cpp/src/ray/runtime/task/task_executor.cc index 585976ef5..ba9d62f06 100644 --- a/cpp/src/ray/runtime/task/task_executor.cc +++ b/cpp/src/ray/runtime/task/task_executor.cc @@ -126,7 +126,8 @@ Status TaskExecutor::ExecuteTask( const std::vector &arg_reference_ids, const std::vector &return_ids, const std::string &debugger_breakpoint, std::vector> *results, - std::shared_ptr &creation_task_exception_pb_bytes) { + std::shared_ptr &creation_task_exception_pb_bytes, + bool *is_application_level_error) { RAY_LOG(INFO) << "Execute task: " << TaskType_Name(task_type); RAY_CHECK(ray_function.GetLanguage() == ray::Language::CPP); auto function_descriptor = ray_function.GetFunctionDescriptor(); diff --git a/cpp/src/ray/runtime/task/task_executor.h b/cpp/src/ray/runtime/task/task_executor.h index 8e0613ad0..5714ef9cf 100644 --- a/cpp/src/ray/runtime/task/task_executor.h +++ b/cpp/src/ray/runtime/task/task_executor.h @@ -79,7 +79,8 @@ class TaskExecutor { const std::vector &arg_reference_ids, const std::vector &return_ids, const std::string &debugger_breakpoint, std::vector> *results, - std::shared_ptr &creation_task_exception_pb_bytes); + std::shared_ptr &creation_task_exception_pb_bytes, + bool *is_application_level_error); virtual ~TaskExecutor(){}; diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 7a906b9a6..dad0d0c78 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -406,7 +406,10 @@ cdef execute_task( const c_vector[CObjectID] &c_arg_reference_ids, const c_vector[CObjectID] &c_return_ids, const c_string debugger_breakpoint, - c_vector[shared_ptr[CRayObject]] *returns): + c_vector[shared_ptr[CRayObject]] *returns, + c_bool *is_application_level_error): + + is_application_level_error[0] = False worker = ray.worker.global_worker manager = worker.function_actor_manager @@ -579,6 +582,9 @@ cdef execute_task( except KeyboardInterrupt as e: raise TaskCancelledError( core_worker.get_current_task_id()) + except Exception as e: + is_application_level_error[0] = True + raise e if c_return_ids.size() == 1: outputs = (outputs,) # Check for a cancellation that was called when the function @@ -656,7 +662,8 @@ cdef CRayStatus task_execution_handler( const c_vector[CObjectID] &c_return_ids, const c_string debugger_breakpoint, c_vector[shared_ptr[CRayObject]] *returns, - shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes) nogil: + shared_ptr[LocalMemoryBuffer] &creation_task_exception_pb_bytes, + c_bool *is_application_level_error) nogil: with gil, disable_client_hook(): try: try: @@ -664,7 +671,8 @@ cdef CRayStatus task_execution_handler( # it does, that indicates that there was an internal error. execute_task(task_type, task_name, ray_function, c_resources, c_args, c_arg_reference_ids, c_return_ids, - debugger_breakpoint, returns) + debugger_breakpoint, returns, + is_application_level_error) except Exception as e: sys_exit = SystemExit() if isinstance(e, RayActorError) and \ @@ -1318,6 +1326,7 @@ cdef class CoreWorker: int num_returns, resources, int max_retries, + c_bool retry_exceptions, PlacementGroupID placement_group_id, int64_t placement_group_bundle_index, c_bool placement_group_capture_child_tasks, @@ -1354,7 +1363,7 @@ cdef class CoreWorker: b"", c_serialized_runtime_env, c_override_environment_variables), - &return_ids, max_retries, + &return_ids, max_retries, retry_exceptions, c_pair[CPlacementGroupID, int64_t]( c_placement_group_id, placement_group_bundle_index), placement_group_capture_child_tasks, diff --git a/python/ray/cross_language.py b/python/ray/cross_language.py index ebeaa1228..eae3939ff 100644 --- a/python/ray/cross_language.py +++ b/python/ray/cross_language.py @@ -77,7 +77,8 @@ def java_function(class_name, function_name): None, # accelerator_type, None, # num_returns, None, # max_calls, - None, # max_retries + None, # max_retries, + None, # retry_exceptions, None) # runtime_env diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index bf40ff850..90944251c 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -104,6 +104,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const c_vector[unique_ptr[CTaskArg]] &args, const CTaskOptions &options, c_vector[CObjectID] *return_ids, int max_retries, + c_bool retry_exceptions, c_pair[CPlacementGroupID, int64_t] placement_options, c_bool placement_group_capture_child_tasks, c_string debugger_breakpoint) @@ -280,7 +281,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const c_string debugger_breakpoint, c_vector[shared_ptr[CRayObject]] *returns, shared_ptr[LocalMemoryBuffer] - &creation_task_exception_pb_bytes) nogil + &creation_task_exception_pb_bytes, + c_bool *is_application_level_error) nogil ) task_execution_callback (void(const CWorkerID &) nogil) on_worker_shutdown (CRayStatus() nogil) check_signals diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index a06c05bb5..81a1b5665 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -24,6 +24,7 @@ DEFAULT_REMOTE_FUNCTION_MAX_CALLS = 0 # Normal tasks may be retried on failure this many times. # TODO(swang): Allow this to be set globally for an application. DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES = 3 +DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS = False logger = logging.getLogger(__name__) @@ -53,6 +54,9 @@ class RemoteFunction: of this remote function. _max_calls: The number of times a worker can execute this function before exiting. + _max_retries: The number of times this task may be retried + on worker failure. + _retry_exceptions: Whether application-level errors should be retried. _runtime_env: The runtime environment for this task. _decorator: An optional decorator that should be applied to the remote function invocation (as opposed to the function execution) before @@ -73,7 +77,7 @@ class RemoteFunction: def __init__(self, language, function, function_descriptor, num_cpus, num_gpus, memory, object_store_memory, resources, accelerator_type, num_returns, max_calls, max_retries, - runtime_env): + retry_exceptions, runtime_env): if inspect.iscoroutinefunction(function): raise ValueError("'async def' should not be used for remote " "tasks. You can wrap the async function with " @@ -100,6 +104,9 @@ class RemoteFunction: if max_calls is None else max_calls) self._max_retries = (DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES if max_retries is None else max_retries) + self._retry_exceptions = (DEFAULT_REMOTE_FUNCTION_RETRY_EXCEPTIONS + if retry_exceptions is None else + retry_exceptions) self._runtime_env = runtime_env self._decorator = getattr(function, "__ray_invocation_decorator__", None) @@ -131,6 +138,7 @@ class RemoteFunction: accelerator_type=None, resources=None, max_retries=None, + retry_exceptions=None, placement_group="default", placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, @@ -168,6 +176,7 @@ class RemoteFunction: accelerator_type=accelerator_type, resources=resources, max_retries=max_retries, + retry_exceptions=retry_exceptions, placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_index, placement_group_capture_child_tasks=( @@ -191,6 +200,7 @@ class RemoteFunction: accelerator_type=None, resources=None, max_retries=None, + retry_exceptions=None, placement_group="default", placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, @@ -211,6 +221,7 @@ class RemoteFunction: accelerator_type=accelerator_type, resources=resources, max_retries=max_retries, + retry_exceptions=retry_exceptions, placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_index, placement_group_capture_child_tasks=( @@ -251,6 +262,8 @@ class RemoteFunction: num_returns = self._num_returns if max_retries is None: max_retries = self._max_retries + if retry_exceptions is None: + retry_exceptions = self._retry_exceptions if placement_group_capture_child_tasks is None: placement_group_capture_child_tasks = ( @@ -307,6 +320,7 @@ class RemoteFunction: num_returns, resources, max_retries, + retry_exceptions, placement_group.id, placement_group_bundle_index, placement_group_capture_child_tasks, diff --git a/python/ray/tests/test_failure_4.py b/python/ray/tests/test_failure_4.py index 35c86a480..e6f137ea1 100644 --- a/python/ray/tests/test_failure_4.py +++ b/python/ray/tests/test_failure_4.py @@ -10,6 +10,68 @@ from ray._private.test_utils import (init_error_pubsub, get_error_message, run_string_as_driver) +def test_retry_system_level_error(ray_start_regular): + @ray.remote + class Counter: + def __init__(self): + self.value = 0 + + def increment(self): + self.value += 1 + return self.value + + @ray.remote(max_retries=1) + def func(counter): + count = counter.increment.remote() + if ray.get(count) == 1: + import os + os._exit(0) + else: + return 1 + + counter1 = Counter.remote() + r1 = func.remote(counter1) + assert ray.get(r1) == 1 + + counter2 = Counter.remote() + r2 = func.options(max_retries=0).remote(counter2) + with pytest.raises(ray.exceptions.WorkerCrashedError): + ray.get(r2) + + +def test_retry_application_level_error(ray_start_regular): + @ray.remote + class Counter: + def __init__(self): + self.value = 0 + + def increment(self): + self.value += 1 + return self.value + + @ray.remote(max_retries=1, retry_exceptions=True) + def func(counter): + count = counter.increment.remote() + if ray.get(count) == 1: + raise ValueError() + else: + return 2 + + counter1 = Counter.remote() + r1 = func.remote(counter1) + assert ray.get(r1) == 2 + + counter2 = Counter.remote() + r2 = func.options(max_retries=0).remote(counter2) + with pytest.raises(ValueError): + ray.get(r2) + + counter3 = Counter.remote() + r3 = func.options(retry_exceptions=False).remote(counter3) + with pytest.raises(ValueError): + ray.get(r3) + + def test_connect_with_disconnected_node(shutdown_only): config = { "num_heartbeats_timeout": 50, diff --git a/python/ray/util/client/options.py b/python/ray/util/client/options.py index 9fc5c233f..a0d932352 100644 --- a/python/ray/util/client/options.py +++ b/python/ray/util/client/options.py @@ -22,6 +22,7 @@ options = { "max_retries": (int, lambda x: x >= -1, "The keyword 'max_retries' only accepts 0, -1 " "or a positive integer"), + "retry_exceptions": (), "max_concurrency": (), "name": (), "namespace": (), diff --git a/python/ray/worker.py b/python/ray/worker.py index ad9894492..6887031e7 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -651,7 +651,7 @@ def init( a raylet, a plasma store, a plasma manager, and some workers. It will also kill these processes when Python exits. If the driver is running on a node in a Ray cluster, using `auto` as the value - tells the driver to detect the the cluster, removing the need to + tells the driver to detect the cluster, removing the need to specify a specific node address. If the environment variable `RAY_ADDRESS` is defined and the address is None or "auto", Ray will set `address` to `RAY_ADDRESS`. @@ -1924,7 +1924,8 @@ def make_decorator(num_returns=None, max_restarts=None, max_task_retries=None, runtime_env=None, - worker=None): + worker=None, + retry_exceptions=None): def decorator(function_or_class): if (inspect.isfunction(function_or_class) or is_cython(function_or_class)): @@ -1953,12 +1954,19 @@ def make_decorator(num_returns=None, return ray.remote_function.RemoteFunction( Language.PYTHON, function_or_class, None, num_cpus, num_gpus, memory, object_store_memory, resources, accelerator_type, - num_returns, max_calls, max_retries, runtime_env) + num_returns, max_calls, max_retries, retry_exceptions, + runtime_env) if inspect.isclass(function_or_class): if num_returns is not None: raise TypeError("The keyword 'num_returns' is not " "allowed for actors.") + if max_retries is not None: + raise TypeError("The keyword 'max_retries' is not " + "allowed for actors.") + if retry_exceptions is not None: + raise TypeError("The keyword 'retry_exceptions' is not " + "allowed for actors.") if max_calls is not None: raise TypeError("The keyword 'max_calls' is not " "allowed for actors.") @@ -2082,6 +2090,9 @@ def remote(*args, **kwargs): this actor or task and its children. See :ref:`runtime-environments` for detailed documentation. This API is in beta and may change before becoming stable. + retry_exceptions (bool): Only for *remote functions*. This specifies + whether application-level errors should be retried + up to max_retries times. override_environment_variables (Dict[str, str]): (Deprecated in Ray 1.4.0, will be removed in Ray 1.6--please use the ``env_vars`` field of :ref:`runtime-environments` instead.) This specifies @@ -2102,7 +2113,7 @@ def remote(*args, **kwargs): valid_kwargs = [ "num_returns", "num_cpus", "num_gpus", "memory", "object_store_memory", "resources", "accelerator_type", "max_calls", "max_restarts", - "max_task_retries", "max_retries", "runtime_env" + "max_task_retries", "max_retries", "runtime_env", "retry_exceptions" ] error_string = ("The @ray.remote decorator must be applied either " "with no arguments and no parentheses, for example " @@ -2135,6 +2146,7 @@ def remote(*args, **kwargs): object_store_memory = kwargs.get("object_store_memory") max_retries = kwargs.get("max_retries") runtime_env = kwargs.get("runtime_env") + retry_exceptions = kwargs.get("retry_exceptions") return make_decorator( num_returns=num_returns, @@ -2149,4 +2161,5 @@ def remote(*args, **kwargs): max_task_retries=max_task_retries, max_retries=max_retries, runtime_env=runtime_env, - worker=worker) + worker=worker, + retry_exceptions=retry_exceptions) diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index 422983179..04c33e0f2 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -134,6 +134,12 @@ class TaskSpecBuilder { return *this; } + TaskSpecBuilder &SetNormalTaskSpec(int max_retries, bool retry_exceptions) { + message_->set_max_retries(max_retries); + message_->set_retry_exceptions(retry_exceptions); + return *this; + } + /// Set the driver attributes of the task spec. /// See `common.proto` for meaning of the arguments. /// diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 08ce3fae2..81d4edba5 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -385,9 +385,9 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ // Initialize task receivers. if (options_.worker_type == WorkerType::WORKER || options_.is_local_mode) { RAY_CHECK(options_.task_execution_callback != nullptr); - auto execute_task = - std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1, - std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); + auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, + std::placeholders::_4, std::placeholders::_5); direct_task_receiver_ = std::make_unique( worker_context_, task_execution_service_, execute_task, [this] { return local_raylet_client_->TaskDone(); }); @@ -565,6 +565,10 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ }, "CoreWorker.ReconstructObject"); }; + auto push_error_callback = [this](const JobID &job_id, const std::string &type, + const std::string &error_message, double timestamp) { + return PushError(job_id, type, error_message, timestamp); + }; task_manager_.reset(new TaskManager( memory_store_, reference_counter_, /* retry_task_callback= */ @@ -589,7 +593,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ } } }, - check_node_alive_fn, reconstruct_object_callback)); + check_node_alive_fn, reconstruct_object_callback, push_error_callback)); // Create an entry for the driver task in the task table. This task is // added immediately with status RUNNING. This allows us to push errors @@ -1643,7 +1647,7 @@ void CoreWorker::SubmitTask(const RayFunction &function, const std::vector> &args, const TaskOptions &task_options, std::vector *return_ids, int max_retries, - BundleID placement_options, + bool retry_exceptions, BundleID placement_options, bool placement_group_capture_child_tasks, const std::string &debugger_breakpoint) { TaskSpecBuilder builder; @@ -1674,6 +1678,7 @@ void CoreWorker::SubmitTask(const RayFunction &function, placement_options, placement_group_capture_child_tasks, debugger_breakpoint, task_options.serialized_runtime_env, override_environment_variables); + builder.SetNormalTaskSpec(max_retries, retry_exceptions); TaskSpecification task_spec = builder.Build(); RAY_LOG(DEBUG) << "Submit task " << task_spec.DebugString(); if (options_.is_local_mode) { @@ -2146,7 +2151,8 @@ Status CoreWorker::AllocateReturnObject(const ObjectID &object_id, Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, const std::shared_ptr &resource_ids, std::vector> *return_objects, - ReferenceCounter::ReferenceTableProto *borrowed_refs) { + ReferenceCounter::ReferenceTableProto *borrowed_refs, + bool *is_application_level_error) { RAY_LOG(DEBUG) << "Executing task, task info = " << task_spec.DebugString(); task_queue_length_ -= 1; num_executed_tasks_ += 1; @@ -2212,7 +2218,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, task_type, task_spec.GetName(), func, task_spec.GetRequiredResources().GetResourceMap(), args, arg_reference_ids, return_ids, task_spec.GetDebuggerBreakpoint(), return_objects, - creation_task_exception_pb_bytes); + creation_task_exception_pb_bytes, is_application_level_error); // Get the reference counts for any IDs that we borrowed during this task and // return them to the caller. This will notify the caller of any IDs that we @@ -2301,7 +2307,9 @@ void CoreWorker::ExecuteTaskLocalMode(const TaskSpecification &task_spec, } auto old_id = GetActorId(); SetActorId(actor_id); - RAY_UNUSED(ExecuteTask(task_spec, resource_ids, &return_objects, &borrowed_refs)); + bool is_application_level_error; + RAY_UNUSED(ExecuteTask(task_spec, resource_ids, &return_objects, &borrowed_refs, + &is_application_level_error)); SetActorId(old_id); } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 1e6f53cd0..11137265f 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -70,7 +70,8 @@ struct CoreWorkerOptions { const std::vector &arg_reference_ids, const std::vector &return_ids, const std::string &debugger_breakpoint, std::vector> *results, - std::shared_ptr &creation_task_exception_pb_bytes)>; + std::shared_ptr &creation_task_exception_pb_bytes, + bool *is_application_level_error)>; CoreWorkerOptions() : store_socket(""), @@ -712,7 +713,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { void SubmitTask(const RayFunction &function, const std::vector> &args, const TaskOptions &task_options, std::vector *return_ids, - int max_retries, BundleID placement_options, + int max_retries, bool retry_exceptions, BundleID placement_options, bool placement_group_capture_child_tasks, const std::string &debugger_breakpoint); @@ -1122,7 +1123,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { Status ExecuteTask(const TaskSpecification &task_spec, const std::shared_ptr &resource_ids, std::vector> *return_objects, - ReferenceCounter::ReferenceTableProto *borrowed_refs); + ReferenceCounter::ReferenceTableProto *borrowed_refs, + bool *is_application_level_error); /// Execute a local mode task (runs normal ExecuteTask) /// diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index 64e0790ce..21aea5cc7 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -100,7 +100,11 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( const std::vector &arg_reference_ids, const std::vector &return_ids, const std::string &debugger_breakpoint, std::vector> *results, - std::shared_ptr &creation_task_exception_pb) { + std::shared_ptr &creation_task_exception_pb, + bool *is_application_level_error) { + // TODO(jjyao): Support retrying application-level errors for Java + *is_application_level_error = false; + JNIEnv *env = GetJNIEnv(); RAY_CHECK(java_task_executor); diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index a5aed2616..d3c6b9882 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -308,6 +308,7 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSub CoreWorkerProcess::GetCoreWorker().SubmitTask( ray_function, task_args, task_options, &return_ids, /*max_retries=*/0, + /*retry_exceptions=*/false, /*placement_options=*/placement_group_options, /*placement_group_capture_child_tasks=*/true, /*debugger_breakpoint*/ ""); diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index e7e762a80..72acdc87f 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -266,7 +266,7 @@ void TaskManager::CompletePendingTask(const TaskID &task_id, it->second.pending = false; num_pending_tasks_--; - // A finished task can be only be re-executed if it has some number of + // A finished task can only be re-executed if it has some number of // retries left and returned at least one object that is still in use and // stored in plasma. bool task_retryable = it->second.num_retries_left != 0 && @@ -284,6 +284,45 @@ void TaskManager::CompletePendingTask(const TaskID &task_id, ShutdownIfNeeded(); } +bool TaskManager::RetryTaskIfPossible(const TaskID &task_id) { + int num_retries_left = 0; + TaskSpecification spec; + { + absl::MutexLock lock(&mu_); + auto it = submissible_tasks_.find(task_id); + RAY_CHECK(it != submissible_tasks_.end()) + << "Tried to retry task that was not pending " << task_id; + RAY_CHECK(it->second.pending) + << "Tried to retry task that was not pending " << task_id; + spec = it->second.spec; + num_retries_left = it->second.num_retries_left; + if (num_retries_left > 0) { + it->second.num_retries_left--; + } else { + RAY_CHECK(num_retries_left == 0 || num_retries_left == -1); + } + } + + // We should not hold the lock during these calls because they may trigger + // callbacks in this or other classes. + if (num_retries_left != 0) { + auto timestamp = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + std::ostringstream stream; + auto num_retries_left_str = + num_retries_left == -1 ? "infinite" : std::to_string(num_retries_left); + stream << num_retries_left_str << " retries left for task " << spec.TaskId() + << ", attempting to resubmit."; + RAY_CHECK_OK( + push_error_callback_(spec.JobId(), "retry_task", stream.str(), timestamp)); + retry_task_callback_(spec, /*delay=*/true); + return true; + } else { + return false; + } +} + bool TaskManager::PendingTaskFailed( const TaskID &task_id, rpc::ErrorType error_type, Status *status, const std::shared_ptr &creation_task_exception, @@ -292,9 +331,9 @@ bool TaskManager::PendingTaskFailed( // loudly with ERROR here. RAY_LOG(DEBUG) << "Task " << task_id << " failed with error " << rpc::ErrorType_Name(error_type); - int num_retries_left = 0; + const bool will_retry = RetryTaskIfPossible(task_id); + const bool release_lineage = !will_retry; TaskSpecification spec; - bool release_lineage = true; { absl::MutexLock lock(&mu_); auto it = submissible_tasks_.find(task_id); @@ -303,30 +342,13 @@ bool TaskManager::PendingTaskFailed( RAY_CHECK(it->second.pending) << "Tried to complete task that was not pending " << task_id; spec = it->second.spec; - num_retries_left = it->second.num_retries_left; - if (num_retries_left == 0) { + if (!will_retry) { submissible_tasks_.erase(it); num_pending_tasks_--; - } else if (num_retries_left == -1) { - release_lineage = false; - } else { - RAY_CHECK(num_retries_left > 0); - it->second.num_retries_left--; - release_lineage = false; } } - bool will_retry = false; - // We should not hold the lock during these calls because they may trigger - // callbacks in this or other classes. - if (num_retries_left != 0) { - auto retries_str = - num_retries_left == -1 ? "infinite" : std::to_string(num_retries_left); - RAY_LOG(INFO) << retries_str << " retries left for task " << spec.TaskId() - << ", attempting to resubmit."; - retry_task_callback_(spec, /*delay=*/true); - will_retry = true; - } else { + if (!will_retry) { // Throttled logging of task failure errors. { absl::MutexLock lock(&mu_); diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index e17c9b7b2..986900352 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -31,6 +31,8 @@ class TaskFinisherInterface { virtual void CompletePendingTask(const TaskID &task_id, const rpc::PushTaskReply &reply, const rpc::Address &actor_addr) = 0; + virtual bool RetryTaskIfPossible(const TaskID &task_id) = 0; + virtual bool PendingTaskFailed( const TaskID &task_id, rpc::ErrorType error_type, Status *status, const std::shared_ptr &creation_task_exception = nullptr, @@ -61,6 +63,9 @@ class TaskResubmissionInterface { using RetryTaskCallback = std::function; using ReconstructObjectCallback = std::function; +using PushErrorCallback = + std::function; class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterface { public: @@ -68,12 +73,14 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa std::shared_ptr reference_counter, RetryTaskCallback retry_task_callback, const std::function &check_node_alive, - ReconstructObjectCallback reconstruct_object_callback) + ReconstructObjectCallback reconstruct_object_callback, + PushErrorCallback push_error_callback) : in_memory_store_(in_memory_store), reference_counter_(reference_counter), retry_task_callback_(retry_task_callback), check_node_alive_(check_node_alive), - reconstruct_object_callback_(reconstruct_object_callback) { + reconstruct_object_callback_(reconstruct_object_callback), + push_error_callback_(push_error_callback) { reference_counter_->SetReleaseLineageCallback( [this](const ObjectID &object_id, std::vector *ids_to_release) { RemoveLineageReference(object_id, ids_to_release); @@ -118,6 +125,8 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa void CompletePendingTask(const TaskID &task_id, const rpc::PushTaskReply &reply, const rpc::Address &worker_addr) override; + bool RetryTaskIfPossible(const TaskID &task_id) override; + /// A pending task failed. This will either retry the task or mark the task /// as failed if there are no retries left. /// @@ -266,6 +275,9 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// recoverable). const ReconstructObjectCallback reconstruct_object_callback_; + // Called to push an error to the relevant driver. + const PushErrorCallback push_error_callback_; + // The number of task failures we have logged total. int64_t num_failure_logs_ GUARDED_BY(mu_) = 0; diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 6330f80d6..c6edc7176 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -249,6 +249,7 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map &res TaskOptions options; std::vector return_ids; driver.SubmitTask(func, args, options, &return_ids, /*max_retries=*/0, + /*retry_exceptions=*/false, std::make_pair(PlacementGroupID::Nil(), -1), true, /*debugger_breakpoint=*/""); @@ -855,12 +856,14 @@ TEST_F(SingleNodeTest, TestCancelTasks) { // Submit func1. The function should start looping forever. driver.SubmitTask(func1, args, options, &return_ids1, /*max_retries=*/0, + /*retry_exceptions=*/false, std::make_pair(PlacementGroupID::Nil(), -1), true, /*debugger_breakpoint=*/""); ASSERT_EQ(return_ids1.size(), 1); // Submit func2. The function should be queued at the worker indefinitely. driver.SubmitTask(func2, args, options, &return_ids2, /*max_retries=*/0, + /*retry_exceptions=*/false, std::make_pair(PlacementGroupID::Nil(), -1), true, /*debugger_breakpoint=*/""); ASSERT_EQ(return_ids2.size(), 1); diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index 1f367190b..38534ae17 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -91,6 +91,9 @@ class MockTaskFinisher : public TaskFinisherInterface { MOCK_METHOD3(CompletePendingTask, void(const TaskID &, const rpc::PushTaskReply &, const rpc::Address &addr)); + + MOCK_METHOD1(RetryTaskIfPossible, bool(const TaskID &task_id)); + MOCK_METHOD5(PendingTaskFailed, bool(const TaskID &task_id, rpc::ErrorType error_type, Status *status, const std::shared_ptr &creation_task_exception, diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index facea3f63..b5b24a278 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -64,8 +64,8 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface { return true; } - bool ReplyPushTask(Status status = Status::OK(), bool exit = false, - bool stolen = false) { + bool ReplyPushTask(Status status = Status::OK(), bool exit = false, bool stolen = false, + bool is_application_level_error = false) { if (callbacks.size() == 0) { return false; } @@ -77,6 +77,9 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface { if (stolen) { reply.set_task_stolen(true); } + if (is_application_level_error) { + reply.set_is_application_level_error(true); + } callback(status, reply); callbacks.pop_front(); return true; @@ -101,6 +104,11 @@ class MockTaskFinisher : public TaskFinisherInterface { num_tasks_complete++; } + bool RetryTaskIfPossible(const TaskID &task_id) override { + num_task_retries_attempted++; + return false; + } + bool PendingTaskFailed( const TaskID &task_id, rpc::ErrorType error_type, Status *status, const std::shared_ptr &creation_task_exception = nullptr, @@ -131,6 +139,7 @@ class MockTaskFinisher : public TaskFinisherInterface { int num_tasks_failed = 0; int num_inlined_dependencies = 0; int num_contained_ids = 0; + int num_task_retries_attempted = 0; }; class MockRayletClient : public WorkerLeaseInterface { @@ -441,6 +450,54 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) { ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(task_finisher->num_tasks_complete, 1); ASSERT_EQ(task_finisher->num_tasks_failed, 0); + ASSERT_EQ(task_finisher->num_task_retries_attempted, 0); + ASSERT_EQ(raylet_client->num_leases_canceled, 0); + ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease()); + + // Check that there are no entries left in the scheduling_key_entries_ hashmap. These + // would otherwise cause a memory leak. + ASSERT_TRUE(submitter.CheckNoSchedulingKeyEntriesPublic()); +} + +TEST(DirectTaskTransportTest, TestRetryTaskApplicationLevelError) { + rpc::Address address; + auto raylet_client = std::make_shared(); + auto worker_client = std::make_shared(); + auto store = std::make_shared(); + auto client_pool = std::make_shared( + [&](const rpc::Address &addr) { return worker_client; }); + auto task_finisher = std::make_shared(); + auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); + TaskSpecification task = BuildEmptyTaskSpec(); + task.GetMutableMessage().set_retry_exceptions(true); + + ASSERT_TRUE(submitter.SubmitTask(task).ok()); + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, NodeID::Nil())); + // Simulate an application-level error. + ASSERT_TRUE(worker_client->ReplyPushTask(Status::OK(), false, false, true)); + ASSERT_EQ(raylet_client->num_workers_returned, 1); + ASSERT_EQ(raylet_client->num_workers_disconnected, 0); + ASSERT_EQ(task_finisher->num_tasks_complete, 1); + ASSERT_EQ(task_finisher->num_task_retries_attempted, 1); + ASSERT_EQ(task_finisher->num_tasks_failed, 0); + ASSERT_EQ(raylet_client->num_leases_canceled, 0); + ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease()); + + task.GetMutableMessage().set_retry_exceptions(false); + + ASSERT_TRUE(submitter.SubmitTask(task).ok()); + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, NodeID::Nil())); + // Simulate an application-level error. + ASSERT_TRUE(worker_client->ReplyPushTask(Status::OK(), false, false, true)); + ASSERT_EQ(raylet_client->num_workers_returned, 2); + ASSERT_EQ(raylet_client->num_workers_disconnected, 0); + ASSERT_EQ(task_finisher->num_tasks_complete, 2); + ASSERT_EQ(task_finisher->num_task_retries_attempted, 1); + ASSERT_EQ(task_finisher->num_tasks_failed, 0); ASSERT_EQ(raylet_client->num_leases_canceled, 0); ASSERT_FALSE(raylet_client->ReplyCancelWorkerLease()); diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index 6462d5154..e1a75b72b 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -55,7 +55,10 @@ class TaskManagerTest : public ::testing::Test { [this](const NodeID &node_id) { return all_nodes_alive_; }, [this](const ObjectID &object_id) { objects_to_recover_.push_back(object_id); - }) {} + }, + [](const JobID &job_id, const std::string &type, + const std::string &error_message, + double timestamp) { return Status::OK(); }) {} std::shared_ptr store_; std::shared_ptr publisher_; diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index c37495802..7be01da15 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -487,8 +487,11 @@ void CoreWorkerDirectTaskReceiver::HandleTask( RAY_CHECK(num_returns >= 0); std::vector> return_objects; - auto status = task_handler_(task_spec, resource_ids, &return_objects, - reply->mutable_borrowed_refs()); + bool is_application_level_error = false; + auto status = + task_handler_(task_spec, resource_ids, &return_objects, + reply->mutable_borrowed_refs(), &is_application_level_error); + reply->set_is_application_level_error(is_application_level_error); bool objects_valid = return_objects.size() == num_returns; if (objects_valid) { diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index fd71c8c4c..edc7c18af 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -781,7 +781,8 @@ class CoreWorkerDirectTaskReceiver { std::function resource_ids, std::vector> *return_objects, - ReferenceCounter::ReferenceTableProto *borrower_refs)>; + ReferenceCounter::ReferenceTableProto *borrower_refs, + bool *is_application_level_error)>; using OnTaskDone = std::function; diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index 0f57a27dc..e1338f957 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -638,7 +638,11 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask( is_actor ? rpc::ErrorType::ACTOR_DIED : rpc::ErrorType::WORKER_DIED, &status)); } else { - task_finisher_->CompletePendingTask(task_id, reply, addr.ToProto()); + if (!task_spec.GetMessage().retry_exceptions() || + !reply.is_application_level_error() || + !task_finisher_->RetryTaskIfPossible(task_id)) { + task_finisher_->CompletePendingTask(task_id, reply, addr.ToProto()); + } } }); } diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 3d092d224..93d4108ec 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -213,6 +213,8 @@ message TaskSpec { string serialized_runtime_env = 24; // The concurrency group name in which this task will be performed. string concurrency_group_name = 25; + // Whether application-level errors (exceptions) should be retried. + bool retry_exceptions = 26; } message Bundle { diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 2c58c8ac5..87f5739fd 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -124,6 +124,8 @@ message PushTaskReply { // may now be borrowing. The reference counts also include any new borrowers // that the worker created by passing a borrowed ID into a nested task. repeated ObjectReferenceCount borrowed_refs = 4; + // Whether the result contains an application-level error (exception). + bool is_application_level_error = 5; } message DirectActorCallArgWaitCompleteRequest {