mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
Resolve dependencies locally before submitting direct actor tasks (#6191)
* Priority queue in direct actor transport by task number * Move LocalDependencyResolver out to separate file, share with direct actor transport * works * Test case for ordering * Cleanups * Remove priority queue * comment * Share ClientFactoryFn with direct actor transport * Unit test * fix
This commit is contained in:
parent
33c768ebe4
commit
c0be9e6738
13 changed files with 429 additions and 247 deletions
10
BUILD.bazel
10
BUILD.bazel
|
@ -405,6 +405,16 @@ cc_binary(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "direct_actor_transport_test",
|
||||
srcs = ["src/ray/core_worker/test/direct_actor_transport_test.cc"],
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
":core_worker_lib",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "direct_task_transport_test",
|
||||
srcs = ["src/ray/core_worker/test/direct_task_transport_test.cc"],
|
||||
|
|
|
@ -1349,6 +1349,31 @@ def test_direct_actor_enabled(ray_start_regular):
|
|||
assert ray.get(obj_id) == 2
|
||||
|
||||
|
||||
def test_direct_actor_order(shutdown_only):
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
@ray.remote
|
||||
def small_value():
|
||||
time.sleep(0.01 * np.random.randint(0, 10))
|
||||
return 0
|
||||
|
||||
@ray.remote
|
||||
class Actor(object):
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
def inc(self, count, dependency):
|
||||
assert count == self.count
|
||||
self.count += 1
|
||||
return count
|
||||
|
||||
a = Actor._remote(is_direct_call=True)
|
||||
assert ray.get([
|
||||
a.inc.remote(i, small_value.options(is_direct_call=True).remote())
|
||||
for i in range(100)
|
||||
]) == list(range(100))
|
||||
|
||||
|
||||
def test_direct_actor_large_objects(ray_start_regular):
|
||||
@ray.remote
|
||||
class Actor(object):
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "ray/util/util.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
using WorkerType = rpc::WorkerType;
|
||||
|
||||
// Return a string representation of the worker type.
|
||||
|
|
|
@ -62,37 +62,6 @@ void GroupObjectIdsByStoreProvider(const std::vector<ObjectID> &object_ids,
|
|||
|
||||
namespace ray {
|
||||
|
||||
// Prepare direct call args for sending to a direct call *actor*. Direct call actors
|
||||
// always resolve their dependencies remotely, so we need some client-side preprocessing
|
||||
// to ensure they don't try to resolve a direct call object ID remotely (which is
|
||||
// impossible).
|
||||
// - Direct call args that are local and small will be inlined.
|
||||
// - Direct call args that are non-local or large will be promoted to plasma.
|
||||
// Note that args for direct call *tasks* are handled by LocalDependencyResolver.
|
||||
std::vector<TaskArg> PrepareDirectActorCallArgs(
|
||||
const std::vector<TaskArg> &args,
|
||||
std::shared_ptr<CoreWorkerMemoryStore> memory_store) {
|
||||
std::vector<TaskArg> out;
|
||||
for (const auto &arg : args) {
|
||||
if (arg.IsPassedByReference() && arg.GetReference().IsDirectCallType()) {
|
||||
const ObjectID &obj_id = arg.GetReference();
|
||||
// TODO(ekl) we should consider resolving these dependencies on the client side
|
||||
// for actor calls. It is a little tricky since we have to also preserve the
|
||||
// task ordering so we can't simply use LocalDependencyResolver.
|
||||
std::shared_ptr<RayObject> obj = memory_store->GetOrPromoteToPlasma(obj_id);
|
||||
if (obj != nullptr) {
|
||||
out.push_back(TaskArg::PassByValue(obj));
|
||||
} else {
|
||||
out.push_back(TaskArg::PassByReference(
|
||||
obj_id.WithTransportType(TaskTransportType::RAYLET)));
|
||||
}
|
||||
} else {
|
||||
out.push_back(arg);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||
const std::string &store_socket, const std::string &raylet_socket,
|
||||
const JobID &job_id, const gcs::GcsClientOptions &gcs_options,
|
||||
|
@ -221,17 +190,16 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
|||
SetCurrentTaskId(task_id);
|
||||
}
|
||||
|
||||
auto client_factory = [this](const rpc::WorkerAddress &addr) {
|
||||
return std::shared_ptr<rpc::CoreWorkerClient>(
|
||||
new rpc::CoreWorkerClient(addr.first, addr.second, *client_call_manager_));
|
||||
};
|
||||
direct_actor_submitter_ = std::unique_ptr<CoreWorkerDirectActorTaskSubmitter>(
|
||||
new CoreWorkerDirectActorTaskSubmitter(*client_call_manager_,
|
||||
memory_store_provider_));
|
||||
new CoreWorkerDirectActorTaskSubmitter(client_factory, memory_store_provider_));
|
||||
|
||||
direct_task_submitter_ =
|
||||
std::unique_ptr<CoreWorkerDirectTaskSubmitter>(new CoreWorkerDirectTaskSubmitter(
|
||||
raylet_client_,
|
||||
[this](WorkerAddress addr) {
|
||||
return std::shared_ptr<rpc::CoreWorkerClient>(new rpc::CoreWorkerClient(
|
||||
addr.first, addr.second, *client_call_manager_));
|
||||
},
|
||||
raylet_client_, client_factory,
|
||||
[this](const rpc::Address &address) {
|
||||
auto grpc_client = rpc::NodeManagerWorkerClient::make(
|
||||
address.ip_address(), address.port(), *client_call_manager_);
|
||||
|
@ -657,11 +625,10 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
|
|||
const TaskID actor_task_id = TaskID::ForActorTask(
|
||||
worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(),
|
||||
next_task_index, actor_handle->GetActorID());
|
||||
BuildCommonTaskSpec(
|
||||
builder, actor_handle->CreationJobID(), actor_task_id,
|
||||
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_,
|
||||
function, is_direct_call ? PrepareDirectActorCallArgs(args, memory_store_) : args,
|
||||
num_returns, task_options.resources, {}, transport_type, return_ids);
|
||||
BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id,
|
||||
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(),
|
||||
rpc_address_, function, args, num_returns, task_options.resources,
|
||||
{}, transport_type, return_ids);
|
||||
|
||||
const ObjectID new_cursor = return_ids->back();
|
||||
actor_handle->SetActorTaskSpec(builder, transport_type, new_cursor);
|
||||
|
|
130
src/ray/core_worker/test/direct_actor_transport_test.cc
Normal file
130
src/ray/core_worker/test/direct_actor_transport_test.cc
Normal file
|
@ -0,0 +1,130 @@
|
|||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||
#include "ray/core_worker/store_provider/memory_store_provider.h"
|
||||
#include "ray/core_worker/transport/direct_task_transport.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "ray/rpc/worker/core_worker_client.h"
|
||||
#include "src/ray/util/test_util.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
class MockWorkerClient : public rpc::CoreWorkerClientInterface {
|
||||
public:
|
||||
ray::Status PushActorTask(
|
||||
std::unique_ptr<rpc::PushTaskRequest> request,
|
||||
const rpc::ClientCallback<rpc::PushTaskReply> &callback) override {
|
||||
RAY_CHECK(counter == request->task_spec().actor_task_spec().actor_counter());
|
||||
counter++;
|
||||
callbacks.push_back(callback);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
|
||||
uint64_t counter = 0;
|
||||
};
|
||||
|
||||
TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) {
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
task.GetMutableMessage().set_type(TaskType::ACTOR_TASK);
|
||||
task.GetMutableMessage().mutable_actor_task_spec()->set_actor_id(actor_id.Binary());
|
||||
task.GetMutableMessage().mutable_actor_task_spec()->set_actor_counter(counter);
|
||||
return task;
|
||||
}
|
||||
|
||||
class DirectActorTransportTest : public ::testing::Test {
|
||||
public:
|
||||
DirectActorTransportTest()
|
||||
: worker_client_(std::shared_ptr<MockWorkerClient>(new MockWorkerClient())),
|
||||
ptr_(std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore())),
|
||||
store_(std::make_shared<CoreWorkerMemoryStoreProvider>(ptr_)),
|
||||
submitter_([&](const rpc::WorkerAddress &addr) { return worker_client_; },
|
||||
store_) {}
|
||||
|
||||
std::shared_ptr<MockWorkerClient> worker_client_;
|
||||
std::shared_ptr<CoreWorkerMemoryStore> ptr_;
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_;
|
||||
CoreWorkerDirectActorTaskSubmitter submitter_;
|
||||
};
|
||||
|
||||
TEST_F(DirectActorTransportTest, TestSubmitTask) {
|
||||
ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0);
|
||||
|
||||
auto task = CreateActorTaskHelper(actor_id, 0);
|
||||
ASSERT_TRUE(submitter_.SubmitTask(task).ok());
|
||||
ASSERT_EQ(worker_client_->callbacks.size(), 0);
|
||||
|
||||
gcs::ActorTableData actor_data;
|
||||
submitter_.HandleActorUpdate(actor_id, actor_data);
|
||||
ASSERT_EQ(worker_client_->callbacks.size(), 1);
|
||||
|
||||
task = CreateActorTaskHelper(actor_id, 1);
|
||||
ASSERT_TRUE(submitter_.SubmitTask(task).ok());
|
||||
ASSERT_EQ(worker_client_->callbacks.size(), 2);
|
||||
}
|
||||
|
||||
TEST_F(DirectActorTransportTest, TestDependencies) {
|
||||
ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0);
|
||||
gcs::ActorTableData actor_data;
|
||||
submitter_.HandleActorUpdate(actor_id, actor_data);
|
||||
ASSERT_EQ(worker_client_->callbacks.size(), 0);
|
||||
|
||||
// Create two tasks for the actor with different arguments.
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
auto task1 = CreateActorTaskHelper(actor_id, 0);
|
||||
task1.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
|
||||
auto task2 = CreateActorTaskHelper(actor_id, 1);
|
||||
task2.GetMutableMessage().add_args()->add_object_ids(obj2.Binary());
|
||||
|
||||
// Neither task can be submitted yet because they are still waiting on
|
||||
// dependencies.
|
||||
ASSERT_TRUE(submitter_.SubmitTask(task1).ok());
|
||||
ASSERT_TRUE(submitter_.SubmitTask(task2).ok());
|
||||
ASSERT_EQ(worker_client_->callbacks.size(), 0);
|
||||
|
||||
// Put the dependencies in the store in the same order as task submission.
|
||||
auto data = GenerateRandomObject();
|
||||
ASSERT_TRUE(store_->Put(*data, obj1).ok());
|
||||
ASSERT_EQ(worker_client_->callbacks.size(), 1);
|
||||
ASSERT_TRUE(store_->Put(*data, obj2).ok());
|
||||
ASSERT_EQ(worker_client_->callbacks.size(), 2);
|
||||
}
|
||||
|
||||
TEST_F(DirectActorTransportTest, TestOutOfOrderDependencies) {
|
||||
ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0);
|
||||
gcs::ActorTableData actor_data;
|
||||
submitter_.HandleActorUpdate(actor_id, actor_data);
|
||||
ASSERT_EQ(worker_client_->callbacks.size(), 0);
|
||||
|
||||
// Create two tasks for the actor with different arguments.
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
auto task1 = CreateActorTaskHelper(actor_id, 0);
|
||||
task1.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
|
||||
auto task2 = CreateActorTaskHelper(actor_id, 1);
|
||||
task2.GetMutableMessage().add_args()->add_object_ids(obj2.Binary());
|
||||
|
||||
// Neither task can be submitted yet because they are still waiting on
|
||||
// dependencies.
|
||||
ASSERT_TRUE(submitter_.SubmitTask(task1).ok());
|
||||
ASSERT_TRUE(submitter_.SubmitTask(task2).ok());
|
||||
ASSERT_EQ(worker_client_->callbacks.size(), 0);
|
||||
|
||||
// Put the dependencies in the store in the opposite order of task
|
||||
// submission.
|
||||
auto data = GenerateRandomObject();
|
||||
ASSERT_TRUE(store_->Put(*data, obj2).ok());
|
||||
ASSERT_EQ(worker_client_->callbacks.size(), 0);
|
||||
ASSERT_TRUE(store_->Put(*data, obj1).ok());
|
||||
ASSERT_EQ(worker_client_->callbacks.size(), 2);
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
|
@ -192,7 +192,7 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) {
|
|||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store);
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
|
@ -214,7 +214,7 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) {
|
|||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store);
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
|
@ -232,7 +232,7 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
|
|||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store);
|
||||
TaskSpecification task1;
|
||||
TaskSpecification task2;
|
||||
|
@ -273,7 +273,7 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
|
|||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store);
|
||||
TaskSpecification task1;
|
||||
TaskSpecification task2;
|
||||
|
@ -316,7 +316,7 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
|
|||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store);
|
||||
TaskSpecification task1;
|
||||
TaskSpecification task2;
|
||||
|
@ -349,7 +349,7 @@ TEST(DirectTaskTransportTest, TestSpillback) {
|
|||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
auto factory = [&](const rpc::WorkerAddress &addr) { return worker_client; };
|
||||
|
||||
std::unordered_map<ClientID, std::shared_ptr<MockRayletClient>> remote_lease_clients;
|
||||
auto lease_client_factory = [&](const rpc::Address &addr) {
|
||||
|
|
89
src/ray/core_worker/transport/dependency_resolver.cc
Normal file
89
src/ray/core_worker/transport/dependency_resolver.cc
Normal file
|
@ -0,0 +1,89 @@
|
|||
#include "ray/core_worker/transport/dependency_resolver.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
struct TaskState {
|
||||
/// The task to be run.
|
||||
TaskSpecification task;
|
||||
/// The remaining dependencies to resolve for this task.
|
||||
absl::flat_hash_set<ObjectID> local_dependencies;
|
||||
};
|
||||
|
||||
void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr<RayObject> value,
|
||||
TaskSpecification &task) {
|
||||
auto &msg = task.GetMutableMessage();
|
||||
bool found = false;
|
||||
for (size_t i = 0; i < task.NumArgs(); i++) {
|
||||
auto count = task.ArgIdCount(i);
|
||||
if (count > 0) {
|
||||
const auto &id = task.ArgId(i, 0);
|
||||
if (id == obj_id) {
|
||||
auto *mutable_arg = msg.mutable_args(i);
|
||||
mutable_arg->clear_object_ids();
|
||||
if (value->IsInPlasmaError()) {
|
||||
// Promote the object id to plasma.
|
||||
mutable_arg->add_object_ids(
|
||||
obj_id.WithTransportType(TaskTransportType::RAYLET).Binary());
|
||||
} else {
|
||||
// Inline the object value.
|
||||
if (value->HasData()) {
|
||||
const auto &data = value->GetData();
|
||||
mutable_arg->set_data(data->Data(), data->Size());
|
||||
}
|
||||
if (value->HasMetadata()) {
|
||||
const auto &metadata = value->GetMetadata();
|
||||
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
|
||||
}
|
||||
}
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
RAY_CHECK(found) << "obj id " << obj_id << " not found";
|
||||
}
|
||||
|
||||
void LocalDependencyResolver::ResolveDependencies(const TaskSpecification &task,
|
||||
std::function<void()> on_complete) {
|
||||
absl::flat_hash_set<ObjectID> local_dependencies;
|
||||
for (size_t i = 0; i < task.NumArgs(); i++) {
|
||||
auto count = task.ArgIdCount(i);
|
||||
if (count > 0) {
|
||||
RAY_CHECK(count <= 1) << "multi args not implemented";
|
||||
const auto &id = task.ArgId(i, 0);
|
||||
if (id.IsDirectCallType()) {
|
||||
local_dependencies.insert(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (local_dependencies.empty()) {
|
||||
on_complete();
|
||||
return;
|
||||
}
|
||||
|
||||
// This is deleted when the last dependency fetch callback finishes.
|
||||
std::shared_ptr<TaskState> state =
|
||||
std::shared_ptr<TaskState>(new TaskState{task, std::move(local_dependencies)});
|
||||
num_pending_ += 1;
|
||||
|
||||
for (const auto &obj_id : state->local_dependencies) {
|
||||
in_memory_store_->GetAsync(
|
||||
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
|
||||
RAY_CHECK(obj != nullptr);
|
||||
bool complete = false;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
state->local_dependencies.erase(obj_id);
|
||||
DoInlineObjectValue(obj_id, obj, state->task);
|
||||
if (state->local_dependencies.empty()) {
|
||||
complete = true;
|
||||
num_pending_ -= 1;
|
||||
}
|
||||
}
|
||||
if (complete) {
|
||||
on_complete();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ray
|
45
src/ray/core_worker/transport/dependency_resolver.h
Normal file
45
src/ray/core_worker/transport/dependency_resolver.h
Normal file
|
@ -0,0 +1,45 @@
|
|||
#ifndef RAY_CORE_WORKER_DEPENDENCY_RESOLVER_H
|
||||
#define RAY_CORE_WORKER_DEPENDENCY_RESOLVER_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "ray/core_worker/store_provider/memory_store_provider.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
// This class is thread-safe.
|
||||
class LocalDependencyResolver {
|
||||
public:
|
||||
LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
|
||||
: in_memory_store_(store_provider), num_pending_(0) {}
|
||||
|
||||
/// Resolve all local and remote dependencies for the task, calling the specified
|
||||
/// callback when done. Direct call ids in the task specification will be resolved
|
||||
/// to concrete values and inlined.
|
||||
//
|
||||
/// Note: This method **will mutate** the given TaskSpecification.
|
||||
///
|
||||
/// Postcondition: all direct call ids in arguments are converted to values.
|
||||
void ResolveDependencies(const TaskSpecification &task,
|
||||
std::function<void()> on_complete);
|
||||
|
||||
/// Return the number of tasks pending dependency resolution.
|
||||
/// TODO(ekl) this should be exposed in worker stats.
|
||||
int NumPendingTasks() const { return num_pending_; }
|
||||
|
||||
private:
|
||||
/// The store provider.
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> in_memory_store_;
|
||||
|
||||
/// Number of tasks pending dependency resolution.
|
||||
std::atomic<int> num_pending_;
|
||||
|
||||
/// Protects against concurrent access to internal state.
|
||||
absl::Mutex mu_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_CORE_WORKER_DEPENDENCY_RESOLVER_H
|
|
@ -5,6 +5,10 @@ using ray::rpc::ActorTableData;
|
|||
|
||||
namespace ray {
|
||||
|
||||
int64_t GetRequestNumber(const std::unique_ptr<rpc::PushTaskRequest> &request) {
|
||||
return request->task_spec().actor_task_spec().actor_counter();
|
||||
}
|
||||
|
||||
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
|
||||
const rpc::ErrorType &error_type,
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> &in_memory_store) {
|
||||
|
@ -57,17 +61,12 @@ void WriteObjectsToMemoryStore(
|
|||
}
|
||||
}
|
||||
|
||||
CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter(
|
||||
rpc::ClientCallManager &client_call_manager,
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
|
||||
: client_call_manager_(client_call_manager), in_memory_store_(store_provider) {}
|
||||
|
||||
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
||||
RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId();
|
||||
|
||||
RAY_CHECK(task_spec.IsActorTask());
|
||||
const auto &actor_id = task_spec.ActorId();
|
||||
|
||||
resolver_.ResolveDependencies(task_spec, [this, task_spec]() mutable {
|
||||
const auto &actor_id = task_spec.ActorId();
|
||||
const auto task_id = task_spec.TaskId();
|
||||
const auto num_returns = task_spec.NumReturns();
|
||||
|
||||
|
@ -85,24 +84,22 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
|
|||
// actor handle (e.g. from unpickling), in that case it might be desirable
|
||||
// to have a timeout to mark it as invalid if it doesn't show up in the
|
||||
// specified time.
|
||||
pending_requests_[actor_id].emplace_back(std::move(request));
|
||||
auto inserted = pending_requests_[actor_id].emplace(GetRequestNumber(request),
|
||||
std::move(request));
|
||||
RAY_CHECK(inserted.second);
|
||||
RAY_LOG(DEBUG) << "Actor " << actor_id << " is not yet created.";
|
||||
} else if (iter->second.state_ == ActorTableData::ALIVE) {
|
||||
// Actor is alive, submit the request.
|
||||
if (rpc_clients_.count(actor_id) == 0) {
|
||||
// If rpc client is not available, then create it.
|
||||
ConnectAndSendPendingTasks(actor_id, iter->second.location_.first,
|
||||
iter->second.location_.second);
|
||||
}
|
||||
|
||||
// Submit request.
|
||||
auto &client = rpc_clients_[actor_id];
|
||||
PushActorTask(*client, std::move(request), actor_id, task_id, num_returns);
|
||||
auto inserted = pending_requests_[actor_id].emplace(GetRequestNumber(request),
|
||||
std::move(request));
|
||||
RAY_CHECK(inserted.second);
|
||||
SendPendingTasks(actor_id);
|
||||
} else {
|
||||
// Actor is dead, treat the task as failure.
|
||||
RAY_CHECK(iter->second.state_ == ActorTableData::DEAD);
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, in_memory_store_);
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
|
||||
in_memory_store_);
|
||||
}
|
||||
});
|
||||
|
||||
// If the task submission subsequently fails, then the client will receive
|
||||
// the error in a callback.
|
||||
|
@ -118,12 +115,15 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
|
|||
actor_data.address().port()));
|
||||
|
||||
if (actor_data.state() == ActorTableData::ALIVE) {
|
||||
// Check if this actor is the one that we're interested, if we already have
|
||||
// a connection to the actor, or have pending requests for it, we should
|
||||
// create a new connection.
|
||||
if (pending_requests_.count(actor_id) > 0 && rpc_clients_.count(actor_id) == 0) {
|
||||
ConnectAndSendPendingTasks(actor_id, actor_data.address().ip_address(),
|
||||
actor_data.address().port());
|
||||
// Create a new connection to the actor.
|
||||
if (rpc_clients_.count(actor_id) == 0) {
|
||||
rpc::WorkerAddress addr = {actor_data.address().ip_address(),
|
||||
actor_data.address().port()};
|
||||
rpc_clients_[actor_id] =
|
||||
std::shared_ptr<rpc::CoreWorkerClientInterface>(client_factory_(addr));
|
||||
}
|
||||
if (pending_requests_.count(actor_id) > 0) {
|
||||
SendPendingTasks(actor_id);
|
||||
}
|
||||
} else {
|
||||
// Remove rpc client if it's dead or being reconstructed.
|
||||
|
@ -145,40 +145,49 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
|
|||
// If there are pending requests, treat the pending tasks as failed.
|
||||
auto pending_it = pending_requests_.find(actor_id);
|
||||
if (pending_it != pending_requests_.end()) {
|
||||
for (const auto &request : pending_it->second) {
|
||||
auto head = pending_it->second.begin();
|
||||
while (head != pending_it->second.end()) {
|
||||
auto request = std::move(head->second);
|
||||
head = pending_it->second.erase(head);
|
||||
|
||||
TreatTaskAsFailed(TaskID::FromBinary(request->task_spec().task_id()),
|
||||
request->task_spec().num_returns(), rpc::ErrorType::ACTOR_DIED,
|
||||
in_memory_store_);
|
||||
}
|
||||
pending_requests_.erase(pending_it);
|
||||
}
|
||||
|
||||
next_sequence_number_.erase(actor_id);
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks(
|
||||
const ActorID &actor_id, std::string ip_address, int port) {
|
||||
std::shared_ptr<rpc::CoreWorkerClient> grpc_client =
|
||||
std::make_shared<rpc::CoreWorkerClient>(ip_address, port, client_call_manager_);
|
||||
RAY_CHECK(rpc_clients_.emplace(actor_id, std::move(grpc_client)).second);
|
||||
|
||||
// Submit all pending requests.
|
||||
void CoreWorkerDirectActorTaskSubmitter::SendPendingTasks(const ActorID &actor_id) {
|
||||
auto &client = rpc_clients_[actor_id];
|
||||
RAY_CHECK(client);
|
||||
// Submit all pending requests.
|
||||
auto &requests = pending_requests_[actor_id];
|
||||
while (!requests.empty()) {
|
||||
auto request = std::move(requests.front());
|
||||
auto head = requests.begin();
|
||||
while (head != requests.end() && head->first == next_sequence_number_[actor_id]) {
|
||||
auto request = std::move(head->second);
|
||||
head = requests.erase(head);
|
||||
|
||||
auto num_returns = request->task_spec().num_returns();
|
||||
auto task_id = TaskID::FromBinary(request->task_spec().task_id());
|
||||
PushActorTask(*client, std::move(request), actor_id, task_id, num_returns);
|
||||
requests.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorkerDirectActorTaskSubmitter::PushActorTask(
|
||||
rpc::CoreWorkerClient &client, std::unique_ptr<rpc::PushTaskRequest> request,
|
||||
rpc::CoreWorkerClientInterface &client, std::unique_ptr<rpc::PushTaskRequest> request,
|
||||
const ActorID &actor_id, const TaskID &task_id, int num_returns) {
|
||||
RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id;
|
||||
waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns));
|
||||
|
||||
auto task_number = GetRequestNumber(request);
|
||||
RAY_CHECK(next_sequence_number_[actor_id] == task_number)
|
||||
<< "Counter was " << task_number << " expected " << next_sequence_number_[actor_id];
|
||||
next_sequence_number_[actor_id]++;
|
||||
|
||||
auto status = client.PushActorTask(
|
||||
std::move(request), [this, actor_id, task_id, num_returns](
|
||||
Status status, const rpc::PushTaskReply &reply) {
|
||||
|
|
|
@ -4,8 +4,10 @@
|
|||
#include <boost/asio/thread_pool.hpp>
|
||||
#include <boost/thread.hpp>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
|
@ -13,10 +15,13 @@
|
|||
#include "ray/common/ray_object.h"
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/store_provider/memory_store_provider.h"
|
||||
#include "ray/core_worker/transport/dependency_resolver.h"
|
||||
#include "ray/gcs/redis_gcs_client.h"
|
||||
#include "ray/rpc/grpc_server.h"
|
||||
#include "ray/rpc/worker/core_worker_client.h"
|
||||
|
||||
namespace {} // namespace
|
||||
|
||||
namespace ray {
|
||||
|
||||
/// The max time to wait for out-of-order tasks.
|
||||
|
@ -61,8 +66,11 @@ struct ActorStateData {
|
|||
class CoreWorkerDirectActorTaskSubmitter {
|
||||
public:
|
||||
CoreWorkerDirectActorTaskSubmitter(
|
||||
rpc::ClientCallManager &client_call_manager,
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider);
|
||||
rpc::ClientFactoryFn client_factory,
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
|
||||
: client_factory_(client_factory),
|
||||
in_memory_store_(store_provider),
|
||||
resolver_(in_memory_store_) {}
|
||||
|
||||
/// Submit a task to an actor for execution.
|
||||
///
|
||||
|
@ -87,20 +95,17 @@ class CoreWorkerDirectActorTaskSubmitter {
|
|||
/// \param[in] task_id The ID of a task.
|
||||
/// \param[in] num_returns Number of return objects.
|
||||
/// \return Void.
|
||||
void PushActorTask(rpc::CoreWorkerClient &client,
|
||||
void PushActorTask(rpc::CoreWorkerClientInterface &client,
|
||||
std::unique_ptr<rpc::PushTaskRequest> request,
|
||||
const ActorID &actor_id, const TaskID &task_id, int num_returns);
|
||||
|
||||
/// Create connection to actor and send all pending tasks.
|
||||
/// Send all pending tasks for an actor.
|
||||
/// Note that this function doesn't take lock, the caller is expected to hold
|
||||
/// `mutex_` before calling this function.
|
||||
///
|
||||
/// \param[in] actor_id Actor ID.
|
||||
/// \param[in] ip_address The ip address of the node that the actor is running on.
|
||||
/// \param[in] port The port that the actor is listening on.
|
||||
/// \return Void.
|
||||
void ConnectAndSendPendingTasks(const ActorID &actor_id, std::string ip_address,
|
||||
int port);
|
||||
void SendPendingTasks(const ActorID &actor_id);
|
||||
|
||||
/// Whether the specified actor is alive.
|
||||
///
|
||||
|
@ -108,8 +113,8 @@ class CoreWorkerDirectActorTaskSubmitter {
|
|||
/// \return Whether this actor is alive.
|
||||
bool IsActorAlive(const ActorID &actor_id) const;
|
||||
|
||||
/// The shared `ClientCallManager` object.
|
||||
rpc::ClientCallManager &client_call_manager_;
|
||||
/// Factory for producing new core worker clients.
|
||||
rpc::ClientFactoryFn client_factory_;
|
||||
|
||||
/// Mutex to proect the various maps below.
|
||||
mutable std::mutex mutex_;
|
||||
|
@ -122,18 +127,27 @@ class CoreWorkerDirectActorTaskSubmitter {
|
|||
///
|
||||
/// TODO(zhijunfu): this will be moved into `actor_states_` later when we can
|
||||
/// subscribe updates for a specific actor.
|
||||
std::unordered_map<ActorID, std::shared_ptr<rpc::CoreWorkerClient>> rpc_clients_;
|
||||
std::unordered_map<ActorID, std::shared_ptr<rpc::CoreWorkerClientInterface>>
|
||||
rpc_clients_;
|
||||
|
||||
/// Map from actor id to the actor's pending requests.
|
||||
std::unordered_map<ActorID, std::list<std::unique_ptr<rpc::PushTaskRequest>>>
|
||||
/// Map from actor id to the actor's pending requests. Each actor's requests
|
||||
/// are ordered by the task number in the request.
|
||||
absl::flat_hash_map<ActorID, std::map<int64_t, std::unique_ptr<rpc::PushTaskRequest>>>
|
||||
pending_requests_;
|
||||
|
||||
/// Map from actor id to the sequence number of the next task to send to that
|
||||
/// actor.
|
||||
std::unordered_map<ActorID, int64_t> next_sequence_number_;
|
||||
|
||||
/// Map from actor id to the tasks that are waiting for reply.
|
||||
std::unordered_map<ActorID, std::unordered_map<TaskID, int>> waiting_reply_tasks_;
|
||||
|
||||
/// The store provider.
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> in_memory_store_;
|
||||
|
||||
/// Resolve direct call object dependencies;
|
||||
LocalDependencyResolver resolver_;
|
||||
|
||||
friend class CoreWorkerTest;
|
||||
};
|
||||
|
||||
|
|
|
@ -1,85 +1,9 @@
|
|||
#include "ray/core_worker/transport/direct_task_transport.h"
|
||||
#include "ray/core_worker/transport/dependency_resolver.h"
|
||||
#include "ray/core_worker/transport/direct_actor_transport.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr<RayObject> value,
|
||||
TaskSpecification &task) {
|
||||
auto &msg = task.GetMutableMessage();
|
||||
bool found = false;
|
||||
for (size_t i = 0; i < task.NumArgs(); i++) {
|
||||
auto count = task.ArgIdCount(i);
|
||||
if (count > 0) {
|
||||
const auto &id = task.ArgId(i, 0);
|
||||
if (id == obj_id) {
|
||||
auto *mutable_arg = msg.mutable_args(i);
|
||||
mutable_arg->clear_object_ids();
|
||||
if (value->IsInPlasmaError()) {
|
||||
// Promote the object id to plasma.
|
||||
mutable_arg->add_object_ids(
|
||||
obj_id.WithTransportType(TaskTransportType::RAYLET).Binary());
|
||||
} else {
|
||||
// Inline the object value.
|
||||
if (value->HasData()) {
|
||||
const auto &data = value->GetData();
|
||||
mutable_arg->set_data(data->Data(), data->Size());
|
||||
}
|
||||
if (value->HasMetadata()) {
|
||||
const auto &metadata = value->GetMetadata();
|
||||
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
|
||||
}
|
||||
}
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
RAY_CHECK(found) << "obj id " << obj_id << " not found";
|
||||
}
|
||||
|
||||
void LocalDependencyResolver::ResolveDependencies(const TaskSpecification &task,
|
||||
std::function<void()> on_complete) {
|
||||
absl::flat_hash_set<ObjectID> local_dependencies;
|
||||
for (size_t i = 0; i < task.NumArgs(); i++) {
|
||||
auto count = task.ArgIdCount(i);
|
||||
if (count > 0) {
|
||||
RAY_CHECK(count <= 1) << "multi args not implemented";
|
||||
const auto &id = task.ArgId(i, 0);
|
||||
if (id.IsDirectCallType()) {
|
||||
local_dependencies.insert(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (local_dependencies.empty()) {
|
||||
on_complete();
|
||||
return;
|
||||
}
|
||||
|
||||
// This is deleted when the last dependency fetch callback finishes.
|
||||
std::shared_ptr<TaskState> state =
|
||||
std::shared_ptr<TaskState>(new TaskState{task, std::move(local_dependencies)});
|
||||
num_pending_ += 1;
|
||||
|
||||
for (const auto &obj_id : state->local_dependencies) {
|
||||
in_memory_store_->GetAsync(
|
||||
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
|
||||
RAY_CHECK(obj != nullptr);
|
||||
bool complete = false;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
state->local_dependencies.erase(obj_id);
|
||||
DoInlineObjectValue(obj_id, obj, state->task);
|
||||
if (state->local_dependencies.empty()) {
|
||||
complete = true;
|
||||
num_pending_ -= 1;
|
||||
}
|
||||
}
|
||||
if (complete) {
|
||||
on_complete();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
||||
resolver_.ResolveDependencies(task_spec, [this, task_spec]() {
|
||||
// TODO(ekl) should have a queue per distinct resource type required
|
||||
|
@ -91,7 +15,7 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
|||
}
|
||||
|
||||
void CoreWorkerDirectTaskSubmitter::HandleWorkerLeaseGranted(
|
||||
const WorkerAddress &addr, std::shared_ptr<WorkerLeaseInterface> lease_client) {
|
||||
const rpc::WorkerAddress &addr, std::shared_ptr<WorkerLeaseInterface> lease_client) {
|
||||
// Setup client state for this worker.
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
@ -110,7 +34,7 @@ void CoreWorkerDirectTaskSubmitter::HandleWorkerLeaseGranted(
|
|||
OnWorkerIdle(addr, /*error=*/false);
|
||||
}
|
||||
|
||||
void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(const WorkerAddress &addr,
|
||||
void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(const rpc::WorkerAddress &addr,
|
||||
bool was_error) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
if (queued_tasks_.empty() || was_error) {
|
||||
|
@ -199,7 +123,7 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(
|
|||
worker_request_pending_ = true;
|
||||
}
|
||||
|
||||
void CoreWorkerDirectTaskSubmitter::PushNormalTask(const WorkerAddress &addr,
|
||||
void CoreWorkerDirectTaskSubmitter::PushNormalTask(const rpc::WorkerAddress &addr,
|
||||
rpc::CoreWorkerClientInterface &client,
|
||||
TaskSpecification &task_spec) {
|
||||
auto task_id = task_spec.TaskId();
|
||||
|
|
|
@ -7,53 +7,13 @@
|
|||
#include "ray/common/ray_object.h"
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/store_provider/memory_store_provider.h"
|
||||
#include "ray/core_worker/transport/dependency_resolver.h"
|
||||
#include "ray/core_worker/transport/direct_actor_transport.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "ray/rpc/worker/core_worker_client.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
struct TaskState {
|
||||
/// The task to be run.
|
||||
TaskSpecification task;
|
||||
/// The remaining dependencies to resolve for this task.
|
||||
absl::flat_hash_set<ObjectID> local_dependencies;
|
||||
};
|
||||
|
||||
// This class is thread-safe.
|
||||
class LocalDependencyResolver {
|
||||
public:
|
||||
LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
|
||||
: in_memory_store_(store_provider), num_pending_(0) {}
|
||||
|
||||
/// Resolve all local and remote dependencies for the task, calling the specified
|
||||
/// callback when done. Direct call ids in the task specification will be resolved
|
||||
/// to concrete values and inlined.
|
||||
//
|
||||
/// Note: This method **will mutate** the given TaskSpecification.
|
||||
///
|
||||
/// Postcondition: all direct call ids in arguments are converted to values.
|
||||
void ResolveDependencies(const TaskSpecification &task,
|
||||
std::function<void()> on_complete);
|
||||
|
||||
/// Return the number of tasks pending dependency resolution.
|
||||
/// TODO(ekl) this should be exposed in worker stats.
|
||||
int NumPendingTasks() const { return num_pending_; }
|
||||
|
||||
private:
|
||||
/// The store provider.
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> in_memory_store_;
|
||||
|
||||
/// Number of tasks pending dependency resolution.
|
||||
std::atomic<int> num_pending_;
|
||||
|
||||
/// Protects against concurrent access to internal state.
|
||||
absl::Mutex mu_;
|
||||
};
|
||||
|
||||
typedef std::pair<std::string, int> WorkerAddress;
|
||||
typedef std::function<std::shared_ptr<rpc::CoreWorkerClientInterface>(WorkerAddress)>
|
||||
ClientFactoryFn;
|
||||
typedef std::function<std::shared_ptr<WorkerLeaseInterface>(const rpc::Address &)>
|
||||
LeaseClientFactoryFn;
|
||||
|
||||
|
@ -61,8 +21,8 @@ typedef std::function<std::shared_ptr<WorkerLeaseInterface>(const rpc::Address &
|
|||
class CoreWorkerDirectTaskSubmitter {
|
||||
public:
|
||||
CoreWorkerDirectTaskSubmitter(
|
||||
std::shared_ptr<WorkerLeaseInterface> lease_client, ClientFactoryFn client_factory,
|
||||
LeaseClientFactoryFn lease_client_factory,
|
||||
std::shared_ptr<WorkerLeaseInterface> lease_client,
|
||||
rpc::ClientFactoryFn client_factory, LeaseClientFactoryFn lease_client_factory,
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
|
||||
: local_lease_client_(lease_client),
|
||||
client_factory_(client_factory),
|
||||
|
@ -79,7 +39,7 @@ class CoreWorkerDirectTaskSubmitter {
|
|||
/// Schedule more work onto an idle worker or return it back to the raylet if
|
||||
/// no more tasks are queued for submission. If an error was encountered
|
||||
/// processing the worker, we don't attempt to re-use the worker.
|
||||
void OnWorkerIdle(const WorkerAddress &addr, bool was_error);
|
||||
void OnWorkerIdle(const rpc::WorkerAddress &addr, bool was_error);
|
||||
|
||||
/// Get an existing lease client or connect a new one. If a raylet_address is
|
||||
/// provided, this connects to a remote raylet. Else, this connects to the
|
||||
|
@ -98,11 +58,12 @@ class CoreWorkerDirectTaskSubmitter {
|
|||
/// Callback for when the raylet grants us a worker lease. The worker is returned
|
||||
/// to the raylet via the given lease client once the task queue is empty.
|
||||
/// TODO: Implement a lease term by which we need to return the worker.
|
||||
void HandleWorkerLeaseGranted(const WorkerAddress &addr,
|
||||
void HandleWorkerLeaseGranted(const rpc::WorkerAddress &addr,
|
||||
std::shared_ptr<WorkerLeaseInterface> lease_client);
|
||||
|
||||
/// Push a task to a specific worker.
|
||||
void PushNormalTask(const WorkerAddress &addr, rpc::CoreWorkerClientInterface &client,
|
||||
void PushNormalTask(const rpc::WorkerAddress &addr,
|
||||
rpc::CoreWorkerClientInterface &client,
|
||||
TaskSpecification &task_spec);
|
||||
|
||||
// Client that can be used to lease and return workers from the local raylet.
|
||||
|
@ -113,7 +74,7 @@ class CoreWorkerDirectTaskSubmitter {
|
|||
remote_lease_clients_ GUARDED_BY(mu_);
|
||||
|
||||
/// Factory for producing new core worker clients.
|
||||
ClientFactoryFn client_factory_;
|
||||
rpc::ClientFactoryFn client_factory_;
|
||||
|
||||
/// Factory for producing new clients to request leases from remote nodes.
|
||||
LeaseClientFactoryFn lease_client_factory_;
|
||||
|
@ -128,12 +89,12 @@ class CoreWorkerDirectTaskSubmitter {
|
|||
absl::Mutex mu_;
|
||||
|
||||
/// Cache of gRPC clients to other workers.
|
||||
absl::flat_hash_map<WorkerAddress, std::shared_ptr<rpc::CoreWorkerClientInterface>>
|
||||
absl::flat_hash_map<rpc::WorkerAddress, std::shared_ptr<rpc::CoreWorkerClientInterface>>
|
||||
client_cache_ GUARDED_BY(mu_);
|
||||
|
||||
/// Map from worker address to the lease client through which it should be
|
||||
/// returned.
|
||||
absl::flat_hash_map<WorkerAddress, std::shared_ptr<WorkerLeaseInterface>>
|
||||
absl::flat_hash_map<rpc::WorkerAddress, std::shared_ptr<WorkerLeaseInterface>>
|
||||
worker_to_lease_client_ GUARDED_BY(mu_);
|
||||
|
||||
// Whether we have a request to the Raylet to acquire a new worker in flight.
|
||||
|
|
|
@ -33,6 +33,13 @@ const static int64_t RequestSizeInBytes(const PushTaskRequest &request) {
|
|||
return size;
|
||||
}
|
||||
|
||||
// Shared between direct actor and task submitters.
|
||||
// TODO(swang): Remove and replace with rpc::Address.
|
||||
class CoreWorkerClientInterface;
|
||||
typedef std::pair<std::string, int> WorkerAddress;
|
||||
typedef std::function<std::shared_ptr<CoreWorkerClientInterface>(const WorkerAddress &)>
|
||||
ClientFactoryFn;
|
||||
|
||||
/// Abstract client interface for testing.
|
||||
class CoreWorkerClientInterface {
|
||||
public:
|
||||
|
|
Loading…
Add table
Reference in a new issue