diff --git a/BUILD.bazel b/BUILD.bazel index 3a635a948..ef4507940 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -405,6 +405,16 @@ cc_binary( ], ) +cc_test( + name = "direct_actor_transport_test", + srcs = ["src/ray/core_worker/test/direct_actor_transport_test.cc"], + copts = COPTS, + deps = [ + ":core_worker_lib", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "direct_task_transport_test", srcs = ["src/ray/core_worker/test/direct_task_transport_test.cc"], diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 401ebc0f5..1ce5bf856 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -1349,6 +1349,31 @@ def test_direct_actor_enabled(ray_start_regular): assert ray.get(obj_id) == 2 +def test_direct_actor_order(shutdown_only): + ray.init(num_cpus=4) + + @ray.remote + def small_value(): + time.sleep(0.01 * np.random.randint(0, 10)) + return 0 + + @ray.remote + class Actor(object): + def __init__(self): + self.count = 0 + + def inc(self, count, dependency): + assert count == self.count + self.count += 1 + return count + + a = Actor._remote(is_direct_call=True) + assert ray.get([ + a.inc.remote(i, small_value.options(is_direct_call=True).remote()) + for i in range(100) + ]) == list(range(100)) + + def test_direct_actor_large_objects(ray_start_regular): @ray.remote class Actor(object): diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index ebcc1f831..60fc373f9 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -10,6 +10,7 @@ #include "ray/util/util.h" namespace ray { + using WorkerType = rpc::WorkerType; // Return a string representation of the worker type. diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 5e424363a..36f58150e 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -62,37 +62,6 @@ void GroupObjectIdsByStoreProvider(const std::vector &object_ids, namespace ray { -// Prepare direct call args for sending to a direct call *actor*. Direct call actors -// always resolve their dependencies remotely, so we need some client-side preprocessing -// to ensure they don't try to resolve a direct call object ID remotely (which is -// impossible). -// - Direct call args that are local and small will be inlined. -// - Direct call args that are non-local or large will be promoted to plasma. -// Note that args for direct call *tasks* are handled by LocalDependencyResolver. -std::vector PrepareDirectActorCallArgs( - const std::vector &args, - std::shared_ptr memory_store) { - std::vector out; - for (const auto &arg : args) { - if (arg.IsPassedByReference() && arg.GetReference().IsDirectCallType()) { - const ObjectID &obj_id = arg.GetReference(); - // TODO(ekl) we should consider resolving these dependencies on the client side - // for actor calls. It is a little tricky since we have to also preserve the - // task ordering so we can't simply use LocalDependencyResolver. - std::shared_ptr obj = memory_store->GetOrPromoteToPlasma(obj_id); - if (obj != nullptr) { - out.push_back(TaskArg::PassByValue(obj)); - } else { - out.push_back(TaskArg::PassByReference( - obj_id.WithTransportType(TaskTransportType::RAYLET))); - } - } else { - out.push_back(arg); - } - } - return out; -} - CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, const std::string &store_socket, const std::string &raylet_socket, const JobID &job_id, const gcs::GcsClientOptions &gcs_options, @@ -221,17 +190,16 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, SetCurrentTaskId(task_id); } + auto client_factory = [this](const rpc::WorkerAddress &addr) { + return std::shared_ptr( + new rpc::CoreWorkerClient(addr.first, addr.second, *client_call_manager_)); + }; direct_actor_submitter_ = std::unique_ptr( - new CoreWorkerDirectActorTaskSubmitter(*client_call_manager_, - memory_store_provider_)); + new CoreWorkerDirectActorTaskSubmitter(client_factory, memory_store_provider_)); direct_task_submitter_ = std::unique_ptr(new CoreWorkerDirectTaskSubmitter( - raylet_client_, - [this](WorkerAddress addr) { - return std::shared_ptr(new rpc::CoreWorkerClient( - addr.first, addr.second, *client_call_manager_)); - }, + raylet_client_, client_factory, [this](const rpc::Address &address) { auto grpc_client = rpc::NodeManagerWorkerClient::make( address.ip_address(), address.port(), *client_call_manager_); @@ -657,11 +625,10 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f const TaskID actor_task_id = TaskID::ForActorTask( worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(), next_task_index, actor_handle->GetActorID()); - BuildCommonTaskSpec( - builder, actor_handle->CreationJobID(), actor_task_id, - worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_, - function, is_direct_call ? PrepareDirectActorCallArgs(args, memory_store_) : args, - num_returns, task_options.resources, {}, transport_type, return_ids); + BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id, + worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), + rpc_address_, function, args, num_returns, task_options.resources, + {}, transport_type, return_ids); const ObjectID new_cursor = return_ids->back(); actor_handle->SetActorTaskSpec(builder, transport_type, new_cursor); diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc new file mode 100644 index 000000000..0a8707f07 --- /dev/null +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -0,0 +1,130 @@ +#include "gtest/gtest.h" + +#include "ray/common/task/task_spec.h" +#include "ray/core_worker/store_provider/memory_store/memory_store.h" +#include "ray/core_worker/store_provider/memory_store_provider.h" +#include "ray/core_worker/transport/direct_task_transport.h" +#include "ray/raylet/raylet_client.h" +#include "ray/rpc/worker/core_worker_client.h" +#include "src/ray/util/test_util.h" + +namespace ray { + +class MockWorkerClient : public rpc::CoreWorkerClientInterface { + public: + ray::Status PushActorTask( + std::unique_ptr request, + const rpc::ClientCallback &callback) override { + RAY_CHECK(counter == request->task_spec().actor_task_spec().actor_counter()); + counter++; + callbacks.push_back(callback); + return Status::OK(); + } + + std::vector> callbacks; + uint64_t counter = 0; +}; + +TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) { + TaskSpecification task; + task.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); + task.GetMutableMessage().set_type(TaskType::ACTOR_TASK); + task.GetMutableMessage().mutable_actor_task_spec()->set_actor_id(actor_id.Binary()); + task.GetMutableMessage().mutable_actor_task_spec()->set_actor_counter(counter); + return task; +} + +class DirectActorTransportTest : public ::testing::Test { + public: + DirectActorTransportTest() + : worker_client_(std::shared_ptr(new MockWorkerClient())), + ptr_(std::shared_ptr(new CoreWorkerMemoryStore())), + store_(std::make_shared(ptr_)), + submitter_([&](const rpc::WorkerAddress &addr) { return worker_client_; }, + store_) {} + + std::shared_ptr worker_client_; + std::shared_ptr ptr_; + std::shared_ptr store_; + CoreWorkerDirectActorTaskSubmitter submitter_; +}; + +TEST_F(DirectActorTransportTest, TestSubmitTask) { + ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); + + auto task = CreateActorTaskHelper(actor_id, 0); + ASSERT_TRUE(submitter_.SubmitTask(task).ok()); + ASSERT_EQ(worker_client_->callbacks.size(), 0); + + gcs::ActorTableData actor_data; + submitter_.HandleActorUpdate(actor_id, actor_data); + ASSERT_EQ(worker_client_->callbacks.size(), 1); + + task = CreateActorTaskHelper(actor_id, 1); + ASSERT_TRUE(submitter_.SubmitTask(task).ok()); + ASSERT_EQ(worker_client_->callbacks.size(), 2); +} + +TEST_F(DirectActorTransportTest, TestDependencies) { + ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); + gcs::ActorTableData actor_data; + submitter_.HandleActorUpdate(actor_id, actor_data); + ASSERT_EQ(worker_client_->callbacks.size(), 0); + + // Create two tasks for the actor with different arguments. + ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + auto task1 = CreateActorTaskHelper(actor_id, 0); + task1.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); + auto task2 = CreateActorTaskHelper(actor_id, 1); + task2.GetMutableMessage().add_args()->add_object_ids(obj2.Binary()); + + // Neither task can be submitted yet because they are still waiting on + // dependencies. + ASSERT_TRUE(submitter_.SubmitTask(task1).ok()); + ASSERT_TRUE(submitter_.SubmitTask(task2).ok()); + ASSERT_EQ(worker_client_->callbacks.size(), 0); + + // Put the dependencies in the store in the same order as task submission. + auto data = GenerateRandomObject(); + ASSERT_TRUE(store_->Put(*data, obj1).ok()); + ASSERT_EQ(worker_client_->callbacks.size(), 1); + ASSERT_TRUE(store_->Put(*data, obj2).ok()); + ASSERT_EQ(worker_client_->callbacks.size(), 2); +} + +TEST_F(DirectActorTransportTest, TestOutOfOrderDependencies) { + ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); + gcs::ActorTableData actor_data; + submitter_.HandleActorUpdate(actor_id, actor_data); + ASSERT_EQ(worker_client_->callbacks.size(), 0); + + // Create two tasks for the actor with different arguments. + ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + auto task1 = CreateActorTaskHelper(actor_id, 0); + task1.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); + auto task2 = CreateActorTaskHelper(actor_id, 1); + task2.GetMutableMessage().add_args()->add_object_ids(obj2.Binary()); + + // Neither task can be submitted yet because they are still waiting on + // dependencies. + ASSERT_TRUE(submitter_.SubmitTask(task1).ok()); + ASSERT_TRUE(submitter_.SubmitTask(task2).ok()); + ASSERT_EQ(worker_client_->callbacks.size(), 0); + + // Put the dependencies in the store in the opposite order of task + // submission. + auto data = GenerateRandomObject(); + ASSERT_TRUE(store_->Put(*data, obj2).ok()); + ASSERT_EQ(worker_client_->callbacks.size(), 0); + ASSERT_TRUE(store_->Put(*data, obj1).ok()); + ASSERT_EQ(worker_client_->callbacks.size(), 2); +} + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 4ffcb93bf..1bb35d0f0 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -192,7 +192,7 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) { auto worker_client = std::shared_ptr(new MockWorkerClient()); auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); auto store = std::make_shared(ptr); - auto factory = [&](WorkerAddress addr) { return worker_client; }; + auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task; task.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); @@ -214,7 +214,7 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) { auto worker_client = std::shared_ptr(new MockWorkerClient()); auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); auto store = std::make_shared(ptr); - auto factory = [&](WorkerAddress addr) { return worker_client; }; + auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task; task.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); @@ -232,7 +232,7 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { auto worker_client = std::shared_ptr(new MockWorkerClient()); auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); auto store = std::make_shared(ptr); - auto factory = [&](WorkerAddress addr) { return worker_client; }; + auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task1; TaskSpecification task2; @@ -273,7 +273,7 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) { auto worker_client = std::shared_ptr(new MockWorkerClient()); auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); auto store = std::make_shared(ptr); - auto factory = [&](WorkerAddress addr) { return worker_client; }; + auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task1; TaskSpecification task2; @@ -316,7 +316,7 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { auto worker_client = std::shared_ptr(new MockWorkerClient()); auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); auto store = std::make_shared(ptr); - auto factory = [&](WorkerAddress addr) { return worker_client; }; + auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task1; TaskSpecification task2; @@ -349,7 +349,7 @@ TEST(DirectTaskTransportTest, TestSpillback) { auto worker_client = std::shared_ptr(new MockWorkerClient()); auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); auto store = std::make_shared(ptr); - auto factory = [&](WorkerAddress addr) { return worker_client; }; + auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; }; std::unordered_map> remote_lease_clients; auto lease_client_factory = [&](const rpc::Address &addr) { diff --git a/src/ray/core_worker/transport/dependency_resolver.cc b/src/ray/core_worker/transport/dependency_resolver.cc new file mode 100644 index 000000000..a30acade8 --- /dev/null +++ b/src/ray/core_worker/transport/dependency_resolver.cc @@ -0,0 +1,89 @@ +#include "ray/core_worker/transport/dependency_resolver.h" + +namespace ray { + +struct TaskState { + /// The task to be run. + TaskSpecification task; + /// The remaining dependencies to resolve for this task. + absl::flat_hash_set local_dependencies; +}; + +void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr value, + TaskSpecification &task) { + auto &msg = task.GetMutableMessage(); + bool found = false; + for (size_t i = 0; i < task.NumArgs(); i++) { + auto count = task.ArgIdCount(i); + if (count > 0) { + const auto &id = task.ArgId(i, 0); + if (id == obj_id) { + auto *mutable_arg = msg.mutable_args(i); + mutable_arg->clear_object_ids(); + if (value->IsInPlasmaError()) { + // Promote the object id to plasma. + mutable_arg->add_object_ids( + obj_id.WithTransportType(TaskTransportType::RAYLET).Binary()); + } else { + // Inline the object value. + if (value->HasData()) { + const auto &data = value->GetData(); + mutable_arg->set_data(data->Data(), data->Size()); + } + if (value->HasMetadata()) { + const auto &metadata = value->GetMetadata(); + mutable_arg->set_metadata(metadata->Data(), metadata->Size()); + } + } + found = true; + } + } + } + RAY_CHECK(found) << "obj id " << obj_id << " not found"; +} + +void LocalDependencyResolver::ResolveDependencies(const TaskSpecification &task, + std::function on_complete) { + absl::flat_hash_set local_dependencies; + for (size_t i = 0; i < task.NumArgs(); i++) { + auto count = task.ArgIdCount(i); + if (count > 0) { + RAY_CHECK(count <= 1) << "multi args not implemented"; + const auto &id = task.ArgId(i, 0); + if (id.IsDirectCallType()) { + local_dependencies.insert(id); + } + } + } + if (local_dependencies.empty()) { + on_complete(); + return; + } + + // This is deleted when the last dependency fetch callback finishes. + std::shared_ptr state = + std::shared_ptr(new TaskState{task, std::move(local_dependencies)}); + num_pending_ += 1; + + for (const auto &obj_id : state->local_dependencies) { + in_memory_store_->GetAsync( + obj_id, [this, state, obj_id, on_complete](std::shared_ptr obj) { + RAY_CHECK(obj != nullptr); + bool complete = false; + { + absl::MutexLock lock(&mu_); + state->local_dependencies.erase(obj_id); + DoInlineObjectValue(obj_id, obj, state->task); + if (state->local_dependencies.empty()) { + complete = true; + num_pending_ -= 1; + } + } + if (complete) { + on_complete(); + } + }); + } +} + +} // namespace ray diff --git a/src/ray/core_worker/transport/dependency_resolver.h b/src/ray/core_worker/transport/dependency_resolver.h new file mode 100644 index 000000000..f6ee49d9d --- /dev/null +++ b/src/ray/core_worker/transport/dependency_resolver.h @@ -0,0 +1,45 @@ +#ifndef RAY_CORE_WORKER_DEPENDENCY_RESOLVER_H +#define RAY_CORE_WORKER_DEPENDENCY_RESOLVER_H + +#include + +#include "ray/common/id.h" +#include "ray/common/task/task_spec.h" +#include "ray/core_worker/store_provider/memory_store_provider.h" + +namespace ray { + +// This class is thread-safe. +class LocalDependencyResolver { + public: + LocalDependencyResolver(std::shared_ptr store_provider) + : in_memory_store_(store_provider), num_pending_(0) {} + + /// Resolve all local and remote dependencies for the task, calling the specified + /// callback when done. Direct call ids in the task specification will be resolved + /// to concrete values and inlined. + // + /// Note: This method **will mutate** the given TaskSpecification. + /// + /// Postcondition: all direct call ids in arguments are converted to values. + void ResolveDependencies(const TaskSpecification &task, + std::function on_complete); + + /// Return the number of tasks pending dependency resolution. + /// TODO(ekl) this should be exposed in worker stats. + int NumPendingTasks() const { return num_pending_; } + + private: + /// The store provider. + std::shared_ptr in_memory_store_; + + /// Number of tasks pending dependency resolution. + std::atomic num_pending_; + + /// Protects against concurrent access to internal state. + absl::Mutex mu_; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_DEPENDENCY_RESOLVER_H diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 7570c6db0..539808b65 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -5,6 +5,10 @@ using ray::rpc::ActorTableData; namespace ray { +int64_t GetRequestNumber(const std::unique_ptr &request) { + return request->task_spec().actor_task_spec().actor_counter(); +} + void TreatTaskAsFailed(const TaskID &task_id, int num_returns, const rpc::ErrorType &error_type, std::shared_ptr &in_memory_store) { @@ -57,52 +61,45 @@ void WriteObjectsToMemoryStore( } } -CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter( - rpc::ClientCallManager &client_call_manager, - std::shared_ptr store_provider) - : client_call_manager_(client_call_manager), in_memory_store_(store_provider) {} - Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) { RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId(); - RAY_CHECK(task_spec.IsActorTask()); - const auto &actor_id = task_spec.ActorId(); - const auto task_id = task_spec.TaskId(); - const auto num_returns = task_spec.NumReturns(); + resolver_.ResolveDependencies(task_spec, [this, task_spec]() mutable { + const auto &actor_id = task_spec.ActorId(); + const auto task_id = task_spec.TaskId(); + const auto num_returns = task_spec.NumReturns(); - auto request = std::unique_ptr(new rpc::PushTaskRequest); - request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage()); + auto request = std::unique_ptr(new rpc::PushTaskRequest); + request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage()); - std::unique_lock guard(mutex_); + std::unique_lock guard(mutex_); - auto iter = actor_states_.find(actor_id); - if (iter == actor_states_.end() || - iter->second.state_ == ActorTableData::RECONSTRUCTING) { - // Actor is not yet created, or is being reconstructed, cache the request - // and submit after actor is alive. - // TODO(zhijunfu): it might be possible for a user to specify an invalid - // actor handle (e.g. from unpickling), in that case it might be desirable - // to have a timeout to mark it as invalid if it doesn't show up in the - // specified time. - pending_requests_[actor_id].emplace_back(std::move(request)); - RAY_LOG(DEBUG) << "Actor " << actor_id << " is not yet created."; - } else if (iter->second.state_ == ActorTableData::ALIVE) { - // Actor is alive, submit the request. - if (rpc_clients_.count(actor_id) == 0) { - // If rpc client is not available, then create it. - ConnectAndSendPendingTasks(actor_id, iter->second.location_.first, - iter->second.location_.second); + auto iter = actor_states_.find(actor_id); + if (iter == actor_states_.end() || + iter->second.state_ == ActorTableData::RECONSTRUCTING) { + // Actor is not yet created, or is being reconstructed, cache the request + // and submit after actor is alive. + // TODO(zhijunfu): it might be possible for a user to specify an invalid + // actor handle (e.g. from unpickling), in that case it might be desirable + // to have a timeout to mark it as invalid if it doesn't show up in the + // specified time. + auto inserted = pending_requests_[actor_id].emplace(GetRequestNumber(request), + std::move(request)); + RAY_CHECK(inserted.second); + RAY_LOG(DEBUG) << "Actor " << actor_id << " is not yet created."; + } else if (iter->second.state_ == ActorTableData::ALIVE) { + auto inserted = pending_requests_[actor_id].emplace(GetRequestNumber(request), + std::move(request)); + RAY_CHECK(inserted.second); + SendPendingTasks(actor_id); + } else { + // Actor is dead, treat the task as failure. + RAY_CHECK(iter->second.state_ == ActorTableData::DEAD); + TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, + in_memory_store_); } - - // Submit request. - auto &client = rpc_clients_[actor_id]; - PushActorTask(*client, std::move(request), actor_id, task_id, num_returns); - } else { - // Actor is dead, treat the task as failure. - RAY_CHECK(iter->second.state_ == ActorTableData::DEAD); - TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, in_memory_store_); - } + }); // If the task submission subsequently fails, then the client will receive // the error in a callback. @@ -118,12 +115,15 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate( actor_data.address().port())); if (actor_data.state() == ActorTableData::ALIVE) { - // Check if this actor is the one that we're interested, if we already have - // a connection to the actor, or have pending requests for it, we should - // create a new connection. - if (pending_requests_.count(actor_id) > 0 && rpc_clients_.count(actor_id) == 0) { - ConnectAndSendPendingTasks(actor_id, actor_data.address().ip_address(), - actor_data.address().port()); + // Create a new connection to the actor. + if (rpc_clients_.count(actor_id) == 0) { + rpc::WorkerAddress addr = {actor_data.address().ip_address(), + actor_data.address().port()}; + rpc_clients_[actor_id] = + std::shared_ptr(client_factory_(addr)); + } + if (pending_requests_.count(actor_id) > 0) { + SendPendingTasks(actor_id); } } else { // Remove rpc client if it's dead or being reconstructed. @@ -145,40 +145,49 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate( // If there are pending requests, treat the pending tasks as failed. auto pending_it = pending_requests_.find(actor_id); if (pending_it != pending_requests_.end()) { - for (const auto &request : pending_it->second) { + auto head = pending_it->second.begin(); + while (head != pending_it->second.end()) { + auto request = std::move(head->second); + head = pending_it->second.erase(head); + TreatTaskAsFailed(TaskID::FromBinary(request->task_spec().task_id()), request->task_spec().num_returns(), rpc::ErrorType::ACTOR_DIED, in_memory_store_); } pending_requests_.erase(pending_it); } + + next_sequence_number_.erase(actor_id); } } -void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks( - const ActorID &actor_id, std::string ip_address, int port) { - std::shared_ptr grpc_client = - std::make_shared(ip_address, port, client_call_manager_); - RAY_CHECK(rpc_clients_.emplace(actor_id, std::move(grpc_client)).second); - - // Submit all pending requests. +void CoreWorkerDirectActorTaskSubmitter::SendPendingTasks(const ActorID &actor_id) { auto &client = rpc_clients_[actor_id]; + RAY_CHECK(client); + // Submit all pending requests. auto &requests = pending_requests_[actor_id]; - while (!requests.empty()) { - auto request = std::move(requests.front()); + auto head = requests.begin(); + while (head != requests.end() && head->first == next_sequence_number_[actor_id]) { + auto request = std::move(head->second); + head = requests.erase(head); + auto num_returns = request->task_spec().num_returns(); auto task_id = TaskID::FromBinary(request->task_spec().task_id()); PushActorTask(*client, std::move(request), actor_id, task_id, num_returns); - requests.pop_front(); } } void CoreWorkerDirectActorTaskSubmitter::PushActorTask( - rpc::CoreWorkerClient &client, std::unique_ptr request, + rpc::CoreWorkerClientInterface &client, std::unique_ptr request, const ActorID &actor_id, const TaskID &task_id, int num_returns) { RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id; waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns)); + auto task_number = GetRequestNumber(request); + RAY_CHECK(next_sequence_number_[actor_id] == task_number) + << "Counter was " << task_number << " expected " << next_sequence_number_[actor_id]; + next_sequence_number_[actor_id]++; + auto status = client.PushActorTask( std::move(request), [this, actor_id, task_id, num_returns]( Status status, const rpc::PushTaskReply &reply) { diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 9c116c728..5b0b3acf2 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -4,8 +4,10 @@ #include #include #include +#include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" @@ -13,10 +15,13 @@ #include "ray/common/ray_object.h" #include "ray/core_worker/context.h" #include "ray/core_worker/store_provider/memory_store_provider.h" +#include "ray/core_worker/transport/dependency_resolver.h" #include "ray/gcs/redis_gcs_client.h" #include "ray/rpc/grpc_server.h" #include "ray/rpc/worker/core_worker_client.h" +namespace {} // namespace + namespace ray { /// The max time to wait for out-of-order tasks. @@ -61,8 +66,11 @@ struct ActorStateData { class CoreWorkerDirectActorTaskSubmitter { public: CoreWorkerDirectActorTaskSubmitter( - rpc::ClientCallManager &client_call_manager, - std::shared_ptr store_provider); + rpc::ClientFactoryFn client_factory, + std::shared_ptr store_provider) + : client_factory_(client_factory), + in_memory_store_(store_provider), + resolver_(in_memory_store_) {} /// Submit a task to an actor for execution. /// @@ -87,20 +95,17 @@ class CoreWorkerDirectActorTaskSubmitter { /// \param[in] task_id The ID of a task. /// \param[in] num_returns Number of return objects. /// \return Void. - void PushActorTask(rpc::CoreWorkerClient &client, + void PushActorTask(rpc::CoreWorkerClientInterface &client, std::unique_ptr request, const ActorID &actor_id, const TaskID &task_id, int num_returns); - /// Create connection to actor and send all pending tasks. + /// Send all pending tasks for an actor. /// Note that this function doesn't take lock, the caller is expected to hold /// `mutex_` before calling this function. /// /// \param[in] actor_id Actor ID. - /// \param[in] ip_address The ip address of the node that the actor is running on. - /// \param[in] port The port that the actor is listening on. /// \return Void. - void ConnectAndSendPendingTasks(const ActorID &actor_id, std::string ip_address, - int port); + void SendPendingTasks(const ActorID &actor_id); /// Whether the specified actor is alive. /// @@ -108,8 +113,8 @@ class CoreWorkerDirectActorTaskSubmitter { /// \return Whether this actor is alive. bool IsActorAlive(const ActorID &actor_id) const; - /// The shared `ClientCallManager` object. - rpc::ClientCallManager &client_call_manager_; + /// Factory for producing new core worker clients. + rpc::ClientFactoryFn client_factory_; /// Mutex to proect the various maps below. mutable std::mutex mutex_; @@ -122,18 +127,27 @@ class CoreWorkerDirectActorTaskSubmitter { /// /// TODO(zhijunfu): this will be moved into `actor_states_` later when we can /// subscribe updates for a specific actor. - std::unordered_map> rpc_clients_; + std::unordered_map> + rpc_clients_; - /// Map from actor id to the actor's pending requests. - std::unordered_map>> + /// Map from actor id to the actor's pending requests. Each actor's requests + /// are ordered by the task number in the request. + absl::flat_hash_map>> pending_requests_; + /// Map from actor id to the sequence number of the next task to send to that + /// actor. + std::unordered_map next_sequence_number_; + /// Map from actor id to the tasks that are waiting for reply. std::unordered_map> waiting_reply_tasks_; /// The store provider. std::shared_ptr in_memory_store_; + /// Resolve direct call object dependencies; + LocalDependencyResolver resolver_; + friend class CoreWorkerTest; }; diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index 069c7154c..d5d1316e3 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -1,85 +1,9 @@ #include "ray/core_worker/transport/direct_task_transport.h" +#include "ray/core_worker/transport/dependency_resolver.h" #include "ray/core_worker/transport/direct_actor_transport.h" namespace ray { -void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr value, - TaskSpecification &task) { - auto &msg = task.GetMutableMessage(); - bool found = false; - for (size_t i = 0; i < task.NumArgs(); i++) { - auto count = task.ArgIdCount(i); - if (count > 0) { - const auto &id = task.ArgId(i, 0); - if (id == obj_id) { - auto *mutable_arg = msg.mutable_args(i); - mutable_arg->clear_object_ids(); - if (value->IsInPlasmaError()) { - // Promote the object id to plasma. - mutable_arg->add_object_ids( - obj_id.WithTransportType(TaskTransportType::RAYLET).Binary()); - } else { - // Inline the object value. - if (value->HasData()) { - const auto &data = value->GetData(); - mutable_arg->set_data(data->Data(), data->Size()); - } - if (value->HasMetadata()) { - const auto &metadata = value->GetMetadata(); - mutable_arg->set_metadata(metadata->Data(), metadata->Size()); - } - } - found = true; - } - } - } - RAY_CHECK(found) << "obj id " << obj_id << " not found"; -} - -void LocalDependencyResolver::ResolveDependencies(const TaskSpecification &task, - std::function on_complete) { - absl::flat_hash_set local_dependencies; - for (size_t i = 0; i < task.NumArgs(); i++) { - auto count = task.ArgIdCount(i); - if (count > 0) { - RAY_CHECK(count <= 1) << "multi args not implemented"; - const auto &id = task.ArgId(i, 0); - if (id.IsDirectCallType()) { - local_dependencies.insert(id); - } - } - } - if (local_dependencies.empty()) { - on_complete(); - return; - } - - // This is deleted when the last dependency fetch callback finishes. - std::shared_ptr state = - std::shared_ptr(new TaskState{task, std::move(local_dependencies)}); - num_pending_ += 1; - - for (const auto &obj_id : state->local_dependencies) { - in_memory_store_->GetAsync( - obj_id, [this, state, obj_id, on_complete](std::shared_ptr obj) { - RAY_CHECK(obj != nullptr); - bool complete = false; - { - absl::MutexLock lock(&mu_); - state->local_dependencies.erase(obj_id); - DoInlineObjectValue(obj_id, obj, state->task); - if (state->local_dependencies.empty()) { - complete = true; - num_pending_ -= 1; - } - } - if (complete) { - on_complete(); - } - }); - } -} - Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) { resolver_.ResolveDependencies(task_spec, [this, task_spec]() { // TODO(ekl) should have a queue per distinct resource type required @@ -91,7 +15,7 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) { } void CoreWorkerDirectTaskSubmitter::HandleWorkerLeaseGranted( - const WorkerAddress &addr, std::shared_ptr lease_client) { + const rpc::WorkerAddress &addr, std::shared_ptr lease_client) { // Setup client state for this worker. { absl::MutexLock lock(&mu_); @@ -110,7 +34,7 @@ void CoreWorkerDirectTaskSubmitter::HandleWorkerLeaseGranted( OnWorkerIdle(addr, /*error=*/false); } -void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(const WorkerAddress &addr, +void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(const rpc::WorkerAddress &addr, bool was_error) { absl::MutexLock lock(&mu_); if (queued_tasks_.empty() || was_error) { @@ -199,7 +123,7 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( worker_request_pending_ = true; } -void CoreWorkerDirectTaskSubmitter::PushNormalTask(const WorkerAddress &addr, +void CoreWorkerDirectTaskSubmitter::PushNormalTask(const rpc::WorkerAddress &addr, rpc::CoreWorkerClientInterface &client, TaskSpecification &task_spec) { auto task_id = task_spec.TaskId(); diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index fabedc903..079fa5958 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -7,53 +7,13 @@ #include "ray/common/ray_object.h" #include "ray/core_worker/context.h" #include "ray/core_worker/store_provider/memory_store_provider.h" +#include "ray/core_worker/transport/dependency_resolver.h" #include "ray/core_worker/transport/direct_actor_transport.h" #include "ray/raylet/raylet_client.h" #include "ray/rpc/worker/core_worker_client.h" namespace ray { -struct TaskState { - /// The task to be run. - TaskSpecification task; - /// The remaining dependencies to resolve for this task. - absl::flat_hash_set local_dependencies; -}; - -// This class is thread-safe. -class LocalDependencyResolver { - public: - LocalDependencyResolver(std::shared_ptr store_provider) - : in_memory_store_(store_provider), num_pending_(0) {} - - /// Resolve all local and remote dependencies for the task, calling the specified - /// callback when done. Direct call ids in the task specification will be resolved - /// to concrete values and inlined. - // - /// Note: This method **will mutate** the given TaskSpecification. - /// - /// Postcondition: all direct call ids in arguments are converted to values. - void ResolveDependencies(const TaskSpecification &task, - std::function on_complete); - - /// Return the number of tasks pending dependency resolution. - /// TODO(ekl) this should be exposed in worker stats. - int NumPendingTasks() const { return num_pending_; } - - private: - /// The store provider. - std::shared_ptr in_memory_store_; - - /// Number of tasks pending dependency resolution. - std::atomic num_pending_; - - /// Protects against concurrent access to internal state. - absl::Mutex mu_; -}; - -typedef std::pair WorkerAddress; -typedef std::function(WorkerAddress)> - ClientFactoryFn; typedef std::function(const rpc::Address &)> LeaseClientFactoryFn; @@ -61,8 +21,8 @@ typedef std::function(const rpc::Address & class CoreWorkerDirectTaskSubmitter { public: CoreWorkerDirectTaskSubmitter( - std::shared_ptr lease_client, ClientFactoryFn client_factory, - LeaseClientFactoryFn lease_client_factory, + std::shared_ptr lease_client, + rpc::ClientFactoryFn client_factory, LeaseClientFactoryFn lease_client_factory, std::shared_ptr store_provider) : local_lease_client_(lease_client), client_factory_(client_factory), @@ -79,7 +39,7 @@ class CoreWorkerDirectTaskSubmitter { /// Schedule more work onto an idle worker or return it back to the raylet if /// no more tasks are queued for submission. If an error was encountered /// processing the worker, we don't attempt to re-use the worker. - void OnWorkerIdle(const WorkerAddress &addr, bool was_error); + void OnWorkerIdle(const rpc::WorkerAddress &addr, bool was_error); /// Get an existing lease client or connect a new one. If a raylet_address is /// provided, this connects to a remote raylet. Else, this connects to the @@ -98,11 +58,12 @@ class CoreWorkerDirectTaskSubmitter { /// Callback for when the raylet grants us a worker lease. The worker is returned /// to the raylet via the given lease client once the task queue is empty. /// TODO: Implement a lease term by which we need to return the worker. - void HandleWorkerLeaseGranted(const WorkerAddress &addr, + void HandleWorkerLeaseGranted(const rpc::WorkerAddress &addr, std::shared_ptr lease_client); /// Push a task to a specific worker. - void PushNormalTask(const WorkerAddress &addr, rpc::CoreWorkerClientInterface &client, + void PushNormalTask(const rpc::WorkerAddress &addr, + rpc::CoreWorkerClientInterface &client, TaskSpecification &task_spec); // Client that can be used to lease and return workers from the local raylet. @@ -113,7 +74,7 @@ class CoreWorkerDirectTaskSubmitter { remote_lease_clients_ GUARDED_BY(mu_); /// Factory for producing new core worker clients. - ClientFactoryFn client_factory_; + rpc::ClientFactoryFn client_factory_; /// Factory for producing new clients to request leases from remote nodes. LeaseClientFactoryFn lease_client_factory_; @@ -128,12 +89,12 @@ class CoreWorkerDirectTaskSubmitter { absl::Mutex mu_; /// Cache of gRPC clients to other workers. - absl::flat_hash_map> + absl::flat_hash_map> client_cache_ GUARDED_BY(mu_); /// Map from worker address to the lease client through which it should be /// returned. - absl::flat_hash_map> + absl::flat_hash_map> worker_to_lease_client_ GUARDED_BY(mu_); // Whether we have a request to the Raylet to acquire a new worker in flight. diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index c58573371..40ad8f808 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -33,6 +33,13 @@ const static int64_t RequestSizeInBytes(const PushTaskRequest &request) { return size; } +// Shared between direct actor and task submitters. +// TODO(swang): Remove and replace with rpc::Address. +class CoreWorkerClientInterface; +typedef std::pair WorkerAddress; +typedef std::function(const WorkerAddress &)> + ClientFactoryFn; + /// Abstract client interface for testing. class CoreWorkerClientInterface { public: