Resolve dependencies locally before submitting direct actor tasks (#6191)

* Priority queue in direct actor transport by task number

* Move LocalDependencyResolver out to separate file, share with direct actor transport

* works

* Test case for ordering

* Cleanups

* Remove priority queue

* comment

* Share ClientFactoryFn with direct actor transport

* Unit test

* fix
This commit is contained in:
Stephanie Wang 2019-11-20 16:45:19 -08:00 committed by GitHub
parent 33c768ebe4
commit c0be9e6738
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 429 additions and 247 deletions

View file

@ -405,6 +405,16 @@ cc_binary(
],
)
cc_test(
name = "direct_actor_transport_test",
srcs = ["src/ray/core_worker/test/direct_actor_transport_test.cc"],
copts = COPTS,
deps = [
":core_worker_lib",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "direct_task_transport_test",
srcs = ["src/ray/core_worker/test/direct_task_transport_test.cc"],

View file

@ -1349,6 +1349,31 @@ def test_direct_actor_enabled(ray_start_regular):
assert ray.get(obj_id) == 2
def test_direct_actor_order(shutdown_only):
ray.init(num_cpus=4)
@ray.remote
def small_value():
time.sleep(0.01 * np.random.randint(0, 10))
return 0
@ray.remote
class Actor(object):
def __init__(self):
self.count = 0
def inc(self, count, dependency):
assert count == self.count
self.count += 1
return count
a = Actor._remote(is_direct_call=True)
assert ray.get([
a.inc.remote(i, small_value.options(is_direct_call=True).remote())
for i in range(100)
]) == list(range(100))
def test_direct_actor_large_objects(ray_start_regular):
@ray.remote
class Actor(object):

View file

@ -10,6 +10,7 @@
#include "ray/util/util.h"
namespace ray {
using WorkerType = rpc::WorkerType;
// Return a string representation of the worker type.

View file

@ -62,37 +62,6 @@ void GroupObjectIdsByStoreProvider(const std::vector<ObjectID> &object_ids,
namespace ray {
// Prepare direct call args for sending to a direct call *actor*. Direct call actors
// always resolve their dependencies remotely, so we need some client-side preprocessing
// to ensure they don't try to resolve a direct call object ID remotely (which is
// impossible).
// - Direct call args that are local and small will be inlined.
// - Direct call args that are non-local or large will be promoted to plasma.
// Note that args for direct call *tasks* are handled by LocalDependencyResolver.
std::vector<TaskArg> PrepareDirectActorCallArgs(
const std::vector<TaskArg> &args,
std::shared_ptr<CoreWorkerMemoryStore> memory_store) {
std::vector<TaskArg> out;
for (const auto &arg : args) {
if (arg.IsPassedByReference() && arg.GetReference().IsDirectCallType()) {
const ObjectID &obj_id = arg.GetReference();
// TODO(ekl) we should consider resolving these dependencies on the client side
// for actor calls. It is a little tricky since we have to also preserve the
// task ordering so we can't simply use LocalDependencyResolver.
std::shared_ptr<RayObject> obj = memory_store->GetOrPromoteToPlasma(obj_id);
if (obj != nullptr) {
out.push_back(TaskArg::PassByValue(obj));
} else {
out.push_back(TaskArg::PassByReference(
obj_id.WithTransportType(TaskTransportType::RAYLET)));
}
} else {
out.push_back(arg);
}
}
return out;
}
CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id, const gcs::GcsClientOptions &gcs_options,
@ -221,17 +190,16 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
SetCurrentTaskId(task_id);
}
auto client_factory = [this](const rpc::WorkerAddress &addr) {
return std::shared_ptr<rpc::CoreWorkerClient>(
new rpc::CoreWorkerClient(addr.first, addr.second, *client_call_manager_));
};
direct_actor_submitter_ = std::unique_ptr<CoreWorkerDirectActorTaskSubmitter>(
new CoreWorkerDirectActorTaskSubmitter(*client_call_manager_,
memory_store_provider_));
new CoreWorkerDirectActorTaskSubmitter(client_factory, memory_store_provider_));
direct_task_submitter_ =
std::unique_ptr<CoreWorkerDirectTaskSubmitter>(new CoreWorkerDirectTaskSubmitter(
raylet_client_,
[this](WorkerAddress addr) {
return std::shared_ptr<rpc::CoreWorkerClient>(new rpc::CoreWorkerClient(
addr.first, addr.second, *client_call_manager_));
},
raylet_client_, client_factory,
[this](const rpc::Address &address) {
auto grpc_client = rpc::NodeManagerWorkerClient::make(
address.ip_address(), address.port(), *client_call_manager_);
@ -657,11 +625,10 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
const TaskID actor_task_id = TaskID::ForActorTask(
worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(),
next_task_index, actor_handle->GetActorID());
BuildCommonTaskSpec(
builder, actor_handle->CreationJobID(), actor_task_id,
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_,
function, is_direct_call ? PrepareDirectActorCallArgs(args, memory_store_) : args,
num_returns, task_options.resources, {}, transport_type, return_ids);
BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id,
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(),
rpc_address_, function, args, num_returns, task_options.resources,
{}, transport_type, return_ids);
const ObjectID new_cursor = return_ids->back();
actor_handle->SetActorTaskSpec(builder, transport_type, new_cursor);

View file

@ -0,0 +1,130 @@
#include "gtest/gtest.h"
#include "ray/common/task/task_spec.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/core_worker/store_provider/memory_store_provider.h"
#include "ray/core_worker/transport/direct_task_transport.h"
#include "ray/raylet/raylet_client.h"
#include "ray/rpc/worker/core_worker_client.h"
#include "src/ray/util/test_util.h"
namespace ray {
class MockWorkerClient : public rpc::CoreWorkerClientInterface {
public:
ray::Status PushActorTask(
std::unique_ptr<rpc::PushTaskRequest> request,
const rpc::ClientCallback<rpc::PushTaskReply> &callback) override {
RAY_CHECK(counter == request->task_spec().actor_task_spec().actor_counter());
counter++;
callbacks.push_back(callback);
return Status::OK();
}
std::vector<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
uint64_t counter = 0;
};
TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) {
TaskSpecification task;
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
task.GetMutableMessage().set_type(TaskType::ACTOR_TASK);
task.GetMutableMessage().mutable_actor_task_spec()->set_actor_id(actor_id.Binary());
task.GetMutableMessage().mutable_actor_task_spec()->set_actor_counter(counter);
return task;
}
class DirectActorTransportTest : public ::testing::Test {
public:
DirectActorTransportTest()
: worker_client_(std::shared_ptr<MockWorkerClient>(new MockWorkerClient())),
ptr_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
store_(std::make_shared<CoreWorkerMemoryStoreProvider>(ptr_)),
submitter_([&](const rpc::WorkerAddress &addr) { return worker_client_; },
store_) {}
std::shared_ptr<MockWorkerClient> worker_client_;
std::shared_ptr<CoreWorkerMemoryStore> ptr_;
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_;
CoreWorkerDirectActorTaskSubmitter submitter_;
};
TEST_F(DirectActorTransportTest, TestSubmitTask) {
ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0);
auto task = CreateActorTaskHelper(actor_id, 0);
ASSERT_TRUE(submitter_.SubmitTask(task).ok());
ASSERT_EQ(worker_client_->callbacks.size(), 0);
gcs::ActorTableData actor_data;
submitter_.HandleActorUpdate(actor_id, actor_data);
ASSERT_EQ(worker_client_->callbacks.size(), 1);
task = CreateActorTaskHelper(actor_id, 1);
ASSERT_TRUE(submitter_.SubmitTask(task).ok());
ASSERT_EQ(worker_client_->callbacks.size(), 2);
}
TEST_F(DirectActorTransportTest, TestDependencies) {
ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0);
gcs::ActorTableData actor_data;
submitter_.HandleActorUpdate(actor_id, actor_data);
ASSERT_EQ(worker_client_->callbacks.size(), 0);
// Create two tasks for the actor with different arguments.
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
auto task1 = CreateActorTaskHelper(actor_id, 0);
task1.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
auto task2 = CreateActorTaskHelper(actor_id, 1);
task2.GetMutableMessage().add_args()->add_object_ids(obj2.Binary());
// Neither task can be submitted yet because they are still waiting on
// dependencies.
ASSERT_TRUE(submitter_.SubmitTask(task1).ok());
ASSERT_TRUE(submitter_.SubmitTask(task2).ok());
ASSERT_EQ(worker_client_->callbacks.size(), 0);
// Put the dependencies in the store in the same order as task submission.
auto data = GenerateRandomObject();
ASSERT_TRUE(store_->Put(*data, obj1).ok());
ASSERT_EQ(worker_client_->callbacks.size(), 1);
ASSERT_TRUE(store_->Put(*data, obj2).ok());
ASSERT_EQ(worker_client_->callbacks.size(), 2);
}
TEST_F(DirectActorTransportTest, TestOutOfOrderDependencies) {
ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0);
gcs::ActorTableData actor_data;
submitter_.HandleActorUpdate(actor_id, actor_data);
ASSERT_EQ(worker_client_->callbacks.size(), 0);
// Create two tasks for the actor with different arguments.
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
auto task1 = CreateActorTaskHelper(actor_id, 0);
task1.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
auto task2 = CreateActorTaskHelper(actor_id, 1);
task2.GetMutableMessage().add_args()->add_object_ids(obj2.Binary());
// Neither task can be submitted yet because they are still waiting on
// dependencies.
ASSERT_TRUE(submitter_.SubmitTask(task1).ok());
ASSERT_TRUE(submitter_.SubmitTask(task2).ok());
ASSERT_EQ(worker_client_->callbacks.size(), 0);
// Put the dependencies in the store in the opposite order of task
// submission.
auto data = GenerateRandomObject();
ASSERT_TRUE(store_->Put(*data, obj2).ok());
ASSERT_EQ(worker_client_->callbacks.size(), 0);
ASSERT_TRUE(store_->Put(*data, obj1).ok());
ASSERT_EQ(worker_client_->callbacks.size(), 2);
}
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View file

@ -192,7 +192,7 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) {
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
auto factory = [&](WorkerAddress addr) { return worker_client; };
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store);
TaskSpecification task;
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
@ -214,7 +214,7 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) {
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
auto factory = [&](WorkerAddress addr) { return worker_client; };
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store);
TaskSpecification task;
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
@ -232,7 +232,7 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
auto factory = [&](WorkerAddress addr) { return worker_client; };
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store);
TaskSpecification task1;
TaskSpecification task2;
@ -273,7 +273,7 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
auto factory = [&](WorkerAddress addr) { return worker_client; };
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store);
TaskSpecification task1;
TaskSpecification task2;
@ -316,7 +316,7 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
auto factory = [&](WorkerAddress addr) { return worker_client; };
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store);
TaskSpecification task1;
TaskSpecification task2;
@ -349,7 +349,7 @@ TEST(DirectTaskTransportTest, TestSpillback) {
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
auto factory = [&](WorkerAddress addr) { return worker_client; };
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
std::unordered_map<ClientID, std::shared_ptr<MockRayletClient>> remote_lease_clients;
auto lease_client_factory = [&](const rpc::Address &addr) {

View file

@ -0,0 +1,89 @@
#include "ray/core_worker/transport/dependency_resolver.h"
namespace ray {
struct TaskState {
/// The task to be run.
TaskSpecification task;
/// The remaining dependencies to resolve for this task.
absl::flat_hash_set<ObjectID> local_dependencies;
};
void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr<RayObject> value,
TaskSpecification &task) {
auto &msg = task.GetMutableMessage();
bool found = false;
for (size_t i = 0; i < task.NumArgs(); i++) {
auto count = task.ArgIdCount(i);
if (count > 0) {
const auto &id = task.ArgId(i, 0);
if (id == obj_id) {
auto *mutable_arg = msg.mutable_args(i);
mutable_arg->clear_object_ids();
if (value->IsInPlasmaError()) {
// Promote the object id to plasma.
mutable_arg->add_object_ids(
obj_id.WithTransportType(TaskTransportType::RAYLET).Binary());
} else {
// Inline the object value.
if (value->HasData()) {
const auto &data = value->GetData();
mutable_arg->set_data(data->Data(), data->Size());
}
if (value->HasMetadata()) {
const auto &metadata = value->GetMetadata();
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
}
}
found = true;
}
}
}
RAY_CHECK(found) << "obj id " << obj_id << " not found";
}
void LocalDependencyResolver::ResolveDependencies(const TaskSpecification &task,
std::function<void()> on_complete) {
absl::flat_hash_set<ObjectID> local_dependencies;
for (size_t i = 0; i < task.NumArgs(); i++) {
auto count = task.ArgIdCount(i);
if (count > 0) {
RAY_CHECK(count <= 1) << "multi args not implemented";
const auto &id = task.ArgId(i, 0);
if (id.IsDirectCallType()) {
local_dependencies.insert(id);
}
}
}
if (local_dependencies.empty()) {
on_complete();
return;
}
// This is deleted when the last dependency fetch callback finishes.
std::shared_ptr<TaskState> state =
std::shared_ptr<TaskState>(new TaskState{task, std::move(local_dependencies)});
num_pending_ += 1;
for (const auto &obj_id : state->local_dependencies) {
in_memory_store_->GetAsync(
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
RAY_CHECK(obj != nullptr);
bool complete = false;
{
absl::MutexLock lock(&mu_);
state->local_dependencies.erase(obj_id);
DoInlineObjectValue(obj_id, obj, state->task);
if (state->local_dependencies.empty()) {
complete = true;
num_pending_ -= 1;
}
}
if (complete) {
on_complete();
}
});
}
}
} // namespace ray

View file

@ -0,0 +1,45 @@
#ifndef RAY_CORE_WORKER_DEPENDENCY_RESOLVER_H
#define RAY_CORE_WORKER_DEPENDENCY_RESOLVER_H
#include <memory>
#include "ray/common/id.h"
#include "ray/common/task/task_spec.h"
#include "ray/core_worker/store_provider/memory_store_provider.h"
namespace ray {
// This class is thread-safe.
class LocalDependencyResolver {
public:
LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
: in_memory_store_(store_provider), num_pending_(0) {}
/// Resolve all local and remote dependencies for the task, calling the specified
/// callback when done. Direct call ids in the task specification will be resolved
/// to concrete values and inlined.
//
/// Note: This method **will mutate** the given TaskSpecification.
///
/// Postcondition: all direct call ids in arguments are converted to values.
void ResolveDependencies(const TaskSpecification &task,
std::function<void()> on_complete);
/// Return the number of tasks pending dependency resolution.
/// TODO(ekl) this should be exposed in worker stats.
int NumPendingTasks() const { return num_pending_; }
private:
/// The store provider.
std::shared_ptr<CoreWorkerMemoryStoreProvider> in_memory_store_;
/// Number of tasks pending dependency resolution.
std::atomic<int> num_pending_;
/// Protects against concurrent access to internal state.
absl::Mutex mu_;
};
} // namespace ray
#endif // RAY_CORE_WORKER_DEPENDENCY_RESOLVER_H

View file

@ -5,6 +5,10 @@ using ray::rpc::ActorTableData;
namespace ray {
int64_t GetRequestNumber(const std::unique_ptr<rpc::PushTaskRequest> &request) {
return request->task_spec().actor_task_spec().actor_counter();
}
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
const rpc::ErrorType &error_type,
std::shared_ptr<CoreWorkerMemoryStoreProvider> &in_memory_store) {
@ -57,17 +61,12 @@ void WriteObjectsToMemoryStore(
}
}
CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter(
rpc::ClientCallManager &client_call_manager,
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
: client_call_manager_(client_call_manager), in_memory_store_(store_provider) {}
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId();
RAY_CHECK(task_spec.IsActorTask());
const auto &actor_id = task_spec.ActorId();
resolver_.ResolveDependencies(task_spec, [this, task_spec]() mutable {
const auto &actor_id = task_spec.ActorId();
const auto task_id = task_spec.TaskId();
const auto num_returns = task_spec.NumReturns();
@ -85,24 +84,22 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
// actor handle (e.g. from unpickling), in that case it might be desirable
// to have a timeout to mark it as invalid if it doesn't show up in the
// specified time.
pending_requests_[actor_id].emplace_back(std::move(request));
auto inserted = pending_requests_[actor_id].emplace(GetRequestNumber(request),
std::move(request));
RAY_CHECK(inserted.second);
RAY_LOG(DEBUG) << "Actor " << actor_id << " is not yet created.";
} else if (iter->second.state_ == ActorTableData::ALIVE) {
// Actor is alive, submit the request.
if (rpc_clients_.count(actor_id) == 0) {
// If rpc client is not available, then create it.
ConnectAndSendPendingTasks(actor_id, iter->second.location_.first,
iter->second.location_.second);
}
// Submit request.
auto &client = rpc_clients_[actor_id];
PushActorTask(*client, std::move(request), actor_id, task_id, num_returns);
auto inserted = pending_requests_[actor_id].emplace(GetRequestNumber(request),
std::move(request));
RAY_CHECK(inserted.second);
SendPendingTasks(actor_id);
} else {
// Actor is dead, treat the task as failure.
RAY_CHECK(iter->second.state_ == ActorTableData::DEAD);
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, in_memory_store_);
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
in_memory_store_);
}
});
// If the task submission subsequently fails, then the client will receive
// the error in a callback.
@ -118,12 +115,15 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
actor_data.address().port()));
if (actor_data.state() == ActorTableData::ALIVE) {
// Check if this actor is the one that we're interested, if we already have
// a connection to the actor, or have pending requests for it, we should
// create a new connection.
if (pending_requests_.count(actor_id) > 0 && rpc_clients_.count(actor_id) == 0) {
ConnectAndSendPendingTasks(actor_id, actor_data.address().ip_address(),
actor_data.address().port());
// Create a new connection to the actor.
if (rpc_clients_.count(actor_id) == 0) {
rpc::WorkerAddress addr = {actor_data.address().ip_address(),
actor_data.address().port()};
rpc_clients_[actor_id] =
std::shared_ptr<rpc::CoreWorkerClientInterface>(client_factory_(addr));
}
if (pending_requests_.count(actor_id) > 0) {
SendPendingTasks(actor_id);
}
} else {
// Remove rpc client if it's dead or being reconstructed.
@ -145,40 +145,49 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
// If there are pending requests, treat the pending tasks as failed.
auto pending_it = pending_requests_.find(actor_id);
if (pending_it != pending_requests_.end()) {
for (const auto &request : pending_it->second) {
auto head = pending_it->second.begin();
while (head != pending_it->second.end()) {
auto request = std::move(head->second);
head = pending_it->second.erase(head);
TreatTaskAsFailed(TaskID::FromBinary(request->task_spec().task_id()),
request->task_spec().num_returns(), rpc::ErrorType::ACTOR_DIED,
in_memory_store_);
}
pending_requests_.erase(pending_it);
}
next_sequence_number_.erase(actor_id);
}
}
void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks(
const ActorID &actor_id, std::string ip_address, int port) {
std::shared_ptr<rpc::CoreWorkerClient> grpc_client =
std::make_shared<rpc::CoreWorkerClient>(ip_address, port, client_call_manager_);
RAY_CHECK(rpc_clients_.emplace(actor_id, std::move(grpc_client)).second);
// Submit all pending requests.
void CoreWorkerDirectActorTaskSubmitter::SendPendingTasks(const ActorID &actor_id) {
auto &client = rpc_clients_[actor_id];
RAY_CHECK(client);
// Submit all pending requests.
auto &requests = pending_requests_[actor_id];
while (!requests.empty()) {
auto request = std::move(requests.front());
auto head = requests.begin();
while (head != requests.end() && head->first == next_sequence_number_[actor_id]) {
auto request = std::move(head->second);
head = requests.erase(head);
auto num_returns = request->task_spec().num_returns();
auto task_id = TaskID::FromBinary(request->task_spec().task_id());
PushActorTask(*client, std::move(request), actor_id, task_id, num_returns);
requests.pop_front();
}
}
void CoreWorkerDirectActorTaskSubmitter::PushActorTask(
rpc::CoreWorkerClient &client, std::unique_ptr<rpc::PushTaskRequest> request,
rpc::CoreWorkerClientInterface &client, std::unique_ptr<rpc::PushTaskRequest> request,
const ActorID &actor_id, const TaskID &task_id, int num_returns) {
RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id;
waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns));
auto task_number = GetRequestNumber(request);
RAY_CHECK(next_sequence_number_[actor_id] == task_number)
<< "Counter was " << task_number << " expected " << next_sequence_number_[actor_id];
next_sequence_number_[actor_id]++;
auto status = client.PushActorTask(
std::move(request), [this, actor_id, task_id, num_returns](
Status status, const rpc::PushTaskReply &reply) {

View file

@ -4,8 +4,10 @@
#include <boost/asio/thread_pool.hpp>
#include <boost/thread.hpp>
#include <list>
#include <queue>
#include <set>
#include <utility>
#include "absl/container/flat_hash_map.h"
#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
@ -13,10 +15,13 @@
#include "ray/common/ray_object.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/store_provider/memory_store_provider.h"
#include "ray/core_worker/transport/dependency_resolver.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/rpc/grpc_server.h"
#include "ray/rpc/worker/core_worker_client.h"
namespace {} // namespace
namespace ray {
/// The max time to wait for out-of-order tasks.
@ -61,8 +66,11 @@ struct ActorStateData {
class CoreWorkerDirectActorTaskSubmitter {
public:
CoreWorkerDirectActorTaskSubmitter(
rpc::ClientCallManager &client_call_manager,
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider);
rpc::ClientFactoryFn client_factory,
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
: client_factory_(client_factory),
in_memory_store_(store_provider),
resolver_(in_memory_store_) {}
/// Submit a task to an actor for execution.
///
@ -87,20 +95,17 @@ class CoreWorkerDirectActorTaskSubmitter {
/// \param[in] task_id The ID of a task.
/// \param[in] num_returns Number of return objects.
/// \return Void.
void PushActorTask(rpc::CoreWorkerClient &client,
void PushActorTask(rpc::CoreWorkerClientInterface &client,
std::unique_ptr<rpc::PushTaskRequest> request,
const ActorID &actor_id, const TaskID &task_id, int num_returns);
/// Create connection to actor and send all pending tasks.
/// Send all pending tasks for an actor.
/// Note that this function doesn't take lock, the caller is expected to hold
/// `mutex_` before calling this function.
///
/// \param[in] actor_id Actor ID.
/// \param[in] ip_address The ip address of the node that the actor is running on.
/// \param[in] port The port that the actor is listening on.
/// \return Void.
void ConnectAndSendPendingTasks(const ActorID &actor_id, std::string ip_address,
int port);
void SendPendingTasks(const ActorID &actor_id);
/// Whether the specified actor is alive.
///
@ -108,8 +113,8 @@ class CoreWorkerDirectActorTaskSubmitter {
/// \return Whether this actor is alive.
bool IsActorAlive(const ActorID &actor_id) const;
/// The shared `ClientCallManager` object.
rpc::ClientCallManager &client_call_manager_;
/// Factory for producing new core worker clients.
rpc::ClientFactoryFn client_factory_;
/// Mutex to proect the various maps below.
mutable std::mutex mutex_;
@ -122,18 +127,27 @@ class CoreWorkerDirectActorTaskSubmitter {
///
/// TODO(zhijunfu): this will be moved into `actor_states_` later when we can
/// subscribe updates for a specific actor.
std::unordered_map<ActorID, std::shared_ptr<rpc::CoreWorkerClient>> rpc_clients_;
std::unordered_map<ActorID, std::shared_ptr<rpc::CoreWorkerClientInterface>>
rpc_clients_;
/// Map from actor id to the actor's pending requests.
std::unordered_map<ActorID, std::list<std::unique_ptr<rpc::PushTaskRequest>>>
/// Map from actor id to the actor's pending requests. Each actor's requests
/// are ordered by the task number in the request.
absl::flat_hash_map<ActorID, std::map<int64_t, std::unique_ptr<rpc::PushTaskRequest>>>
pending_requests_;
/// Map from actor id to the sequence number of the next task to send to that
/// actor.
std::unordered_map<ActorID, int64_t> next_sequence_number_;
/// Map from actor id to the tasks that are waiting for reply.
std::unordered_map<ActorID, std::unordered_map<TaskID, int>> waiting_reply_tasks_;
/// The store provider.
std::shared_ptr<CoreWorkerMemoryStoreProvider> in_memory_store_;
/// Resolve direct call object dependencies;
LocalDependencyResolver resolver_;
friend class CoreWorkerTest;
};

View file

@ -1,85 +1,9 @@
#include "ray/core_worker/transport/direct_task_transport.h"
#include "ray/core_worker/transport/dependency_resolver.h"
#include "ray/core_worker/transport/direct_actor_transport.h"
namespace ray {
void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr<RayObject> value,
TaskSpecification &task) {
auto &msg = task.GetMutableMessage();
bool found = false;
for (size_t i = 0; i < task.NumArgs(); i++) {
auto count = task.ArgIdCount(i);
if (count > 0) {
const auto &id = task.ArgId(i, 0);
if (id == obj_id) {
auto *mutable_arg = msg.mutable_args(i);
mutable_arg->clear_object_ids();
if (value->IsInPlasmaError()) {
// Promote the object id to plasma.
mutable_arg->add_object_ids(
obj_id.WithTransportType(TaskTransportType::RAYLET).Binary());
} else {
// Inline the object value.
if (value->HasData()) {
const auto &data = value->GetData();
mutable_arg->set_data(data->Data(), data->Size());
}
if (value->HasMetadata()) {
const auto &metadata = value->GetMetadata();
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
}
}
found = true;
}
}
}
RAY_CHECK(found) << "obj id " << obj_id << " not found";
}
void LocalDependencyResolver::ResolveDependencies(const TaskSpecification &task,
std::function<void()> on_complete) {
absl::flat_hash_set<ObjectID> local_dependencies;
for (size_t i = 0; i < task.NumArgs(); i++) {
auto count = task.ArgIdCount(i);
if (count > 0) {
RAY_CHECK(count <= 1) << "multi args not implemented";
const auto &id = task.ArgId(i, 0);
if (id.IsDirectCallType()) {
local_dependencies.insert(id);
}
}
}
if (local_dependencies.empty()) {
on_complete();
return;
}
// This is deleted when the last dependency fetch callback finishes.
std::shared_ptr<TaskState> state =
std::shared_ptr<TaskState>(new TaskState{task, std::move(local_dependencies)});
num_pending_ += 1;
for (const auto &obj_id : state->local_dependencies) {
in_memory_store_->GetAsync(
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
RAY_CHECK(obj != nullptr);
bool complete = false;
{
absl::MutexLock lock(&mu_);
state->local_dependencies.erase(obj_id);
DoInlineObjectValue(obj_id, obj, state->task);
if (state->local_dependencies.empty()) {
complete = true;
num_pending_ -= 1;
}
}
if (complete) {
on_complete();
}
});
}
}
Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
resolver_.ResolveDependencies(task_spec, [this, task_spec]() {
// TODO(ekl) should have a queue per distinct resource type required
@ -91,7 +15,7 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
}
void CoreWorkerDirectTaskSubmitter::HandleWorkerLeaseGranted(
const WorkerAddress &addr, std::shared_ptr<WorkerLeaseInterface> lease_client) {
const rpc::WorkerAddress &addr, std::shared_ptr<WorkerLeaseInterface> lease_client) {
// Setup client state for this worker.
{
absl::MutexLock lock(&mu_);
@ -110,7 +34,7 @@ void CoreWorkerDirectTaskSubmitter::HandleWorkerLeaseGranted(
OnWorkerIdle(addr, /*error=*/false);
}
void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(const WorkerAddress &addr,
void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(const rpc::WorkerAddress &addr,
bool was_error) {
absl::MutexLock lock(&mu_);
if (queued_tasks_.empty() || was_error) {
@ -199,7 +123,7 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(
worker_request_pending_ = true;
}
void CoreWorkerDirectTaskSubmitter::PushNormalTask(const WorkerAddress &addr,
void CoreWorkerDirectTaskSubmitter::PushNormalTask(const rpc::WorkerAddress &addr,
rpc::CoreWorkerClientInterface &client,
TaskSpecification &task_spec) {
auto task_id = task_spec.TaskId();

View file

@ -7,53 +7,13 @@
#include "ray/common/ray_object.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/store_provider/memory_store_provider.h"
#include "ray/core_worker/transport/dependency_resolver.h"
#include "ray/core_worker/transport/direct_actor_transport.h"
#include "ray/raylet/raylet_client.h"
#include "ray/rpc/worker/core_worker_client.h"
namespace ray {
struct TaskState {
/// The task to be run.
TaskSpecification task;
/// The remaining dependencies to resolve for this task.
absl::flat_hash_set<ObjectID> local_dependencies;
};
// This class is thread-safe.
class LocalDependencyResolver {
public:
LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
: in_memory_store_(store_provider), num_pending_(0) {}
/// Resolve all local and remote dependencies for the task, calling the specified
/// callback when done. Direct call ids in the task specification will be resolved
/// to concrete values and inlined.
//
/// Note: This method **will mutate** the given TaskSpecification.
///
/// Postcondition: all direct call ids in arguments are converted to values.
void ResolveDependencies(const TaskSpecification &task,
std::function<void()> on_complete);
/// Return the number of tasks pending dependency resolution.
/// TODO(ekl) this should be exposed in worker stats.
int NumPendingTasks() const { return num_pending_; }
private:
/// The store provider.
std::shared_ptr<CoreWorkerMemoryStoreProvider> in_memory_store_;
/// Number of tasks pending dependency resolution.
std::atomic<int> num_pending_;
/// Protects against concurrent access to internal state.
absl::Mutex mu_;
};
typedef std::pair<std::string, int> WorkerAddress;
typedef std::function<std::shared_ptr<rpc::CoreWorkerClientInterface>(WorkerAddress)>
ClientFactoryFn;
typedef std::function<std::shared_ptr<WorkerLeaseInterface>(const rpc::Address &)>
LeaseClientFactoryFn;
@ -61,8 +21,8 @@ typedef std::function<std::shared_ptr<WorkerLeaseInterface>(const rpc::Address &
class CoreWorkerDirectTaskSubmitter {
public:
CoreWorkerDirectTaskSubmitter(
std::shared_ptr<WorkerLeaseInterface> lease_client, ClientFactoryFn client_factory,
LeaseClientFactoryFn lease_client_factory,
std::shared_ptr<WorkerLeaseInterface> lease_client,
rpc::ClientFactoryFn client_factory, LeaseClientFactoryFn lease_client_factory,
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
: local_lease_client_(lease_client),
client_factory_(client_factory),
@ -79,7 +39,7 @@ class CoreWorkerDirectTaskSubmitter {
/// Schedule more work onto an idle worker or return it back to the raylet if
/// no more tasks are queued for submission. If an error was encountered
/// processing the worker, we don't attempt to re-use the worker.
void OnWorkerIdle(const WorkerAddress &addr, bool was_error);
void OnWorkerIdle(const rpc::WorkerAddress &addr, bool was_error);
/// Get an existing lease client or connect a new one. If a raylet_address is
/// provided, this connects to a remote raylet. Else, this connects to the
@ -98,11 +58,12 @@ class CoreWorkerDirectTaskSubmitter {
/// Callback for when the raylet grants us a worker lease. The worker is returned
/// to the raylet via the given lease client once the task queue is empty.
/// TODO: Implement a lease term by which we need to return the worker.
void HandleWorkerLeaseGranted(const WorkerAddress &addr,
void HandleWorkerLeaseGranted(const rpc::WorkerAddress &addr,
std::shared_ptr<WorkerLeaseInterface> lease_client);
/// Push a task to a specific worker.
void PushNormalTask(const WorkerAddress &addr, rpc::CoreWorkerClientInterface &client,
void PushNormalTask(const rpc::WorkerAddress &addr,
rpc::CoreWorkerClientInterface &client,
TaskSpecification &task_spec);
// Client that can be used to lease and return workers from the local raylet.
@ -113,7 +74,7 @@ class CoreWorkerDirectTaskSubmitter {
remote_lease_clients_ GUARDED_BY(mu_);
/// Factory for producing new core worker clients.
ClientFactoryFn client_factory_;
rpc::ClientFactoryFn client_factory_;
/// Factory for producing new clients to request leases from remote nodes.
LeaseClientFactoryFn lease_client_factory_;
@ -128,12 +89,12 @@ class CoreWorkerDirectTaskSubmitter {
absl::Mutex mu_;
/// Cache of gRPC clients to other workers.
absl::flat_hash_map<WorkerAddress, std::shared_ptr<rpc::CoreWorkerClientInterface>>
absl::flat_hash_map<rpc::WorkerAddress, std::shared_ptr<rpc::CoreWorkerClientInterface>>
client_cache_ GUARDED_BY(mu_);
/// Map from worker address to the lease client through which it should be
/// returned.
absl::flat_hash_map<WorkerAddress, std::shared_ptr<WorkerLeaseInterface>>
absl::flat_hash_map<rpc::WorkerAddress, std::shared_ptr<WorkerLeaseInterface>>
worker_to_lease_client_ GUARDED_BY(mu_);
// Whether we have a request to the Raylet to acquire a new worker in flight.

View file

@ -33,6 +33,13 @@ const static int64_t RequestSizeInBytes(const PushTaskRequest &request) {
return size;
}
// Shared between direct actor and task submitters.
// TODO(swang): Remove and replace with rpc::Address.
class CoreWorkerClientInterface;
typedef std::pair<std::string, int> WorkerAddress;
typedef std::function<std::shared_ptr<CoreWorkerClientInterface>(const WorkerAddress &)>
ClientFactoryFn;
/// Abstract client interface for testing.
class CoreWorkerClientInterface {
public: