diff --git a/python/ray/tests/test_basic_2.py b/python/ray/tests/test_basic_2.py index 48676da1f..5f6ee8370 100644 --- a/python/ray/tests/test_basic_2.py +++ b/python/ray/tests/test_basic_2.py @@ -561,6 +561,39 @@ def test_actor_concurrent(ray_start_regular_shared): assert r1 == r2 == r3 +def test_actor_max_concurrency(ray_start_regular_shared): + """ + Test that an actor of max_concurrency=N should only run + N tasks at most concurrently. + """ + CONCURRENCY = 3 + + @ray.remote + class ConcurentActor: + def __init__(self): + self.threads = set() + + def call(self): + # Record the current thread that runs this function. + self.threads.add(threading.current_thread()) + + def get_num_threads(self): + return len(self.threads) + + @ray.remote + def call(actor): + for _ in range(CONCURRENCY * 100): + ray.get(actor.call.remote()) + return + + actor = ConcurentActor.options(max_concurrency=CONCURRENCY).remote() + # Start many concurrent tasks that will call the actor many times. + ray.get([call.remote(actor) for _ in range(CONCURRENCY * 10)]) + + # Check that the number of threads shouldn't be greater than CONCURRENCY. + assert ray.get(actor.get_num_threads.remote()) <= CONCURRENCY + + def test_wait(ray_start_regular_shared): @ray.remote def f(delay): diff --git a/src/ray/core_worker/test/scheduling_queue_test.cc b/src/ray/core_worker/test/scheduling_queue_test.cc index c31d428ef..9f3066965 100644 --- a/src/ray/core_worker/test/scheduling_queue_test.cc +++ b/src/ray/core_worker/test/scheduling_queue_test.cc @@ -39,8 +39,7 @@ class MockWaiter : public DependencyWaiter { TEST(SchedulingQueueTest, TestInOrder) { instrumented_io_context io_service; MockWaiter waiter; - WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); - ActorSchedulingQueue queue(io_service, waiter, context); + ActorSchedulingQueue queue(io_service, waiter); int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok]() { n_ok++; }; @@ -60,8 +59,7 @@ TEST(SchedulingQueueTest, TestWaitForObjects) { ObjectID obj3 = ObjectID::FromRandom(); instrumented_io_context io_service; MockWaiter waiter; - WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); - ActorSchedulingQueue queue(io_service, waiter, context); + ActorSchedulingQueue queue(io_service, waiter); int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok]() { n_ok++; }; @@ -86,8 +84,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) { ObjectID obj1 = ObjectID::FromRandom(); instrumented_io_context io_service; MockWaiter waiter; - WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); - ActorSchedulingQueue queue(io_service, waiter, context); + ActorSchedulingQueue queue(io_service, waiter); int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok]() { n_ok++; }; @@ -104,8 +101,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) { TEST(SchedulingQueueTest, TestOutOfOrder) { instrumented_io_context io_service; MockWaiter waiter; - WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); - ActorSchedulingQueue queue(io_service, waiter, context); + ActorSchedulingQueue queue(io_service, waiter); int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok]() { n_ok++; }; @@ -122,8 +118,7 @@ TEST(SchedulingQueueTest, TestOutOfOrder) { TEST(SchedulingQueueTest, TestSeqWaitTimeout) { instrumented_io_context io_service; MockWaiter waiter; - WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); - ActorSchedulingQueue queue(io_service, waiter, context); + ActorSchedulingQueue queue(io_service, waiter); int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok]() { n_ok++; }; @@ -145,8 +140,7 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) { TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) { instrumented_io_context io_service; MockWaiter waiter; - WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); - ActorSchedulingQueue queue(io_service, waiter, context); + ActorSchedulingQueue queue(io_service, waiter); int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok]() { n_ok++; }; diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index b6b21ce54..a8c5ae00a 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -437,6 +437,10 @@ void CoreWorkerDirectTaskReceiver::HandleTask( return; } + if (task_spec.IsActorCreationTask()) { + SetMaxActorConcurrency(task_spec.IsAsyncioActor(), task_spec.MaxActorConcurrency()); + } + // Only assign resources for non-actor tasks. Actor tasks inherit the resources // assigned at initial actor creation time. std::shared_ptr resource_ids; @@ -530,7 +534,7 @@ void CoreWorkerDirectTaskReceiver::HandleTask( auto result = actor_scheduling_queues_.emplace( task_spec.CallerWorkerId(), std::unique_ptr(new ActorSchedulingQueue( - task_main_io_service_, *waiter_, worker_context_))); + task_main_io_service_, *waiter_, pool_, is_asyncio_, fiber_state_))); it = result.first; } @@ -563,4 +567,25 @@ bool CoreWorkerDirectTaskReceiver::CancelQueuedNormalTask(TaskID task_id) { return normal_scheduling_queue_->CancelTaskIfFound(task_id); } +void CoreWorkerDirectTaskReceiver::SetMaxActorConcurrency(bool is_asyncio, + int max_concurrency) { + RAY_CHECK(max_concurrency_ == 0) + << "SetMaxActorConcurrency should only be called at most once."; + RAY_CHECK(fiber_state_ == nullptr); + RAY_CHECK(pool_ == nullptr); + RAY_CHECK(max_concurrency >= 1); + if (max_concurrency > 1) { + max_concurrency_ = max_concurrency; + is_asyncio_ = is_asyncio; + if (is_asyncio_) { + RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency; + fiber_state_.reset(new FiberState(max_concurrency)); + } else { + RAY_LOG(INFO) << "Setting actor as async with max_concurrency=" << max_concurrency + << ", creating new fiber thread."; + pool_.reset(new BoundedExecutor(max_concurrency)); + } + } +} + } // namespace ray diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 2fca3cac9..02cf374e8 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -386,13 +386,17 @@ class SchedulingQueue { class ActorSchedulingQueue : public SchedulingQueue { public: ActorSchedulingQueue(instrumented_io_context &main_io_service, DependencyWaiter &waiter, - WorkerContext &worker_context, + std::shared_ptr pool = nullptr, + bool is_asyncio = false, + std::shared_ptr fiber_state = nullptr, int64_t reorder_wait_seconds = kMaxReorderWaitSeconds) - : worker_context_(worker_context), - reorder_wait_seconds_(reorder_wait_seconds), + : reorder_wait_seconds_(reorder_wait_seconds), wait_timer_(main_io_service), main_thread_id_(boost::this_thread::get_id()), - waiter_(waiter) {} + waiter_(waiter), + pool_(pool), + is_asyncio_(is_asyncio), + fiber_state_(fiber_state) {} bool TaskQueueEmpty() const { return pending_actor_tasks_.empty(); } @@ -437,24 +441,6 @@ class ActorSchedulingQueue : public SchedulingQueue { /// Schedules as many requests as possible in sequence. void ScheduleRequests() { - // Only call SetMaxActorConcurrency to configure threadpool size when the - // actor is not async actor. Async actor is single threaded. - int max_concurrency = worker_context_.CurrentActorMaxConcurrency(); - if (worker_context_.CurrentActorIsAsync()) { - // If this is an async actor, initialize the fiber state once. - if (!is_asyncio_) { - RAY_LOG(DEBUG) << "Setting direct actor as async, creating new fiber thread."; - fiber_state_.reset(new FiberState(max_concurrency)); - is_asyncio_ = true; - } - } else { - // If this is a concurrency actor (not async), initialize the thread pool once. - if (max_concurrency != 1 && !pool_) { - RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency; - pool_.reset(new BoundedExecutor(max_concurrency)); - } - } - // Cancel any stale requests that the client doesn't need any longer. while (!pending_actor_tasks_.empty() && pending_actor_tasks_.begin()->first < next_seq_no_) { @@ -475,7 +461,7 @@ class ActorSchedulingQueue : public SchedulingQueue { if (is_asyncio_) { // Process async actor task. fiber_state_->EnqueueFiber([request]() mutable { request.Accept(); }); - } else if (pool_) { + } else if (pool_ != nullptr) { // Process concurrent actor task. pool_->PostBlocking([request]() mutable { request.Accept(); }); } else { @@ -518,8 +504,6 @@ class ActorSchedulingQueue : public SchedulingQueue { } } - // Worker context. - WorkerContext &worker_context_; /// Max time in seconds to wait for dependencies to show up. const int64_t reorder_wait_seconds_ = 0; /// Sorted map of (accept, rej) task callbacks keyed by their sequence number. @@ -534,13 +518,13 @@ class ActorSchedulingQueue : public SchedulingQueue { /// Reference to the waiter owned by the task receiver. DependencyWaiter &waiter_; /// If concurrent calls are allowed, holds the pool for executing these tasks. - std::unique_ptr pool_; + std::shared_ptr pool_; /// Whether we should enqueue requests into asyncio pool. Setting this to true /// will instantiate all tasks as fibers that can be yielded. bool is_asyncio_ = false; - /// If use_asyncio_ is true, fiber_state_ contains the running state required + /// If is_asyncio_ is true, fiber_state_ contains the running state required /// to enable continuation and work together with python asyncio. - std::unique_ptr fiber_state_; + std::shared_ptr fiber_state_; friend class SchedulingQueueTest; }; @@ -667,6 +651,20 @@ class CoreWorkerDirectTaskReceiver { // Queue of pending normal (non-actor) tasks. std::unique_ptr normal_scheduling_queue_ = std::unique_ptr(new NormalSchedulingQueue()); + /// The max number of concurrent calls to allow. + /// 0 indicates that the value is not set yet. + int max_concurrency_ = 0; + /// If concurrent calls are allowed, holds the pool for executing these tasks. + std::shared_ptr pool_; + /// Whether this actor use asyncio for concurrency. + bool is_asyncio_ = false; + /// If use_asyncio_ is true, fiber_state_ contains the running state required + /// to enable continuation and work together with python asyncio. + std::shared_ptr fiber_state_; + + /// Set the max concurrency of an actor. + /// This should be called once for the actor creation task. + void SetMaxActorConcurrency(bool is_asyncio, int max_concurrency); }; } // namespace ray