mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Fix asyncio actor race condition (#7335)
This commit is contained in:
parent
58073f7260
commit
55ccfb6089
4 changed files with 114 additions and 110 deletions
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
@ -300,12 +301,6 @@ def test_dynamic_res_concurrent_res_increment(ray_start_cluster):
|
|||
num_nodes = 5
|
||||
TIMEOUT_DURATION = 1
|
||||
|
||||
# Create a object ID to have the task wait on
|
||||
WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii")
|
||||
|
||||
# Create a object ID to signal that the task is running
|
||||
TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii")
|
||||
|
||||
for i in range(num_nodes):
|
||||
cluster.add_node()
|
||||
|
||||
|
@ -325,29 +320,42 @@ def test_dynamic_res_concurrent_res_increment(ray_start_cluster):
|
|||
|
||||
# Task to hold the resource till the driver signals to finish
|
||||
@ray.remote
|
||||
def wait_func(running_oid, wait_oid):
|
||||
# Signal that the task is running
|
||||
ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid))
|
||||
# Make the task wait till signalled by driver
|
||||
ray.get(ray.ObjectID(wait_oid))
|
||||
def wait_func(running_signal, finish_signal):
|
||||
# Signal that the task is running.
|
||||
ray.get(running_signal.send.remote())
|
||||
# Wait until signaled by driver.
|
||||
ray.get(finish_signal.wait.remote())
|
||||
|
||||
@ray.remote
|
||||
def test_func():
|
||||
return 1
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class Signal:
|
||||
def __init__(self):
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
def send(self):
|
||||
self.ready_event.set()
|
||||
|
||||
async def wait(self):
|
||||
await self.ready_event.wait()
|
||||
|
||||
running_signal = Signal.remote()
|
||||
finish_signal = Signal.remote()
|
||||
|
||||
# Launch the task with resource requirement of 4, thus the new available
|
||||
# capacity becomes 1
|
||||
task = wait_func._remote(
|
||||
args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR],
|
||||
resources={res_name: 4})
|
||||
# Wait till wait_func is launched before updating resource
|
||||
ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR))
|
||||
args=[running_signal, finish_signal], resources={res_name: 4})
|
||||
# Wait until wait_func is launched before updating resource
|
||||
ray.get(running_signal.wait.remote())
|
||||
|
||||
# Update the resource capacity
|
||||
ray.get(set_res.remote(res_name, updated_capacity, target_node_id))
|
||||
|
||||
# Signal task to complete
|
||||
ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR))
|
||||
ray.get(finish_signal.send.remote())
|
||||
ray.get(task)
|
||||
|
||||
# Check if scheduler state is consistent by launching a task requiring
|
||||
|
@ -379,12 +387,6 @@ def test_dynamic_res_concurrent_res_decrement(ray_start_cluster):
|
|||
num_nodes = 5
|
||||
TIMEOUT_DURATION = 1
|
||||
|
||||
# Create a object ID to have the task wait on
|
||||
WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii")
|
||||
|
||||
# Create a object ID to signal that the task is running
|
||||
TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii")
|
||||
|
||||
for i in range(num_nodes):
|
||||
cluster.add_node()
|
||||
|
||||
|
@ -404,29 +406,42 @@ def test_dynamic_res_concurrent_res_decrement(ray_start_cluster):
|
|||
|
||||
# Task to hold the resource till the driver signals to finish
|
||||
@ray.remote
|
||||
def wait_func(running_oid, wait_oid):
|
||||
# Signal that the task is running
|
||||
ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid))
|
||||
# Make the task wait till signalled by driver
|
||||
ray.get(ray.ObjectID(wait_oid))
|
||||
def wait_func(running_signal, finish_signal):
|
||||
# Signal that the task is running.
|
||||
ray.get(running_signal.send.remote())
|
||||
# Wait until signaled by driver.
|
||||
ray.get(finish_signal.wait.remote())
|
||||
|
||||
@ray.remote
|
||||
def test_func():
|
||||
return 1
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class Signal:
|
||||
def __init__(self):
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
def send(self):
|
||||
self.ready_event.set()
|
||||
|
||||
async def wait(self):
|
||||
await self.ready_event.wait()
|
||||
|
||||
running_signal = Signal.remote()
|
||||
finish_signal = Signal.remote()
|
||||
|
||||
# Launch the task with resource requirement of 4, thus the new available
|
||||
# capacity becomes 1
|
||||
task = wait_func._remote(
|
||||
args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR],
|
||||
resources={res_name: 4})
|
||||
# Wait till wait_func is launched before updating resource
|
||||
ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR))
|
||||
args=[running_signal, finish_signal], resources={res_name: 4})
|
||||
# Wait until wait_func is launched before updating resource
|
||||
ray.get(running_signal.wait.remote())
|
||||
|
||||
# Decrease the resource capacity
|
||||
ray.get(set_res.remote(res_name, updated_capacity, target_node_id))
|
||||
|
||||
# Signal task to complete
|
||||
ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR))
|
||||
ray.get(finish_signal.send.remote())
|
||||
ray.get(task)
|
||||
|
||||
# Check if scheduler state is consistent by launching a task requiring
|
||||
|
@ -456,12 +471,6 @@ def test_dynamic_res_concurrent_res_delete(ray_start_cluster):
|
|||
num_nodes = 5
|
||||
TIMEOUT_DURATION = 1
|
||||
|
||||
# Create a object ID to have the task wait on
|
||||
WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii")
|
||||
|
||||
# Create a object ID to signal that the task is running
|
||||
TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii")
|
||||
|
||||
for i in range(num_nodes):
|
||||
cluster.add_node()
|
||||
|
||||
|
@ -486,29 +495,42 @@ def test_dynamic_res_concurrent_res_delete(ray_start_cluster):
|
|||
|
||||
# Task to hold the resource till the driver signals to finish
|
||||
@ray.remote
|
||||
def wait_func(running_oid, wait_oid):
|
||||
# Signal that the task is running
|
||||
ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid))
|
||||
# Make the task wait till signalled by driver
|
||||
ray.get(ray.ObjectID(wait_oid))
|
||||
def wait_func(running_signal, finish_signal):
|
||||
# Signal that the task is running.
|
||||
ray.get(running_signal.send.remote())
|
||||
# Wait until signaled by driver.
|
||||
ray.get(finish_signal.wait.remote())
|
||||
|
||||
@ray.remote
|
||||
def test_func():
|
||||
return 1
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class Signal:
|
||||
def __init__(self):
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
def send(self):
|
||||
self.ready_event.set()
|
||||
|
||||
async def wait(self):
|
||||
await self.ready_event.wait()
|
||||
|
||||
running_signal = Signal.remote()
|
||||
finish_signal = Signal.remote()
|
||||
|
||||
# Launch the task with resource requirement of 4, thus the new available
|
||||
# capacity becomes 1
|
||||
task = wait_func._remote(
|
||||
args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR],
|
||||
resources={res_name: 4})
|
||||
# Wait till wait_func is launched before updating resource
|
||||
ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR))
|
||||
args=[running_signal, finish_signal], resources={res_name: 4})
|
||||
# Wait until wait_func is launched before updating resource
|
||||
ray.get(running_signal.wait.remote())
|
||||
|
||||
# Delete the resource
|
||||
ray.get(delete_res.remote(res_name, target_node_id))
|
||||
|
||||
# Signal task to complete
|
||||
ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR))
|
||||
ray.get(finish_signal.send.remote())
|
||||
ray.get(task)
|
||||
|
||||
# Check if scheduler state is consistent by launching a task requiring
|
||||
|
|
|
@ -23,7 +23,8 @@ class MockWaiter : public DependencyWaiter {
|
|||
TEST(SchedulingQueueTest, TestInOrder) {
|
||||
boost::asio::io_service io_service;
|
||||
MockWaiter waiter;
|
||||
SchedulingQueue queue(io_service, waiter, nullptr, 0);
|
||||
WorkerContext context(WorkerType::WORKER, JobID::Nil());
|
||||
SchedulingQueue queue(io_service, waiter, context, 0);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
|
@ -43,7 +44,8 @@ TEST(SchedulingQueueTest, TestWaitForObjects) {
|
|||
ObjectID obj3 = ObjectID::FromRandom();
|
||||
boost::asio::io_service io_service;
|
||||
MockWaiter waiter;
|
||||
SchedulingQueue queue(io_service, waiter, nullptr, 0);
|
||||
WorkerContext context(WorkerType::WORKER, JobID::Nil());
|
||||
SchedulingQueue queue(io_service, waiter, context, 0);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
|
@ -68,7 +70,8 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
|
|||
ObjectID obj1 = ObjectID::FromRandom();
|
||||
boost::asio::io_service io_service;
|
||||
MockWaiter waiter;
|
||||
SchedulingQueue queue(io_service, waiter, nullptr, 0);
|
||||
WorkerContext context(WorkerType::WORKER, JobID::Nil());
|
||||
SchedulingQueue queue(io_service, waiter, context, 0);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
|
@ -85,7 +88,8 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
|
|||
TEST(SchedulingQueueTest, TestOutOfOrder) {
|
||||
boost::asio::io_service io_service;
|
||||
MockWaiter waiter;
|
||||
SchedulingQueue queue(io_service, waiter, nullptr, 0);
|
||||
WorkerContext context(WorkerType::WORKER, JobID::Nil());
|
||||
SchedulingQueue queue(io_service, waiter, context, 0);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
|
@ -102,7 +106,8 @@ TEST(SchedulingQueueTest, TestOutOfOrder) {
|
|||
TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
|
||||
boost::asio::io_service io_service;
|
||||
MockWaiter waiter;
|
||||
SchedulingQueue queue(io_service, waiter, nullptr, 0);
|
||||
WorkerContext context(WorkerType::WORKER, JobID::Nil());
|
||||
SchedulingQueue queue(io_service, waiter, context, 0);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
|
@ -124,7 +129,8 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
|
|||
TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) {
|
||||
boost::asio::io_service io_service;
|
||||
MockWaiter waiter;
|
||||
SchedulingQueue queue(io_service, waiter, nullptr, 0);
|
||||
WorkerContext context(WorkerType::WORKER, JobID::Nil());
|
||||
SchedulingQueue queue(io_service, waiter, context, 0);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
|
|
|
@ -178,24 +178,6 @@ void CoreWorkerDirectTaskReceiver::Init(rpc::ClientFactoryFn client_factory,
|
|||
client_factory_ = client_factory;
|
||||
}
|
||||
|
||||
void CoreWorkerDirectTaskReceiver::SetMaxActorConcurrency(int max_concurrency) {
|
||||
if (max_concurrency != max_concurrency_) {
|
||||
RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency;
|
||||
RAY_CHECK(pool_ == nullptr) << "Cannot change max concurrency at runtime.";
|
||||
pool_.reset(new BoundedExecutor(max_concurrency));
|
||||
max_concurrency_ = max_concurrency;
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorkerDirectTaskReceiver::SetActorAsAsync(int max_concurrency) {
|
||||
if (!is_asyncio_) {
|
||||
RAY_LOG(DEBUG) << "Setting direct actor as async, creating new fiber thread.";
|
||||
fiber_state_.reset(new FiberState(max_concurrency));
|
||||
max_concurrency_ = max_concurrency;
|
||||
is_asyncio_ = true;
|
||||
}
|
||||
};
|
||||
|
||||
void CoreWorkerDirectTaskReceiver::HandlePushTask(
|
||||
const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
|
@ -208,14 +190,6 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask(
|
|||
return;
|
||||
}
|
||||
|
||||
// Only call SetMaxActorConcurrency to configure threadpool size when the
|
||||
// actor is not async actor. Async actor is single threaded.
|
||||
if (worker_context_.CurrentActorIsAsync()) {
|
||||
SetActorAsAsync(worker_context_.CurrentActorMaxConcurrency());
|
||||
} else {
|
||||
SetMaxActorConcurrency(worker_context_.CurrentActorMaxConcurrency());
|
||||
}
|
||||
|
||||
std::vector<ObjectID> dependencies;
|
||||
for (size_t i = 0; i < task_spec.NumArgs(); ++i) {
|
||||
int count = task_spec.ArgIdCount(i);
|
||||
|
@ -325,9 +299,8 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask(
|
|||
auto it = scheduling_queue_.find(task_spec.CallerId());
|
||||
if (it == scheduling_queue_.end()) {
|
||||
auto result = scheduling_queue_.emplace(
|
||||
task_spec.CallerId(),
|
||||
std::unique_ptr<SchedulingQueue>(new SchedulingQueue(
|
||||
task_main_io_service_, *waiter_, pool_, is_asyncio_, fiber_state_)));
|
||||
task_spec.CallerId(), std::unique_ptr<SchedulingQueue>(new SchedulingQueue(
|
||||
task_main_io_service_, *waiter_, worker_context_)));
|
||||
it = result.first;
|
||||
}
|
||||
it->second->Add(request.sequence_number(), request.client_processed_up_to(),
|
||||
|
|
|
@ -239,17 +239,13 @@ class BoundedExecutor {
|
|||
class SchedulingQueue {
|
||||
public:
|
||||
SchedulingQueue(boost::asio::io_service &main_io_service, DependencyWaiter &waiter,
|
||||
std::shared_ptr<BoundedExecutor> pool = nullptr,
|
||||
bool use_asyncio = false,
|
||||
std::shared_ptr<FiberState> fiber_state = nullptr,
|
||||
WorkerContext &worker_context,
|
||||
int64_t reorder_wait_seconds = kMaxReorderWaitSeconds)
|
||||
: wait_timer_(main_io_service),
|
||||
waiter_(waiter),
|
||||
reorder_wait_seconds_(reorder_wait_seconds),
|
||||
main_thread_id_(boost::this_thread::get_id()),
|
||||
pool_(pool),
|
||||
use_asyncio_(use_asyncio),
|
||||
fiber_state_(fiber_state) {}
|
||||
worker_context_(worker_context) {}
|
||||
|
||||
void Add(int64_t seq_no, int64_t client_processed_up_to,
|
||||
std::function<void()> accept_request, std::function<void()> reject_request,
|
||||
|
@ -283,6 +279,24 @@ class SchedulingQueue {
|
|||
private:
|
||||
/// Schedules as many requests as possible in sequence.
|
||||
void ScheduleRequests() {
|
||||
// Only call SetMaxActorConcurrency to configure threadpool size when the
|
||||
// actor is not async actor. Async actor is single threaded.
|
||||
int max_concurrency = worker_context_.CurrentActorMaxConcurrency();
|
||||
if (worker_context_.CurrentActorIsAsync()) {
|
||||
// If this is an async actor, initialize the fiber state once.
|
||||
if (!is_asyncio_) {
|
||||
RAY_LOG(DEBUG) << "Setting direct actor as async, creating new fiber thread.";
|
||||
fiber_state_.reset(new FiberState(max_concurrency));
|
||||
is_asyncio_ = true;
|
||||
}
|
||||
} else {
|
||||
// If this is a concurrency actor (not async), initialize the thread pool once.
|
||||
if (max_concurrency != 1 && !pool_) {
|
||||
RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency;
|
||||
pool_.reset(new BoundedExecutor(max_concurrency));
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel any stale requests that the client doesn't need any longer.
|
||||
while (!pending_tasks_.empty() && pending_tasks_.begin()->first < next_seq_no_) {
|
||||
auto head = pending_tasks_.begin();
|
||||
|
@ -298,11 +312,14 @@ class SchedulingQueue {
|
|||
auto head = pending_tasks_.begin();
|
||||
auto request = head->second;
|
||||
|
||||
if (use_asyncio_) {
|
||||
if (is_asyncio_) {
|
||||
// Process async actor task.
|
||||
fiber_state_->EnqueueFiber([request]() mutable { request.Accept(); });
|
||||
} else if (pool_ != nullptr) {
|
||||
} else if (pool_) {
|
||||
// Process concurrent actor task.
|
||||
pool_->PostBlocking([request]() mutable { request.Accept(); });
|
||||
} else {
|
||||
// Process normal actor task.
|
||||
request.Accept();
|
||||
}
|
||||
pending_tasks_.erase(head);
|
||||
|
@ -339,6 +356,8 @@ class SchedulingQueue {
|
|||
}
|
||||
}
|
||||
|
||||
// Worker context.
|
||||
WorkerContext &worker_context_;
|
||||
/// Max time in seconds to wait for dependencies to show up.
|
||||
const int64_t reorder_wait_seconds_ = 0;
|
||||
/// Sorted map of (accept, rej) task callbacks keyed by their sequence number.
|
||||
|
@ -353,13 +372,13 @@ class SchedulingQueue {
|
|||
/// Reference to the waiter owned by the task receiver.
|
||||
DependencyWaiter &waiter_;
|
||||
/// If concurrent calls are allowed, holds the pool for executing these tasks.
|
||||
std::shared_ptr<BoundedExecutor> pool_;
|
||||
std::unique_ptr<BoundedExecutor> pool_;
|
||||
/// Whether we should enqueue requests into asyncio pool. Setting this to true
|
||||
/// will instantiate all tasks as fibers that can be yielded.
|
||||
bool use_asyncio_;
|
||||
bool is_asyncio_ = false;
|
||||
/// If use_asyncio_ is true, fiber_state_ contains the running state required
|
||||
/// to enable continuation and work together with python asyncio.
|
||||
std::shared_ptr<FiberState> fiber_state_;
|
||||
std::unique_ptr<FiberState> fiber_state_;
|
||||
friend class SchedulingQueueTest;
|
||||
};
|
||||
|
||||
|
@ -403,12 +422,6 @@ class CoreWorkerDirectTaskReceiver {
|
|||
rpc::DirectActorCallArgWaitCompleteReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
|
||||
/// Set the max concurrency at runtime. It cannot be changed once set.
|
||||
void SetMaxActorConcurrency(int max_concurrency);
|
||||
|
||||
/// Set the max concurrency and start async actor context.
|
||||
void SetActorAsAsync(int max_concurrency);
|
||||
|
||||
private:
|
||||
// Worker context.
|
||||
WorkerContext &worker_context_;
|
||||
|
@ -430,18 +443,8 @@ class CoreWorkerDirectTaskReceiver {
|
|||
/// Queue of pending requests per actor handle.
|
||||
/// TODO(ekl) GC these queues once the handle is no longer active.
|
||||
std::unordered_map<TaskID, std::unique_ptr<SchedulingQueue>> scheduling_queue_;
|
||||
/// The max number of concurrent calls to allow.
|
||||
int max_concurrency_ = 1;
|
||||
/// Whether we are shutting down and not running further tasks.
|
||||
bool exiting_ = false;
|
||||
/// If concurrent calls are allowed, holds the pool for executing these tasks.
|
||||
std::shared_ptr<BoundedExecutor> pool_;
|
||||
/// Whether this actor use asyncio for concurrency.
|
||||
/// TODO(simon) group all asyncio related fields into a separate struct.
|
||||
bool is_asyncio_ = false;
|
||||
/// If use_asyncio_ is true, fiber_state_ contains the running state required
|
||||
/// to enable continuation and work together with python asyncio.
|
||||
std::shared_ptr<FiberState> fiber_state_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
|
Loading…
Add table
Reference in a new issue