Fix concurrent actor starting too many threads. (#14927)

This commit is contained in:
Hao Chen 2021-04-01 19:58:18 +08:00 committed by GitHub
parent 12b4560afa
commit 3e1a0439b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 91 additions and 41 deletions

View file

@ -561,6 +561,39 @@ def test_actor_concurrent(ray_start_regular_shared):
assert r1 == r2 == r3
def test_actor_max_concurrency(ray_start_regular_shared):
"""
Test that an actor of max_concurrency=N should only run
N tasks at most concurrently.
"""
CONCURRENCY = 3
@ray.remote
class ConcurentActor:
def __init__(self):
self.threads = set()
def call(self):
# Record the current thread that runs this function.
self.threads.add(threading.current_thread())
def get_num_threads(self):
return len(self.threads)
@ray.remote
def call(actor):
for _ in range(CONCURRENCY * 100):
ray.get(actor.call.remote())
return
actor = ConcurentActor.options(max_concurrency=CONCURRENCY).remote()
# Start many concurrent tasks that will call the actor many times.
ray.get([call.remote(actor) for _ in range(CONCURRENCY * 10)])
# Check that the number of threads shouldn't be greater than CONCURRENCY.
assert ray.get(actor.get_num_threads.remote()) <= CONCURRENCY
def test_wait(ray_start_regular_shared):
@ray.remote
def f(delay):

View file

@ -39,8 +39,7 @@ class MockWaiter : public DependencyWaiter {
TEST(SchedulingQueueTest, TestInOrder) {
instrumented_io_context io_service;
MockWaiter waiter;
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
ActorSchedulingQueue queue(io_service, waiter, context);
ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
@ -60,8 +59,7 @@ TEST(SchedulingQueueTest, TestWaitForObjects) {
ObjectID obj3 = ObjectID::FromRandom();
instrumented_io_context io_service;
MockWaiter waiter;
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
ActorSchedulingQueue queue(io_service, waiter, context);
ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
@ -86,8 +84,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
ObjectID obj1 = ObjectID::FromRandom();
instrumented_io_context io_service;
MockWaiter waiter;
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
ActorSchedulingQueue queue(io_service, waiter, context);
ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
@ -104,8 +101,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
TEST(SchedulingQueueTest, TestOutOfOrder) {
instrumented_io_context io_service;
MockWaiter waiter;
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
ActorSchedulingQueue queue(io_service, waiter, context);
ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
@ -122,8 +118,7 @@ TEST(SchedulingQueueTest, TestOutOfOrder) {
TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
instrumented_io_context io_service;
MockWaiter waiter;
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
ActorSchedulingQueue queue(io_service, waiter, context);
ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
@ -145,8 +140,7 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) {
TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) {
instrumented_io_context io_service;
MockWaiter waiter;
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil());
ActorSchedulingQueue queue(io_service, waiter, context);
ActorSchedulingQueue queue(io_service, waiter);
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };

View file

@ -437,6 +437,10 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
return;
}
if (task_spec.IsActorCreationTask()) {
SetMaxActorConcurrency(task_spec.IsAsyncioActor(), task_spec.MaxActorConcurrency());
}
// Only assign resources for non-actor tasks. Actor tasks inherit the resources
// assigned at initial actor creation time.
std::shared_ptr<ResourceMappingType> resource_ids;
@ -530,7 +534,7 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
auto result = actor_scheduling_queues_.emplace(
task_spec.CallerWorkerId(),
std::unique_ptr<SchedulingQueue>(new ActorSchedulingQueue(
task_main_io_service_, *waiter_, worker_context_)));
task_main_io_service_, *waiter_, pool_, is_asyncio_, fiber_state_)));
it = result.first;
}
@ -563,4 +567,25 @@ bool CoreWorkerDirectTaskReceiver::CancelQueuedNormalTask(TaskID task_id) {
return normal_scheduling_queue_->CancelTaskIfFound(task_id);
}
void CoreWorkerDirectTaskReceiver::SetMaxActorConcurrency(bool is_asyncio,
int max_concurrency) {
RAY_CHECK(max_concurrency_ == 0)
<< "SetMaxActorConcurrency should only be called at most once.";
RAY_CHECK(fiber_state_ == nullptr);
RAY_CHECK(pool_ == nullptr);
RAY_CHECK(max_concurrency >= 1);
if (max_concurrency > 1) {
max_concurrency_ = max_concurrency;
is_asyncio_ = is_asyncio;
if (is_asyncio_) {
RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency;
fiber_state_.reset(new FiberState(max_concurrency));
} else {
RAY_LOG(INFO) << "Setting actor as async with max_concurrency=" << max_concurrency
<< ", creating new fiber thread.";
pool_.reset(new BoundedExecutor(max_concurrency));
}
}
}
} // namespace ray

View file

@ -386,13 +386,17 @@ class SchedulingQueue {
class ActorSchedulingQueue : public SchedulingQueue {
public:
ActorSchedulingQueue(instrumented_io_context &main_io_service, DependencyWaiter &waiter,
WorkerContext &worker_context,
std::shared_ptr<BoundedExecutor> pool = nullptr,
bool is_asyncio = false,
std::shared_ptr<FiberState> fiber_state = nullptr,
int64_t reorder_wait_seconds = kMaxReorderWaitSeconds)
: worker_context_(worker_context),
reorder_wait_seconds_(reorder_wait_seconds),
: reorder_wait_seconds_(reorder_wait_seconds),
wait_timer_(main_io_service),
main_thread_id_(boost::this_thread::get_id()),
waiter_(waiter) {}
waiter_(waiter),
pool_(pool),
is_asyncio_(is_asyncio),
fiber_state_(fiber_state) {}
bool TaskQueueEmpty() const { return pending_actor_tasks_.empty(); }
@ -437,24 +441,6 @@ class ActorSchedulingQueue : public SchedulingQueue {
/// Schedules as many requests as possible in sequence.
void ScheduleRequests() {
// Only call SetMaxActorConcurrency to configure threadpool size when the
// actor is not async actor. Async actor is single threaded.
int max_concurrency = worker_context_.CurrentActorMaxConcurrency();
if (worker_context_.CurrentActorIsAsync()) {
// If this is an async actor, initialize the fiber state once.
if (!is_asyncio_) {
RAY_LOG(DEBUG) << "Setting direct actor as async, creating new fiber thread.";
fiber_state_.reset(new FiberState(max_concurrency));
is_asyncio_ = true;
}
} else {
// If this is a concurrency actor (not async), initialize the thread pool once.
if (max_concurrency != 1 && !pool_) {
RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency;
pool_.reset(new BoundedExecutor(max_concurrency));
}
}
// Cancel any stale requests that the client doesn't need any longer.
while (!pending_actor_tasks_.empty() &&
pending_actor_tasks_.begin()->first < next_seq_no_) {
@ -475,7 +461,7 @@ class ActorSchedulingQueue : public SchedulingQueue {
if (is_asyncio_) {
// Process async actor task.
fiber_state_->EnqueueFiber([request]() mutable { request.Accept(); });
} else if (pool_) {
} else if (pool_ != nullptr) {
// Process concurrent actor task.
pool_->PostBlocking([request]() mutable { request.Accept(); });
} else {
@ -518,8 +504,6 @@ class ActorSchedulingQueue : public SchedulingQueue {
}
}
// Worker context.
WorkerContext &worker_context_;
/// Max time in seconds to wait for dependencies to show up.
const int64_t reorder_wait_seconds_ = 0;
/// Sorted map of (accept, rej) task callbacks keyed by their sequence number.
@ -534,13 +518,13 @@ class ActorSchedulingQueue : public SchedulingQueue {
/// Reference to the waiter owned by the task receiver.
DependencyWaiter &waiter_;
/// If concurrent calls are allowed, holds the pool for executing these tasks.
std::unique_ptr<BoundedExecutor> pool_;
std::shared_ptr<BoundedExecutor> pool_;
/// Whether we should enqueue requests into asyncio pool. Setting this to true
/// will instantiate all tasks as fibers that can be yielded.
bool is_asyncio_ = false;
/// If use_asyncio_ is true, fiber_state_ contains the running state required
/// If is_asyncio_ is true, fiber_state_ contains the running state required
/// to enable continuation and work together with python asyncio.
std::unique_ptr<FiberState> fiber_state_;
std::shared_ptr<FiberState> fiber_state_;
friend class SchedulingQueueTest;
};
@ -667,6 +651,20 @@ class CoreWorkerDirectTaskReceiver {
// Queue of pending normal (non-actor) tasks.
std::unique_ptr<SchedulingQueue> normal_scheduling_queue_ =
std::unique_ptr<SchedulingQueue>(new NormalSchedulingQueue());
/// The max number of concurrent calls to allow.
/// 0 indicates that the value is not set yet.
int max_concurrency_ = 0;
/// If concurrent calls are allowed, holds the pool for executing these tasks.
std::shared_ptr<BoundedExecutor> pool_;
/// Whether this actor use asyncio for concurrency.
bool is_asyncio_ = false;
/// If use_asyncio_ is true, fiber_state_ contains the running state required
/// to enable continuation and work together with python asyncio.
std::shared_ptr<FiberState> fiber_state_;
/// Set the max concurrency of an actor.
/// This should be called once for the actor creation task.
void SetMaxActorConcurrency(bool is_asyncio, int max_concurrency);
};
} // namespace ray