Move top level RayletClient to ray::raylet::RayletClient (#6404)

This commit is contained in:
Rong Rong 2019-12-09 21:08:59 -08:00 committed by Philipp Moritz
parent 8c34e8391c
commit c1d4ab8bb4
13 changed files with 136 additions and 119 deletions

View file

@ -41,7 +41,7 @@ ctypedef pair[c_vector[CObjectID], c_vector[CObjectID]] WaitResultPair
cdef extern from "ray/raylet/raylet_client.h" nogil:
cdef cppclass CRayletClient "RayletClient":
cdef cppclass CRayletClient "ray::raylet::RayletClient":
CRayletClient(const c_string &raylet_socket,
const CWorkerID &worker_id,
c_bool is_worker, const CJobID &job_id,

View file

@ -128,7 +128,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
auto grpc_client = rpc::NodeManagerWorkerClient::make(
node_ip_address, node_manager_port, *client_call_manager_);
ClientID local_raylet_id;
local_raylet_client_ = std::shared_ptr<RayletClient>(new RayletClient(
local_raylet_client_ = std::shared_ptr<raylet::RayletClient>(new raylet::RayletClient(
std::move(grpc_client), raylet_socket,
WorkerID::FromBinary(worker_context_.GetWorkerID().Binary()),
(worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(),
@ -210,8 +210,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
[this](const rpc::Address &address) {
auto grpc_client = rpc::NodeManagerWorkerClient::make(
address.ip_address(), address.port(), *client_call_manager_);
return std::shared_ptr<RayletClient>(
new RayletClient(std::move(grpc_client)));
return std::shared_ptr<raylet::RayletClient>(
new raylet::RayletClient(std::move(grpc_client)));
},
memory_store_, task_manager_, local_raylet_id,
RayConfig::instance().worker_lease_timeout_milliseconds()));

View file

@ -89,7 +89,7 @@ class CoreWorker {
WorkerContext &GetWorkerContext() { return worker_context_; }
RayletClient &GetRayletClient() { return *local_raylet_client_; }
raylet::RayletClient &GetRayletClient() { return *local_raylet_client_; }
const TaskID &GetCurrentTaskId() const { return worker_context_.GetCurrentTaskID(); }
@ -525,7 +525,7 @@ class CoreWorker {
// shared_ptr for direct calls because we can lease multiple workers through
// one client, and we need to keep the connection alive until we return all
// of the workers.
std::shared_ptr<RayletClient> local_raylet_client_;
std::shared_ptr<raylet::RayletClient> local_raylet_client_;
// Thread that runs a boost::asio service to process IO events.
std::thread io_thread_;

View file

@ -109,7 +109,7 @@ std::shared_ptr<RayObject> GetRequest::Get(const ObjectID &object_id) const {
CoreWorkerMemoryStore::CoreWorkerMemoryStore(
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma,
std::shared_ptr<ReferenceCounter> counter,
std::shared_ptr<RayletClient> raylet_client)
std::shared_ptr<raylet::RayletClient> raylet_client)
: store_in_plasma_(store_in_plasma),
ref_counter_(counter),
raylet_client_(raylet_client) {}

View file

@ -29,7 +29,7 @@ class CoreWorkerMemoryStore {
CoreWorkerMemoryStore(
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma = nullptr,
std::shared_ptr<ReferenceCounter> counter = nullptr,
std::shared_ptr<RayletClient> raylet_client = nullptr);
std::shared_ptr<raylet::RayletClient> raylet_client = nullptr);
~CoreWorkerMemoryStore(){};
/// Put an object with specified ID into object store.
@ -124,7 +124,7 @@ class CoreWorkerMemoryStore {
std::shared_ptr<ReferenceCounter> ref_counter_ = nullptr;
// If set, this will be used to notify worker blocked / unblocked on get calls.
std::shared_ptr<RayletClient> raylet_client_ = nullptr;
std::shared_ptr<raylet::RayletClient> raylet_client_ = nullptr;
/// Protects the data structures below.
absl::Mutex mu_;

View file

@ -7,7 +7,8 @@
namespace ray {
CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider(
const std::string &store_socket, const std::shared_ptr<RayletClient> raylet_client,
const std::string &store_socket,
const std::shared_ptr<raylet::RayletClient> raylet_client,
std::function<Status()> check_signals)
: raylet_client_(raylet_client) {
check_signals_ = check_signals;
@ -128,7 +129,7 @@ Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore(
return Status::OK();
}
Status UnblockIfNeeded(const std::shared_ptr<RayletClient> &client,
Status UnblockIfNeeded(const std::shared_ptr<raylet::RayletClient> &client,
const WorkerContext &ctx) {
if (ctx.CurrentTaskIsDirectCall()) {
if (ctx.ShouldReleaseResourcesOnBlockingCalls()) {

View file

@ -20,7 +20,7 @@ namespace ray {
class CoreWorkerPlasmaStoreProvider {
public:
CoreWorkerPlasmaStoreProvider(const std::string &store_socket,
const std::shared_ptr<RayletClient> raylet_client,
const std::shared_ptr<raylet::RayletClient> raylet_client,
std::function<Status()> check_signals);
~CoreWorkerPlasmaStoreProvider();
@ -83,7 +83,7 @@ class CoreWorkerPlasmaStoreProvider {
static void WarnIfAttemptedTooManyTimes(int num_attempts,
const absl::flat_hash_set<ObjectID> &remaining);
const std::shared_ptr<RayletClient> raylet_client_;
const std::shared_ptr<raylet::RayletClient> raylet_client_;
plasma::PlasmaClient store_client_;
std::mutex store_client_mutex_;
std::function<Status()> check_signals_;

View file

@ -151,7 +151,7 @@ CoreWorkerDirectTaskReceiver::CoreWorkerDirectTaskReceiver(
exit_handler_(exit_handler),
task_main_io_service_(main_io_service) {}
void CoreWorkerDirectTaskReceiver::Init(RayletClient &raylet_client) {
void CoreWorkerDirectTaskReceiver::Init(raylet::RayletClient &raylet_client) {
waiter_.reset(new DependencyWaiterImpl(raylet_client));
}

View file

@ -167,7 +167,8 @@ class DependencyWaiter {
class DependencyWaiterImpl : public DependencyWaiter {
public:
DependencyWaiterImpl(RayletClient &raylet_client) : raylet_client_(raylet_client) {}
DependencyWaiterImpl(raylet::RayletClient &raylet_client)
: raylet_client_(raylet_client) {}
void Wait(const std::vector<ObjectID> &dependencies,
std::function<void()> on_dependencies_available) override {
@ -187,7 +188,7 @@ class DependencyWaiterImpl : public DependencyWaiter {
private:
int64_t next_request_id_ = 0;
std::unordered_map<int64_t, std::function<void()>> requests_;
RayletClient &raylet_client_;
raylet::RayletClient &raylet_client_;
};
/// Wraps a thread-pool to block posts until the pool has free slots. This is used
@ -436,7 +437,7 @@ class CoreWorkerDirectTaskReceiver {
}
/// Initialize this receiver. This must be called prior to use.
void Init(RayletClient &client);
void Init(raylet::RayletClient &client);
/// Handle a `PushTask` request.
///

View file

@ -6,7 +6,7 @@
namespace ray {
CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver(
const WorkerID &worker_id, std::shared_ptr<RayletClient> &raylet_client,
const WorkerID &worker_id, std::shared_ptr<raylet::RayletClient> &raylet_client,
const TaskHandler &task_handler, const std::function<void()> &exit_handler)
: worker_id_(worker_id),
raylet_client_(raylet_client),

View file

@ -17,7 +17,7 @@ class CoreWorkerRayletTaskReceiver {
std::vector<std::shared_ptr<RayObject>> *return_objects)>;
CoreWorkerRayletTaskReceiver(const WorkerID &worker_id,
std::shared_ptr<RayletClient> &raylet_client,
std::shared_ptr<raylet::RayletClient> &raylet_client,
const TaskHandler &task_handler,
const std::function<void()> &exit_handler);
@ -37,7 +37,7 @@ class CoreWorkerRayletTaskReceiver {
WorkerID worker_id_;
/// Reference to the core worker's raylet client. This is a pointer ref so that it
/// can be initialized by core worker after this class is constructed.
std::shared_ptr<RayletClient> &raylet_client_;
std::shared_ptr<raylet::RayletClient> &raylet_client_;
/// The callback function to process a task.
TaskHandler task_handler_;
/// The callback function to exit the worker.

View file

@ -92,8 +92,10 @@ int write_bytes(Socket &conn, uint8_t *cursor, size_t length) {
return 0;
}
RayletConnection::RayletConnection(const std::string &raylet_socket, int num_retries,
int64_t timeout) {
namespace ray {
raylet::RayletConnection::RayletConnection(const std::string &raylet_socket,
int num_retries, int64_t timeout) {
// Pick the default values if the user did not specify.
if (num_retries < 0) {
num_retries = RayConfig::instance().num_connect_attempts();
@ -122,9 +124,9 @@ RayletConnection::RayletConnection(const std::string &raylet_socket, int num_ret
}
}
ray::Status RayletConnection::Disconnect() {
Status raylet::RayletConnection::Disconnect() {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateDisconnectClient(fbb);
auto message = protocol::CreateDisconnectClient(fbb);
fbb.Finish(message);
auto status = WriteMessage(MessageType::IntentionalDisconnectClient, &fbb);
// Don't be too strict for disconnection errors.
@ -133,11 +135,11 @@ ray::Status RayletConnection::Disconnect() {
RAY_LOG(ERROR) << status.ToString()
<< " [RayletClient] Failed to disconnect from raylet.";
}
return ray::Status::OK();
return Status::OK();
}
ray::Status RayletConnection::ReadMessage(MessageType type,
std::unique_ptr<uint8_t[]> &message) {
Status raylet::RayletConnection::ReadMessage(MessageType type,
std::unique_ptr<uint8_t[]> &message) {
int64_t cookie;
int64_t type_field;
int64_t length;
@ -159,26 +161,26 @@ ray::Status RayletConnection::ReadMessage(MessageType type,
length = 0;
}
if (type_field == static_cast<int64_t>(MessageType::DisconnectClient)) {
return ray::Status::IOError("[RayletClient] Raylet connection closed.");
return Status::IOError("[RayletClient] Raylet connection closed.");
}
if (type_field != static_cast<int64_t>(type)) {
return ray::Status::TypeError(
return Status::TypeError(
std::string("[RayletClient] Raylet connection corrupted. ") +
"Expected message type: " + std::to_string(static_cast<int64_t>(type)) +
"; got message type: " + std::to_string(type_field) +
". Check logs or dmesg for previous errors.");
}
return ray::Status::OK();
return Status::OK();
}
ray::Status RayletConnection::WriteMessage(MessageType type,
flatbuffers::FlatBufferBuilder *fbb) {
Status raylet::RayletConnection::WriteMessage(MessageType type,
flatbuffers::FlatBufferBuilder *fbb) {
std::unique_lock<std::mutex> guard(write_mutex_);
int64_t cookie = RayConfig::instance().ray_cookie();
int64_t length = fbb ? fbb->GetSize() : 0;
uint8_t *bytes = fbb ? fbb->GetBufferPointer() : nullptr;
int64_t type_field = static_cast<int64_t>(type);
auto io_error = ray::Status::IOError("[RayletClient] Connection closed unexpectedly.");
auto io_error = Status::IOError("[RayletClient] Connection closed unexpectedly.");
int closed;
closed = write_bytes(conn_, (uint8_t *)&cookie, sizeof(cookie));
if (closed) return io_error;
@ -188,10 +190,10 @@ ray::Status RayletConnection::WriteMessage(MessageType type,
if (closed) return io_error;
closed = write_bytes(conn_, bytes, length * sizeof(char));
if (closed) return io_error;
return ray::Status::OK();
return Status::OK();
}
ray::Status RayletConnection::AtomicRequestReply(
Status raylet::RayletConnection::AtomicRequestReply(
MessageType request_type, MessageType reply_type,
std::unique_ptr<uint8_t[]> &reply_message, flatbuffers::FlatBufferBuilder *fbb) {
std::unique_lock<std::mutex> guard(mutex_);
@ -200,19 +202,21 @@ ray::Status RayletConnection::AtomicRequestReply(
return ReadMessage(reply_type, reply_message);
}
RayletClient::RayletClient(std::shared_ptr<ray::rpc::NodeManagerWorkerClient> grpc_client)
raylet::RayletClient::RayletClient(
std::shared_ptr<rpc::NodeManagerWorkerClient> grpc_client)
: grpc_client_(std::move(grpc_client)) {}
RayletClient::RayletClient(std::shared_ptr<ray::rpc::NodeManagerWorkerClient> grpc_client,
const std::string &raylet_socket, const WorkerID &worker_id,
bool is_worker, const JobID &job_id, const Language &language,
ClientID *raylet_id, int port)
raylet::RayletClient::RayletClient(
std::shared_ptr<rpc::NodeManagerWorkerClient> grpc_client,
const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker,
const JobID &job_id, const Language &language, ClientID *raylet_id, int port)
: grpc_client_(std::move(grpc_client)), worker_id_(worker_id), job_id_(job_id) {
// For C++14, we could use std::make_unique
conn_ = std::unique_ptr<RayletConnection>(new RayletConnection(raylet_socket, -1, -1));
conn_ = std::unique_ptr<raylet::RayletConnection>(
new raylet::RayletConnection(raylet_socket, -1, -1));
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateRegisterClientRequest(
auto message = protocol::CreateRegisterClientRequest(
fbb, is_worker, to_flatbuf(fbb, worker_id), getpid(), to_flatbuf(fbb, job_id),
language, port);
fbb.Finish(message);
@ -222,12 +226,11 @@ RayletClient::RayletClient(std::shared_ptr<ray::rpc::NodeManagerWorkerClient> gr
auto status = conn_->AtomicRequestReply(MessageType::RegisterClientRequest,
MessageType::RegisterClientReply, reply, &fbb);
RAY_CHECK_OK_PREPEND(status, "[RayletClient] Unable to register worker with raylet.");
auto reply_message =
flatbuffers::GetRoot<ray::protocol::RegisterClientReply>(reply.get());
auto reply_message = flatbuffers::GetRoot<protocol::RegisterClientReply>(reply.get());
*raylet_id = ClientID::FromBinary(reply_message->raylet_id()->str());
}
ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) {
Status raylet::RayletClient::SubmitTask(const TaskSpecification &task_spec) {
for (size_t i = 0; i < task_spec.NumArgs(); i++) {
if (task_spec.ArgByRef(i)) {
for (size_t j = 0; j < task_spec.ArgIdCount(i); j++) {
@ -237,58 +240,57 @@ ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) {
}
}
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateSubmitTaskRequest(
fbb, fbb.CreateString(task_spec.Serialize()));
auto message =
protocol::CreateSubmitTaskRequest(fbb, fbb.CreateString(task_spec.Serialize()));
fbb.Finish(message);
return conn_->WriteMessage(MessageType::SubmitTask, &fbb);
}
ray::Status RayletClient::TaskDone() {
Status raylet::RayletClient::TaskDone() {
return conn_->WriteMessage(MessageType::TaskDone);
}
ray::Status RayletClient::FetchOrReconstruct(const std::vector<ObjectID> &object_ids,
bool fetch_only, bool mark_worker_blocked,
const TaskID &current_task_id) {
Status raylet::RayletClient::FetchOrReconstruct(const std::vector<ObjectID> &object_ids,
bool fetch_only, bool mark_worker_blocked,
const TaskID &current_task_id) {
flatbuffers::FlatBufferBuilder fbb;
auto object_ids_message = to_flatbuf(fbb, object_ids);
auto message = ray::protocol::CreateFetchOrReconstruct(
fbb, object_ids_message, fetch_only, mark_worker_blocked,
to_flatbuf(fbb, current_task_id));
auto message = protocol::CreateFetchOrReconstruct(fbb, object_ids_message, fetch_only,
mark_worker_blocked,
to_flatbuf(fbb, current_task_id));
fbb.Finish(message);
auto status = conn_->WriteMessage(MessageType::FetchOrReconstruct, &fbb);
return status;
}
ray::Status RayletClient::NotifyUnblocked(const TaskID &current_task_id) {
Status raylet::RayletClient::NotifyUnblocked(const TaskID &current_task_id) {
flatbuffers::FlatBufferBuilder fbb;
auto message =
ray::protocol::CreateNotifyUnblocked(fbb, to_flatbuf(fbb, current_task_id));
auto message = protocol::CreateNotifyUnblocked(fbb, to_flatbuf(fbb, current_task_id));
fbb.Finish(message);
return conn_->WriteMessage(MessageType::NotifyUnblocked, &fbb);
}
ray::Status RayletClient::NotifyDirectCallTaskBlocked() {
Status raylet::RayletClient::NotifyDirectCallTaskBlocked() {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateNotifyDirectCallTaskBlocked(fbb);
auto message = protocol::CreateNotifyDirectCallTaskBlocked(fbb);
fbb.Finish(message);
return conn_->WriteMessage(MessageType::NotifyDirectCallTaskBlocked, &fbb);
}
ray::Status RayletClient::NotifyDirectCallTaskUnblocked() {
Status raylet::RayletClient::NotifyDirectCallTaskUnblocked() {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateNotifyDirectCallTaskUnblocked(fbb);
auto message = protocol::CreateNotifyDirectCallTaskUnblocked(fbb);
fbb.Finish(message);
return conn_->WriteMessage(MessageType::NotifyDirectCallTaskUnblocked, &fbb);
}
ray::Status RayletClient::Wait(const std::vector<ObjectID> &object_ids, int num_returns,
int64_t timeout_milliseconds, bool wait_local,
bool mark_worker_blocked, const TaskID &current_task_id,
WaitResultPair *result) {
Status raylet::RayletClient::Wait(const std::vector<ObjectID> &object_ids,
int num_returns, int64_t timeout_milliseconds,
bool wait_local, bool mark_worker_blocked,
const TaskID &current_task_id, WaitResultPair *result) {
// Write request.
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateWaitRequest(
auto message = protocol::CreateWaitRequest(
fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, wait_local,
mark_worker_blocked, to_flatbuf(fbb, current_task_id));
fbb.Finish(message);
@ -297,7 +299,7 @@ ray::Status RayletClient::Wait(const std::vector<ObjectID> &object_ids, int num_
MessageType::WaitReply, reply, &fbb);
if (!status.ok()) return status;
// Parse the flatbuffer object.
auto reply_message = flatbuffers::GetRoot<ray::protocol::WaitReply>(reply.get());
auto reply_message = flatbuffers::GetRoot<protocol::WaitReply>(reply.get());
auto found = reply_message->found();
for (size_t i = 0; i < found->size(); i++) {
ObjectID object_id = ObjectID::FromBinary(found->Get(i)->str());
@ -308,22 +310,23 @@ ray::Status RayletClient::Wait(const std::vector<ObjectID> &object_ids, int num_
ObjectID object_id = ObjectID::FromBinary(remaining->Get(i)->str());
result->second.push_back(object_id);
}
return ray::Status::OK();
return Status::OK();
}
ray::Status RayletClient::WaitForDirectActorCallArgs(
Status raylet::RayletClient::WaitForDirectActorCallArgs(
const std::vector<ObjectID> &object_ids, int64_t tag) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateWaitForDirectActorCallArgsRequest(
auto message = protocol::CreateWaitForDirectActorCallArgsRequest(
fbb, to_flatbuf(fbb, object_ids), tag);
fbb.Finish(message);
return conn_->WriteMessage(MessageType::WaitForDirectActorCallArgsRequest, &fbb);
}
ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string &type,
const std::string &error_message, double timestamp) {
Status raylet::RayletClient::PushError(const JobID &job_id, const std::string &type,
const std::string &error_message,
double timestamp) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreatePushErrorRequest(
auto message = protocol::CreatePushErrorRequest(
fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type),
fbb.CreateString(error_message), timestamp);
fbb.Finish(message);
@ -331,7 +334,7 @@ ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string
return conn_->WriteMessage(MessageType::PushErrorRequest, &fbb);
}
ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) {
Status raylet::RayletClient::PushProfileEvents(const ProfileTableData &profile_events) {
flatbuffers::FlatBufferBuilder fbb;
auto message = fbb.CreateString(profile_events.SerializeAsString());
fbb.Finish(message);
@ -342,13 +345,13 @@ ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_even
RAY_LOG(ERROR) << status.ToString()
<< " [RayletClient] Failed to push profile events.";
}
return ray::Status::OK();
return Status::OK();
}
ray::Status RayletClient::FreeObjects(const std::vector<ray::ObjectID> &object_ids,
bool local_only, bool delete_creating_tasks) {
Status raylet::RayletClient::FreeObjects(const std::vector<ObjectID> &object_ids,
bool local_only, bool delete_creating_tasks) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateFreeObjectsRequest(
auto message = protocol::CreateFreeObjectsRequest(
fbb, local_only, delete_creating_tasks, to_flatbuf(fbb, object_ids));
fbb.Finish(message);
@ -356,11 +359,11 @@ ray::Status RayletClient::FreeObjects(const std::vector<ray::ObjectID> &object_i
return status;
}
ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id,
ActorCheckpointID &checkpoint_id) {
Status raylet::RayletClient::PrepareActorCheckpoint(const ActorID &actor_id,
ActorCheckpointID &checkpoint_id) {
flatbuffers::FlatBufferBuilder fbb;
auto message =
ray::protocol::CreatePrepareActorCheckpointRequest(fbb, to_flatbuf(fbb, actor_id));
protocol::CreatePrepareActorCheckpointRequest(fbb, to_flatbuf(fbb, actor_id));
fbb.Finish(message);
std::unique_ptr<uint8_t[]> reply;
@ -369,57 +372,58 @@ ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id,
MessageType::PrepareActorCheckpointReply, reply, &fbb);
if (!status.ok()) return status;
auto reply_message =
flatbuffers::GetRoot<ray::protocol::PrepareActorCheckpointReply>(reply.get());
flatbuffers::GetRoot<protocol::PrepareActorCheckpointReply>(reply.get());
checkpoint_id = ActorCheckpointID::FromBinary(reply_message->checkpoint_id()->str());
return ray::Status::OK();
return Status::OK();
}
ray::Status RayletClient::NotifyActorResumedFromCheckpoint(
Status raylet::RayletClient::NotifyActorResumedFromCheckpoint(
const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateNotifyActorResumedFromCheckpoint(
auto message = protocol::CreateNotifyActorResumedFromCheckpoint(
fbb, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, checkpoint_id));
fbb.Finish(message);
return conn_->WriteMessage(MessageType::NotifyActorResumedFromCheckpoint, &fbb);
}
ray::Status RayletClient::SetResource(const std::string &resource_name,
const double capacity,
const ray::ClientID &client_Id) {
Status raylet::RayletClient::SetResource(const std::string &resource_name,
const double capacity,
const ClientID &client_Id) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateSetResourceRequest(
fbb, fbb.CreateString(resource_name), capacity, to_flatbuf(fbb, client_Id));
auto message = protocol::CreateSetResourceRequest(fbb, fbb.CreateString(resource_name),
capacity, to_flatbuf(fbb, client_Id));
fbb.Finish(message);
return conn_->WriteMessage(MessageType::SetResourceRequest, &fbb);
}
ray::Status RayletClient::ReportActiveObjectIDs(
Status raylet::RayletClient::ReportActiveObjectIDs(
const std::unordered_set<ObjectID> &object_ids) {
flatbuffers::FlatBufferBuilder fbb;
auto message =
ray::protocol::CreateReportActiveObjectIDs(fbb, to_flatbuf(fbb, object_ids));
auto message = protocol::CreateReportActiveObjectIDs(fbb, to_flatbuf(fbb, object_ids));
fbb.Finish(message);
return conn_->WriteMessage(MessageType::ReportActiveObjectIDs, &fbb);
}
ray::Status RayletClient::RequestWorkerLease(
const ray::TaskSpecification &resource_spec,
const ray::rpc::ClientCallback<ray::rpc::WorkerLeaseReply> &callback) {
ray::rpc::WorkerLeaseRequest request;
Status raylet::RayletClient::RequestWorkerLease(
const TaskSpecification &resource_spec,
const rpc::ClientCallback<rpc::WorkerLeaseReply> &callback) {
rpc::WorkerLeaseRequest request;
request.mutable_resource_spec()->CopyFrom(resource_spec.GetMessage());
return grpc_client_->RequestWorkerLease(request, callback);
}
ray::Status RayletClient::ReturnWorker(int worker_port, bool disconnect_worker) {
ray::rpc::ReturnWorkerRequest request;
Status raylet::RayletClient::ReturnWorker(int worker_port, bool disconnect_worker) {
rpc::ReturnWorkerRequest request;
request.set_worker_port(worker_port);
request.set_disconnect_worker(disconnect_worker);
return grpc_client_->ReturnWorker(
request, [](const ray::Status &status, const ray::rpc::ReturnWorkerReply &reply) {
request, [](const Status &status, const rpc::ReturnWorkerReply &reply) {
if (!status.ok()) {
RAY_LOG(INFO) << "Error returning worker: " << status;
}
});
}
} // namespace ray

View file

@ -30,6 +30,29 @@ using ResourceMappingType =
using Socket = boost::asio::detail::socket_holder;
using WaitResultPair = std::pair<std::vector<ObjectID>, std::vector<ObjectID>>;
namespace ray {
/// Interface for leasing workers. Abstract for testing.
class WorkerLeaseInterface {
public:
/// Requests a worker from the raylet. The callback will be sent via gRPC.
/// \param resource_spec Resources that should be allocated for the worker.
/// \return ray::Status
virtual ray::Status RequestWorkerLease(
const ray::TaskSpecification &resource_spec,
const ray::rpc::ClientCallback<ray::rpc::WorkerLeaseReply> &callback) = 0;
/// Returns a worker to the raylet.
/// \param worker_port The local port of the worker on the raylet node.
/// \param disconnect_worker Whether the raylet should disconnect the worker.
/// \return ray::Status
virtual ray::Status ReturnWorker(int worker_port, bool disconnect_worker) = 0;
virtual ~WorkerLeaseInterface(){};
};
namespace raylet {
class RayletConnection {
public:
/// Connect to the raylet.
@ -49,9 +72,12 @@ class RayletConnection {
///
/// \return ray::Status.
ray::Status Disconnect();
ray::Status ReadMessage(MessageType type, std::unique_ptr<uint8_t[]> &message);
ray::Status WriteMessage(MessageType type,
flatbuffers::FlatBufferBuilder *fbb = nullptr);
ray::Status AtomicRequestReply(MessageType request_type, MessageType reply_type,
std::unique_ptr<uint8_t[]> &reply_message,
flatbuffers::FlatBufferBuilder *fbb = nullptr);
@ -65,25 +91,6 @@ class RayletConnection {
std::mutex write_mutex_;
};
/// Interface for leasing workers. Abstract for testing.
class WorkerLeaseInterface {
public:
/// Requests a worker from the raylet. The callback will be sent via gRPC.
/// \param resource_spec Resources that should be allocated for the worker.
/// \return ray::Status
virtual ray::Status RequestWorkerLease(
const ray::TaskSpecification &resource_spec,
const ray::rpc::ClientCallback<ray::rpc::WorkerLeaseReply> &callback) = 0;
/// Returns a worker to the raylet.
/// \param worker_port The local port of the worker on the raylet node.
/// \param disconnect_worker Whether the raylet should disconnect the worker.
/// \return ray::Status
virtual ray::Status ReturnWorker(int worker_port, bool disconnect_worker) = 0;
virtual ~WorkerLeaseInterface(){};
};
class RayletClient : public WorkerLeaseInterface {
public:
/// Connect to the raylet.
@ -258,4 +265,8 @@ class RayletClient : public WorkerLeaseInterface {
std::unique_ptr<RayletConnection> conn_;
};
} // namespace raylet
} // namespace ray
#endif