Fix asyncio actor race condition (#7335)

This commit is contained in:
Edward Oakes 2020-02-27 10:16:04 -08:00 committed by GitHub
parent 58073f7260
commit 55ccfb6089
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 114 additions and 110 deletions

View file

@ -1,3 +1,4 @@
import asyncio
import logging
import time
@ -300,12 +301,6 @@ def test_dynamic_res_concurrent_res_increment(ray_start_cluster):
num_nodes = 5
TIMEOUT_DURATION = 1
# Create a object ID to have the task wait on
WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii")
# Create a object ID to signal that the task is running
TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii")
for i in range(num_nodes):
cluster.add_node()
@ -325,29 +320,42 @@ def test_dynamic_res_concurrent_res_increment(ray_start_cluster):
# Task to hold the resource till the driver signals to finish
@ray.remote
def wait_func(running_oid, wait_oid):
# Signal that the task is running
ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid))
# Make the task wait till signalled by driver
ray.get(ray.ObjectID(wait_oid))
def wait_func(running_signal, finish_signal):
# Signal that the task is running.
ray.get(running_signal.send.remote())
# Wait until signaled by driver.
ray.get(finish_signal.wait.remote())
@ray.remote
def test_func():
return 1
@ray.remote(num_cpus=0)
class Signal:
def __init__(self):
self.ready_event = asyncio.Event()
def send(self):
self.ready_event.set()
async def wait(self):
await self.ready_event.wait()
running_signal = Signal.remote()
finish_signal = Signal.remote()
# Launch the task with resource requirement of 4, thus the new available
# capacity becomes 1
task = wait_func._remote(
args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR],
resources={res_name: 4})
# Wait till wait_func is launched before updating resource
ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR))
args=[running_signal, finish_signal], resources={res_name: 4})
# Wait until wait_func is launched before updating resource
ray.get(running_signal.wait.remote())
# Update the resource capacity
ray.get(set_res.remote(res_name, updated_capacity, target_node_id))
# Signal task to complete
ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR))
ray.get(finish_signal.send.remote())
ray.get(task)
# Check if scheduler state is consistent by launching a task requiring
@ -379,12 +387,6 @@ def test_dynamic_res_concurrent_res_decrement(ray_start_cluster):
num_nodes = 5
TIMEOUT_DURATION = 1
# Create a object ID to have the task wait on
WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii")
# Create a object ID to signal that the task is running
TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii")
for i in range(num_nodes):
cluster.add_node()
@ -404,29 +406,42 @@ def test_dynamic_res_concurrent_res_decrement(ray_start_cluster):
# Task to hold the resource till the driver signals to finish
@ray.remote
def wait_func(running_oid, wait_oid):
# Signal that the task is running
ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid))
# Make the task wait till signalled by driver
ray.get(ray.ObjectID(wait_oid))
def wait_func(running_signal, finish_signal):
# Signal that the task is running.
ray.get(running_signal.send.remote())
# Wait until signaled by driver.
ray.get(finish_signal.wait.remote())
@ray.remote
def test_func():
return 1
@ray.remote(num_cpus=0)
class Signal:
def __init__(self):
self.ready_event = asyncio.Event()
def send(self):
self.ready_event.set()
async def wait(self):
await self.ready_event.wait()
running_signal = Signal.remote()
finish_signal = Signal.remote()
# Launch the task with resource requirement of 4, thus the new available
# capacity becomes 1
task = wait_func._remote(
args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR],
resources={res_name: 4})
# Wait till wait_func is launched before updating resource
ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR))
args=[running_signal, finish_signal], resources={res_name: 4})
# Wait until wait_func is launched before updating resource
ray.get(running_signal.wait.remote())
# Decrease the resource capacity
ray.get(set_res.remote(res_name, updated_capacity, target_node_id))
# Signal task to complete
ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR))
ray.get(finish_signal.send.remote())
ray.get(task)
# Check if scheduler state is consistent by launching a task requiring
@ -456,12 +471,6 @@ def test_dynamic_res_concurrent_res_delete(ray_start_cluster):
num_nodes = 5
TIMEOUT_DURATION = 1
# Create a object ID to have the task wait on
WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii")
# Create a object ID to signal that the task is running
TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii")
for i in range(num_nodes):
cluster.add_node()
@ -486,29 +495,42 @@ def test_dynamic_res_concurrent_res_delete(ray_start_cluster):
# Task to hold the resource till the driver signals to finish
@ray.remote
def wait_func(running_oid, wait_oid):
# Signal that the task is running
ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid))
# Make the task wait till signalled by driver
ray.get(ray.ObjectID(wait_oid))
def wait_func(running_signal, finish_signal):
# Signal that the task is running.
ray.get(running_signal.send.remote())
# Wait until signaled by driver.
ray.get(finish_signal.wait.remote())
@ray.remote
def test_func():
return 1
@ray.remote(num_cpus=0)
class Signal:
def __init__(self):
self.ready_event = asyncio.Event()
def send(self):
self.ready_event.set()
async def wait(self):
await self.ready_event.wait()
running_signal = Signal.remote()
finish_signal = Signal.remote()
# Launch the task with resource requirement of 4, thus the new available
# capacity becomes 1
task = wait_func._remote(
args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR],
resources={res_name: 4})
# Wait till wait_func is launched before updating resource
ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR))
args=[running_signal, finish_signal], resources={res_name: 4})
# Wait until wait_func is launched before updating resource
ray.get(running_signal.wait.remote())
# Delete the resource
ray.get(delete_res.remote(res_name, target_node_id))
# Signal task to complete
ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR))
ray.get(finish_signal.send.remote())
ray.get(task)
# Check if scheduler state is consistent by launching a task requiring

View file

@ -23,7 +23,8 @@ class MockWaiter : public DependencyWaiter {
TEST(SchedulingQueueTest, TestInOrder) {
boost::asio::io_service io_service;
MockWaiter waiter;
SchedulingQueue queue(io_service, waiter, nullptr, 0);
WorkerContext context(WorkerType::WORKER, JobID::Nil());
SchedulingQueue queue(io_service, waiter, context, 0);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
@ -43,7 +44,8 @@ TEST(SchedulingQueueTest, TestWaitForObjects) {
ObjectID obj3 = ObjectID::FromRandom();
boost::asio::io_service io_service;
MockWaiter waiter;
SchedulingQueue queue(io_service, waiter, nullptr, 0);
WorkerContext context(WorkerType::WORKER, JobID::Nil());
SchedulingQueue queue(io_service, waiter, context, 0);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
@ -68,7 +70,8 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
ObjectID obj1 = ObjectID::FromRandom();
boost::asio::io_service io_service;
MockWaiter waiter;
SchedulingQueue queue(io_service, waiter, nullptr, 0);
WorkerContext context(WorkerType::WORKER, JobID::Nil());
SchedulingQueue queue(io_service, waiter, context, 0);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
@ -85,7 +88,8 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
TEST(SchedulingQueueTest, TestOutOfOrder) {
boost::asio::io_service io_service;
MockWaiter waiter;
SchedulingQueue queue(io_service, waiter, nullptr, 0);
WorkerContext context(WorkerType::WORKER, JobID::Nil());
SchedulingQueue queue(io_service, waiter, context, 0);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
@ -102,7 +106,8 @@ TEST(SchedulingQueueTest, TestOutOfOrder) {
TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
boost::asio::io_service io_service;
MockWaiter waiter;
SchedulingQueue queue(io_service, waiter, nullptr, 0);
WorkerContext context(WorkerType::WORKER, JobID::Nil());
SchedulingQueue queue(io_service, waiter, context, 0);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
@ -124,7 +129,8 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) {
boost::asio::io_service io_service;
MockWaiter waiter;
SchedulingQueue queue(io_service, waiter, nullptr, 0);
WorkerContext context(WorkerType::WORKER, JobID::Nil());
SchedulingQueue queue(io_service, waiter, context, 0);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };

View file

@ -178,24 +178,6 @@ void CoreWorkerDirectTaskReceiver::Init(rpc::ClientFactoryFn client_factory,
client_factory_ = client_factory;
}
void CoreWorkerDirectTaskReceiver::SetMaxActorConcurrency(int max_concurrency) {
if (max_concurrency != max_concurrency_) {
RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency;
RAY_CHECK(pool_ == nullptr) << "Cannot change max concurrency at runtime.";
pool_.reset(new BoundedExecutor(max_concurrency));
max_concurrency_ = max_concurrency;
}
}
void CoreWorkerDirectTaskReceiver::SetActorAsAsync(int max_concurrency) {
if (!is_asyncio_) {
RAY_LOG(DEBUG) << "Setting direct actor as async, creating new fiber thread.";
fiber_state_.reset(new FiberState(max_concurrency));
max_concurrency_ = max_concurrency;
is_asyncio_ = true;
}
};
void CoreWorkerDirectTaskReceiver::HandlePushTask(
const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
@ -208,14 +190,6 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask(
return;
}
// Only call SetMaxActorConcurrency to configure threadpool size when the
// actor is not async actor. Async actor is single threaded.
if (worker_context_.CurrentActorIsAsync()) {
SetActorAsAsync(worker_context_.CurrentActorMaxConcurrency());
} else {
SetMaxActorConcurrency(worker_context_.CurrentActorMaxConcurrency());
}
std::vector<ObjectID> dependencies;
for (size_t i = 0; i < task_spec.NumArgs(); ++i) {
int count = task_spec.ArgIdCount(i);
@ -325,9 +299,8 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask(
auto it = scheduling_queue_.find(task_spec.CallerId());
if (it == scheduling_queue_.end()) {
auto result = scheduling_queue_.emplace(
task_spec.CallerId(),
std::unique_ptr<SchedulingQueue>(new SchedulingQueue(
task_main_io_service_, *waiter_, pool_, is_asyncio_, fiber_state_)));
task_spec.CallerId(), std::unique_ptr<SchedulingQueue>(new SchedulingQueue(
task_main_io_service_, *waiter_, worker_context_)));
it = result.first;
}
it->second->Add(request.sequence_number(), request.client_processed_up_to(),

View file

@ -239,17 +239,13 @@ class BoundedExecutor {
class SchedulingQueue {
public:
SchedulingQueue(boost::asio::io_service &main_io_service, DependencyWaiter &waiter,
std::shared_ptr<BoundedExecutor> pool = nullptr,
bool use_asyncio = false,
std::shared_ptr<FiberState> fiber_state = nullptr,
WorkerContext &worker_context,
int64_t reorder_wait_seconds = kMaxReorderWaitSeconds)
: wait_timer_(main_io_service),
waiter_(waiter),
reorder_wait_seconds_(reorder_wait_seconds),
main_thread_id_(boost::this_thread::get_id()),
pool_(pool),
use_asyncio_(use_asyncio),
fiber_state_(fiber_state) {}
worker_context_(worker_context) {}
void Add(int64_t seq_no, int64_t client_processed_up_to,
std::function<void()> accept_request, std::function<void()> reject_request,
@ -283,6 +279,24 @@ class SchedulingQueue {
private:
/// Schedules as many requests as possible in sequence.
void ScheduleRequests() {
// Only call SetMaxActorConcurrency to configure threadpool size when the
// actor is not async actor. Async actor is single threaded.
int max_concurrency = worker_context_.CurrentActorMaxConcurrency();
if (worker_context_.CurrentActorIsAsync()) {
// If this is an async actor, initialize the fiber state once.
if (!is_asyncio_) {
RAY_LOG(DEBUG) << "Setting direct actor as async, creating new fiber thread.";
fiber_state_.reset(new FiberState(max_concurrency));
is_asyncio_ = true;
}
} else {
// If this is a concurrency actor (not async), initialize the thread pool once.
if (max_concurrency != 1 && !pool_) {
RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency;
pool_.reset(new BoundedExecutor(max_concurrency));
}
}
// Cancel any stale requests that the client doesn't need any longer.
while (!pending_tasks_.empty() && pending_tasks_.begin()->first < next_seq_no_) {
auto head = pending_tasks_.begin();
@ -298,11 +312,14 @@ class SchedulingQueue {
auto head = pending_tasks_.begin();
auto request = head->second;
if (use_asyncio_) {
if (is_asyncio_) {
// Process async actor task.
fiber_state_->EnqueueFiber([request]() mutable { request.Accept(); });
} else if (pool_ != nullptr) {
} else if (pool_) {
// Process concurrent actor task.
pool_->PostBlocking([request]() mutable { request.Accept(); });
} else {
// Process normal actor task.
request.Accept();
}
pending_tasks_.erase(head);
@ -339,6 +356,8 @@ class SchedulingQueue {
}
}
// Worker context.
WorkerContext &worker_context_;
/// Max time in seconds to wait for dependencies to show up.
const int64_t reorder_wait_seconds_ = 0;
/// Sorted map of (accept, rej) task callbacks keyed by their sequence number.
@ -353,13 +372,13 @@ class SchedulingQueue {
/// Reference to the waiter owned by the task receiver.
DependencyWaiter &waiter_;
/// If concurrent calls are allowed, holds the pool for executing these tasks.
std::shared_ptr<BoundedExecutor> pool_;
std::unique_ptr<BoundedExecutor> pool_;
/// Whether we should enqueue requests into asyncio pool. Setting this to true
/// will instantiate all tasks as fibers that can be yielded.
bool use_asyncio_;
bool is_asyncio_ = false;
/// If use_asyncio_ is true, fiber_state_ contains the running state required
/// to enable continuation and work together with python asyncio.
std::shared_ptr<FiberState> fiber_state_;
std::unique_ptr<FiberState> fiber_state_;
friend class SchedulingQueueTest;
};
@ -403,12 +422,6 @@ class CoreWorkerDirectTaskReceiver {
rpc::DirectActorCallArgWaitCompleteReply *reply,
rpc::SendReplyCallback send_reply_callback);
/// Set the max concurrency at runtime. It cannot be changed once set.
void SetMaxActorConcurrency(int max_concurrency);
/// Set the max concurrency and start async actor context.
void SetActorAsAsync(int max_concurrency);
private:
// Worker context.
WorkerContext &worker_context_;
@ -430,18 +443,8 @@ class CoreWorkerDirectTaskReceiver {
/// Queue of pending requests per actor handle.
/// TODO(ekl) GC these queues once the handle is no longer active.
std::unordered_map<TaskID, std::unique_ptr<SchedulingQueue>> scheduling_queue_;
/// The max number of concurrent calls to allow.
int max_concurrency_ = 1;
/// Whether we are shutting down and not running further tasks.
bool exiting_ = false;
/// If concurrent calls are allowed, holds the pool for executing these tasks.
std::shared_ptr<BoundedExecutor> pool_;
/// Whether this actor use asyncio for concurrency.
/// TODO(simon) group all asyncio related fields into a separate struct.
bool is_asyncio_ = false;
/// If use_asyncio_ is true, fiber_state_ contains the running state required
/// to enable continuation and work together with python asyncio.
std::shared_ptr<FiberState> fiber_state_;
};
} // namespace ray