[Core][CoreWorker] Make WorkerContext thread safe, fix race condition. #19343

Why are these changes needed?
The theory around #19270 is there are two create actor requests sent to the same threaded actor due to retry logic. Specifically:

the first request comes and calls CoreWorkerDirectTaskReceiver::HandleTask, it's queued to be executed by thread pool;
then the second request comes and calls CoreWorkerDirectTaskReceiver::HandleTask again, before first request being executed and calls worker_context_.SetCurrentTask;
this fails the current dedupe logic and leads to SetMaxActorConcurrency be called twice, which fails the RAY_CHECK.
In this PR, we fix the dedupe logic by adding SetCurrentActorId and calling it in the task execution thread. this ensures the dedupe logic works for threaded actor.

we also noticed that the WorkerContext is actually not thread safe in threaded actors, thus make it thread safe in this PR as well.

Related issue number
Closes #19270

Checks
 I've run scripts/format.sh to lint the changes in this PR.
 I've included any doc changes needed for https://docs.ray.io/en/master/.
 I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
Testing Strategy
 Unit tests
 Release tests
 This PR is not tested :(
This commit is contained in:
Chen Shen 2021-10-13 16:12:36 -07:00 committed by GitHub
parent b86a5fcb96
commit b8c201b7cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 24 deletions

View file

@ -122,7 +122,8 @@ WorkerContext::WorkerContext(WorkerType worker_type, const WorkerID &worker_id,
current_actor_id_(ActorID::Nil()),
current_actor_placement_group_id_(PlacementGroupID::Nil()),
placement_group_capture_child_tasks_(false),
main_thread_id_(boost::this_thread::get_id()) {
main_thread_id_(boost::this_thread::get_id()),
mutex_() {
// For worker main thread which initializes the WorkerContext,
// set task_id according to whether current worker is a driver.
// (For other threads it's set to random ID via GetThreadContext).
@ -150,6 +151,7 @@ const TaskID &WorkerContext::GetCurrentTaskID() const {
}
const PlacementGroupID &WorkerContext::GetCurrentPlacementGroupId() const {
absl::ReaderMutexLock lock(&mutex_);
// If the worker is an actor, we should return the actor's placement group id.
if (current_actor_id_ != ActorID::Nil()) {
return current_actor_placement_group_id_;
@ -159,6 +161,7 @@ const PlacementGroupID &WorkerContext::GetCurrentPlacementGroupId() const {
}
bool WorkerContext::ShouldCaptureChildTasksInPlacementGroup() const {
absl::ReaderMutexLock lock(&mutex_);
// If the worker is an actor, we should return the actor's placement group id.
if (current_actor_id_ != ActorID::Nil()) {
return placement_group_capture_child_tasks_;
@ -168,6 +171,7 @@ bool WorkerContext::ShouldCaptureChildTasksInPlacementGroup() const {
}
const std::string &WorkerContext::GetCurrentSerializedRuntimeEnv() const {
absl::ReaderMutexLock lock(&mutex_);
return runtime_env_.serialized_runtime_env();
}
@ -175,7 +179,17 @@ void WorkerContext::SetCurrentTaskId(const TaskID &task_id) {
GetThreadContext().SetCurrentTaskId(task_id);
}
void WorkerContext::SetCurrentActorId(const ActorID &actor_id) LOCKS_EXCLUDED(mutex_) {
absl::WriterMutexLock lock(&mutex_);
if (!current_actor_id_.IsNil()) {
RAY_CHECK(current_actor_id_ == actor_id);
return;
}
current_actor_id_ = actor_id;
}
void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
absl::WriterMutexLock lock(&mutex_);
GetThreadContext().SetCurrentTask(task_spec);
RAY_CHECK(current_job_id_ == task_spec.JobId());
if (task_spec.IsNormalTask()) {
@ -185,7 +199,9 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
// never see a new one.
runtime_env_ = task_spec.RuntimeEnv();
} else if (task_spec.IsActorCreationTask()) {
RAY_CHECK(current_actor_id_.IsNil());
if (!current_actor_id_.IsNil()) {
RAY_CHECK(current_actor_id_ == task_spec.ActorCreationId());
}
current_actor_id_ = task_spec.ActorCreationId();
current_actor_is_direct_call_ = true;
current_actor_max_concurrency_ = task_spec.MaxActorConcurrency();
@ -207,7 +223,10 @@ std::shared_ptr<const TaskSpecification> WorkerContext::GetCurrentTask() const {
return GetThreadContext().GetCurrentTask();
}
const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; }
const ActorID &WorkerContext::GetCurrentActorID() const {
absl::ReaderMutexLock lock(&mutex_);
return current_actor_id_;
}
bool WorkerContext::CurrentThreadIsMain() const {
return boost::this_thread::get_id() == main_thread_id_;
@ -225,20 +244,29 @@ bool WorkerContext::ShouldReleaseResourcesOnBlockingCalls() const {
// TODO(edoakes): simplify these checks now that we only support direct call mode.
bool WorkerContext::CurrentActorIsDirectCall() const {
absl::ReaderMutexLock lock(&mutex_);
return current_actor_is_direct_call_;
}
bool WorkerContext::CurrentTaskIsDirectCall() const {
absl::ReaderMutexLock lock(&mutex_);
return current_task_is_direct_call_ || current_actor_is_direct_call_;
}
int WorkerContext::CurrentActorMaxConcurrency() const {
absl::ReaderMutexLock lock(&mutex_);
return current_actor_max_concurrency_;
}
bool WorkerContext::CurrentActorIsAsync() const { return current_actor_is_asyncio_; }
bool WorkerContext::CurrentActorIsAsync() const {
absl::ReaderMutexLock lock(&mutex_);
return current_actor_is_asyncio_;
}
bool WorkerContext::CurrentActorDetached() const { return is_detached_actor_; }
bool WorkerContext::CurrentActorDetached() const {
absl::ReaderMutexLock lock(&mutex_);
return is_detached_actor_;
}
WorkerThreadContext &WorkerContext::GetThreadContext() {
if (thread_context_ == nullptr) {

View file

@ -16,6 +16,8 @@
#include <boost/thread.hpp>
#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
#include "ray/common/task/task_spec.h"
#include "ray/core_worker/common.h"
@ -36,22 +38,24 @@ class WorkerContext {
const TaskID &GetCurrentTaskID() const;
const PlacementGroupID &GetCurrentPlacementGroupId() const;
const PlacementGroupID &GetCurrentPlacementGroupId() const LOCKS_EXCLUDED(mutex_);
bool ShouldCaptureChildTasksInPlacementGroup() const;
bool ShouldCaptureChildTasksInPlacementGroup() const LOCKS_EXCLUDED(mutex_);
const std::string &GetCurrentSerializedRuntimeEnv() const;
const std::string &GetCurrentSerializedRuntimeEnv() const LOCKS_EXCLUDED(mutex_);
// TODO(edoakes): remove this once Python core worker uses the task interfaces.
void SetCurrentTaskId(const TaskID &task_id);
void SetCurrentTask(const TaskSpecification &task_spec);
void SetCurrentActorId(const ActorID &actor_id) LOCKS_EXCLUDED(mutex_);
void SetCurrentTask(const TaskSpecification &task_spec) LOCKS_EXCLUDED(mutex_);
void ResetCurrentTask();
std::shared_ptr<const TaskSpecification> GetCurrentTask() const;
const ActorID &GetCurrentActorID() const;
const ActorID &GetCurrentActorID() const LOCKS_EXCLUDED(mutex_);
/// Returns whether the current thread is the main worker thread.
bool CurrentThreadIsMain() const;
@ -61,17 +65,17 @@ class WorkerContext {
bool ShouldReleaseResourcesOnBlockingCalls() const;
/// Returns whether we are in a direct call actor.
bool CurrentActorIsDirectCall() const;
bool CurrentActorIsDirectCall() const LOCKS_EXCLUDED(mutex_);
/// Returns whether we are in a direct call task. This encompasses both direct
/// actor and normal tasks.
bool CurrentTaskIsDirectCall() const;
bool CurrentTaskIsDirectCall() const LOCKS_EXCLUDED(mutex_);
int CurrentActorMaxConcurrency() const;
int CurrentActorMaxConcurrency() const LOCKS_EXCLUDED(mutex_);
bool CurrentActorIsAsync() const;
bool CurrentActorIsAsync() const LOCKS_EXCLUDED(mutex_);
bool CurrentActorDetached() const;
bool CurrentActorDetached() const LOCKS_EXCLUDED(mutex_);
uint64_t GetNextTaskIndex();
@ -86,19 +90,21 @@ class WorkerContext {
private:
const WorkerType worker_type_;
const WorkerID worker_id_;
JobID current_job_id_;
ActorID current_actor_id_;
int current_actor_max_concurrency_ = 1;
bool current_actor_is_asyncio_ = false;
bool is_detached_actor_ = false;
const JobID current_job_id_;
ActorID current_actor_id_ GUARDED_BY(mutex_);
int current_actor_max_concurrency_ GUARDED_BY(mutex_) = 1;
bool current_actor_is_asyncio_ GUARDED_BY(mutex_) = false;
bool is_detached_actor_ GUARDED_BY(mutex_) = false;
// The placement group id that the current actor belongs to.
PlacementGroupID current_actor_placement_group_id_;
PlacementGroupID current_actor_placement_group_id_ GUARDED_BY(mutex_);
// Whether or not we should implicitly capture parent's placement group.
bool placement_group_capture_child_tasks_;
bool placement_group_capture_child_tasks_ GUARDED_BY(mutex_);
// The runtime env for the current actor or task.
rpc::RuntimeEnv runtime_env_;
rpc::RuntimeEnv runtime_env_ GUARDED_BY(mutex_);
/// The id of the (main) thread that constructed this worker context.
boost::thread::id main_thread_id_;
const boost::thread::id main_thread_id_;
// To protect access to mutable members;
mutable absl::Mutex mutex_;
private:
static WorkerThreadContext &GetThreadContext();

View file

@ -460,6 +460,7 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
}
if (task_spec.IsActorCreationTask()) {
worker_context_.SetCurrentActorId(task_spec.ActorCreationId());
SetMaxActorConcurrency(task_spec.IsAsyncioActor(), task_spec.MaxActorConcurrency());
}