mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
Fix concurrent actor starting too many threads. (#14927)
This commit is contained in:
parent
12b4560afa
commit
3e1a0439b7
4 changed files with 91 additions and 41 deletions
|
@ -561,6 +561,39 @@ def test_actor_concurrent(ray_start_regular_shared):
|
|||
assert r1 == r2 == r3
|
||||
|
||||
|
||||
def test_actor_max_concurrency(ray_start_regular_shared):
|
||||
"""
|
||||
Test that an actor of max_concurrency=N should only run
|
||||
N tasks at most concurrently.
|
||||
"""
|
||||
CONCURRENCY = 3
|
||||
|
||||
@ray.remote
|
||||
class ConcurentActor:
|
||||
def __init__(self):
|
||||
self.threads = set()
|
||||
|
||||
def call(self):
|
||||
# Record the current thread that runs this function.
|
||||
self.threads.add(threading.current_thread())
|
||||
|
||||
def get_num_threads(self):
|
||||
return len(self.threads)
|
||||
|
||||
@ray.remote
|
||||
def call(actor):
|
||||
for _ in range(CONCURRENCY * 100):
|
||||
ray.get(actor.call.remote())
|
||||
return
|
||||
|
||||
actor = ConcurentActor.options(max_concurrency=CONCURRENCY).remote()
|
||||
# Start many concurrent tasks that will call the actor many times.
|
||||
ray.get([call.remote(actor) for _ in range(CONCURRENCY * 10)])
|
||||
|
||||
# Check that the number of threads shouldn't be greater than CONCURRENCY.
|
||||
assert ray.get(actor.get_num_threads.remote()) <= CONCURRENCY
|
||||
|
||||
|
||||
def test_wait(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
def f(delay):
|
||||
|
|
|
@ -39,8 +39,7 @@ class MockWaiter : public DependencyWaiter {
|
|||
TEST(SchedulingQueueTest, TestInOrder) {
|
||||
instrumented_io_context io_service;
|
||||
MockWaiter waiter;
|
||||
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
|
||||
ActorSchedulingQueue queue(io_service, waiter, context);
|
||||
ActorSchedulingQueue queue(io_service, waiter);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
|
@ -60,8 +59,7 @@ TEST(SchedulingQueueTest, TestWaitForObjects) {
|
|||
ObjectID obj3 = ObjectID::FromRandom();
|
||||
instrumented_io_context io_service;
|
||||
MockWaiter waiter;
|
||||
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
|
||||
ActorSchedulingQueue queue(io_service, waiter, context);
|
||||
ActorSchedulingQueue queue(io_service, waiter);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
|
@ -86,8 +84,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
|
|||
ObjectID obj1 = ObjectID::FromRandom();
|
||||
instrumented_io_context io_service;
|
||||
MockWaiter waiter;
|
||||
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
|
||||
ActorSchedulingQueue queue(io_service, waiter, context);
|
||||
ActorSchedulingQueue queue(io_service, waiter);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
|
@ -104,8 +101,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
|
|||
TEST(SchedulingQueueTest, TestOutOfOrder) {
|
||||
instrumented_io_context io_service;
|
||||
MockWaiter waiter;
|
||||
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
|
||||
ActorSchedulingQueue queue(io_service, waiter, context);
|
||||
ActorSchedulingQueue queue(io_service, waiter);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
|
@ -122,8 +118,7 @@ TEST(SchedulingQueueTest, TestOutOfOrder) {
|
|||
TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
|
||||
instrumented_io_context io_service;
|
||||
MockWaiter waiter;
|
||||
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
|
||||
ActorSchedulingQueue queue(io_service, waiter, context);
|
||||
ActorSchedulingQueue queue(io_service, waiter);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
|
@ -145,8 +140,7 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
|
|||
TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) {
|
||||
instrumented_io_context io_service;
|
||||
MockWaiter waiter;
|
||||
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
|
||||
ActorSchedulingQueue queue(io_service, waiter, context);
|
||||
ActorSchedulingQueue queue(io_service, waiter);
|
||||
int n_ok = 0;
|
||||
int n_rej = 0;
|
||||
auto fn_ok = [&n_ok]() { n_ok++; };
|
||||
|
|
|
@ -437,6 +437,10 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
|||
return;
|
||||
}
|
||||
|
||||
if (task_spec.IsActorCreationTask()) {
|
||||
SetMaxActorConcurrency(task_spec.IsAsyncioActor(), task_spec.MaxActorConcurrency());
|
||||
}
|
||||
|
||||
// Only assign resources for non-actor tasks. Actor tasks inherit the resources
|
||||
// assigned at initial actor creation time.
|
||||
std::shared_ptr<ResourceMappingType> resource_ids;
|
||||
|
@ -530,7 +534,7 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
|
|||
auto result = actor_scheduling_queues_.emplace(
|
||||
task_spec.CallerWorkerId(),
|
||||
std::unique_ptr<SchedulingQueue>(new ActorSchedulingQueue(
|
||||
task_main_io_service_, *waiter_, worker_context_)));
|
||||
task_main_io_service_, *waiter_, pool_, is_asyncio_, fiber_state_)));
|
||||
it = result.first;
|
||||
}
|
||||
|
||||
|
@ -563,4 +567,25 @@ bool CoreWorkerDirectTaskReceiver::CancelQueuedNormalTask(TaskID task_id) {
|
|||
return normal_scheduling_queue_->CancelTaskIfFound(task_id);
|
||||
}
|
||||
|
||||
void CoreWorkerDirectTaskReceiver::SetMaxActorConcurrency(bool is_asyncio,
|
||||
int max_concurrency) {
|
||||
RAY_CHECK(max_concurrency_ == 0)
|
||||
<< "SetMaxActorConcurrency should only be called at most once.";
|
||||
RAY_CHECK(fiber_state_ == nullptr);
|
||||
RAY_CHECK(pool_ == nullptr);
|
||||
RAY_CHECK(max_concurrency >= 1);
|
||||
if (max_concurrency > 1) {
|
||||
max_concurrency_ = max_concurrency;
|
||||
is_asyncio_ = is_asyncio;
|
||||
if (is_asyncio_) {
|
||||
RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency;
|
||||
fiber_state_.reset(new FiberState(max_concurrency));
|
||||
} else {
|
||||
RAY_LOG(INFO) << "Setting actor as async with max_concurrency=" << max_concurrency
|
||||
<< ", creating new fiber thread.";
|
||||
pool_.reset(new BoundedExecutor(max_concurrency));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
|
|
@ -386,13 +386,17 @@ class SchedulingQueue {
|
|||
class ActorSchedulingQueue : public SchedulingQueue {
|
||||
public:
|
||||
ActorSchedulingQueue(instrumented_io_context &main_io_service, DependencyWaiter &waiter,
|
||||
WorkerContext &worker_context,
|
||||
std::shared_ptr<BoundedExecutor> pool = nullptr,
|
||||
bool is_asyncio = false,
|
||||
std::shared_ptr<FiberState> fiber_state = nullptr,
|
||||
int64_t reorder_wait_seconds = kMaxReorderWaitSeconds)
|
||||
: worker_context_(worker_context),
|
||||
reorder_wait_seconds_(reorder_wait_seconds),
|
||||
: reorder_wait_seconds_(reorder_wait_seconds),
|
||||
wait_timer_(main_io_service),
|
||||
main_thread_id_(boost::this_thread::get_id()),
|
||||
waiter_(waiter) {}
|
||||
waiter_(waiter),
|
||||
pool_(pool),
|
||||
is_asyncio_(is_asyncio),
|
||||
fiber_state_(fiber_state) {}
|
||||
|
||||
bool TaskQueueEmpty() const { return pending_actor_tasks_.empty(); }
|
||||
|
||||
|
@ -437,24 +441,6 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
|||
|
||||
/// Schedules as many requests as possible in sequence.
|
||||
void ScheduleRequests() {
|
||||
// Only call SetMaxActorConcurrency to configure threadpool size when the
|
||||
// actor is not async actor. Async actor is single threaded.
|
||||
int max_concurrency = worker_context_.CurrentActorMaxConcurrency();
|
||||
if (worker_context_.CurrentActorIsAsync()) {
|
||||
// If this is an async actor, initialize the fiber state once.
|
||||
if (!is_asyncio_) {
|
||||
RAY_LOG(DEBUG) << "Setting direct actor as async, creating new fiber thread.";
|
||||
fiber_state_.reset(new FiberState(max_concurrency));
|
||||
is_asyncio_ = true;
|
||||
}
|
||||
} else {
|
||||
// If this is a concurrency actor (not async), initialize the thread pool once.
|
||||
if (max_concurrency != 1 && !pool_) {
|
||||
RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency;
|
||||
pool_.reset(new BoundedExecutor(max_concurrency));
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel any stale requests that the client doesn't need any longer.
|
||||
while (!pending_actor_tasks_.empty() &&
|
||||
pending_actor_tasks_.begin()->first < next_seq_no_) {
|
||||
|
@ -475,7 +461,7 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
|||
if (is_asyncio_) {
|
||||
// Process async actor task.
|
||||
fiber_state_->EnqueueFiber([request]() mutable { request.Accept(); });
|
||||
} else if (pool_) {
|
||||
} else if (pool_ != nullptr) {
|
||||
// Process concurrent actor task.
|
||||
pool_->PostBlocking([request]() mutable { request.Accept(); });
|
||||
} else {
|
||||
|
@ -518,8 +504,6 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
|||
}
|
||||
}
|
||||
|
||||
// Worker context.
|
||||
WorkerContext &worker_context_;
|
||||
/// Max time in seconds to wait for dependencies to show up.
|
||||
const int64_t reorder_wait_seconds_ = 0;
|
||||
/// Sorted map of (accept, rej) task callbacks keyed by their sequence number.
|
||||
|
@ -534,13 +518,13 @@ class ActorSchedulingQueue : public SchedulingQueue {
|
|||
/// Reference to the waiter owned by the task receiver.
|
||||
DependencyWaiter &waiter_;
|
||||
/// If concurrent calls are allowed, holds the pool for executing these tasks.
|
||||
std::unique_ptr<BoundedExecutor> pool_;
|
||||
std::shared_ptr<BoundedExecutor> pool_;
|
||||
/// Whether we should enqueue requests into asyncio pool. Setting this to true
|
||||
/// will instantiate all tasks as fibers that can be yielded.
|
||||
bool is_asyncio_ = false;
|
||||
/// If use_asyncio_ is true, fiber_state_ contains the running state required
|
||||
/// If is_asyncio_ is true, fiber_state_ contains the running state required
|
||||
/// to enable continuation and work together with python asyncio.
|
||||
std::unique_ptr<FiberState> fiber_state_;
|
||||
std::shared_ptr<FiberState> fiber_state_;
|
||||
friend class SchedulingQueueTest;
|
||||
};
|
||||
|
||||
|
@ -667,6 +651,20 @@ class CoreWorkerDirectTaskReceiver {
|
|||
// Queue of pending normal (non-actor) tasks.
|
||||
std::unique_ptr<SchedulingQueue> normal_scheduling_queue_ =
|
||||
std::unique_ptr<SchedulingQueue>(new NormalSchedulingQueue());
|
||||
/// The max number of concurrent calls to allow.
|
||||
/// 0 indicates that the value is not set yet.
|
||||
int max_concurrency_ = 0;
|
||||
/// If concurrent calls are allowed, holds the pool for executing these tasks.
|
||||
std::shared_ptr<BoundedExecutor> pool_;
|
||||
/// Whether this actor use asyncio for concurrency.
|
||||
bool is_asyncio_ = false;
|
||||
/// If use_asyncio_ is true, fiber_state_ contains the running state required
|
||||
/// to enable continuation and work together with python asyncio.
|
||||
std::shared_ptr<FiberState> fiber_state_;
|
||||
|
||||
/// Set the max concurrency of an actor.
|
||||
/// This should be called once for the actor creation task.
|
||||
void SetMaxActorConcurrency(bool is_asyncio, int max_concurrency);
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
|
Loading…
Add table
Reference in a new issue