[Core] WorkerInterface refactor (#9655)

* .

* .

* refactor WorkerInterface

* .

* Basic unit test structure complete?

* .

* .

* .

* .

* Fixed tests

* Fixed tests

* .
This commit is contained in:
Alex Wu 2020-07-23 21:13:29 -07:00 committed by GitHub
parent 06c3518aa1
commit 239196fffc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 213 additions and 98 deletions

View file

@ -105,7 +105,8 @@ namespace raylet {
// A helper function to print the leased workers.
std::string LeasedWorkersSring(
const std::unordered_map<WorkerID, std::shared_ptr<Worker>> &leased_workers) {
const std::unordered_map<WorkerID, std::shared_ptr<WorkerInterface>>
&leased_workers) {
std::stringstream buffer;
buffer << " @leased_workers: (";
for (const auto &pair : leased_workers) {
@ -117,7 +118,8 @@ std::string LeasedWorkersSring(
}
// A helper function to print the workers in worker_pool_.
std::string WorkerPoolString(const std::vector<std::shared_ptr<Worker>> &worker_pool) {
std::string WorkerPoolString(
const std::vector<std::shared_ptr<WorkerInterface>> &worker_pool) {
std::stringstream buffer;
buffer << " @worker_pool: (";
for (const auto &worker : worker_pool) {
@ -128,7 +130,7 @@ std::string WorkerPoolString(const std::vector<std::shared_ptr<Worker>> &worker_
}
// Helper function to print the worker's owner worker and and node owner.
std::string WorkerOwnerString(std::shared_ptr<Worker> &worker) {
std::string WorkerOwnerString(std::shared_ptr<WorkerInterface> &worker) {
std::stringstream buffer;
const auto owner_worker_id =
WorkerID::FromBinary(worker->GetOwnerAddress().worker_id());
@ -320,7 +322,7 @@ ray::Status NodeManager::RegisterGcs() {
return ray::Status::OK();
}
void NodeManager::KillWorker(std::shared_ptr<Worker> worker) {
void NodeManager::KillWorker(std::shared_ptr<WorkerInterface> worker) {
#ifdef _WIN32
// TODO(mehrdadn): implement graceful process termination mechanism
#else
@ -1072,7 +1074,7 @@ void NodeManager::DispatchTasks(
// Try to get an idle worker to execute this task. If nullptr, there
// aren't any available workers so we can't assign the task.
std::shared_ptr<Worker> worker =
std::shared_ptr<WorkerInterface> worker =
worker_pool_.PopWorker(task.GetTaskSpecification());
if (worker != nullptr) {
AssignTask(worker, task, &post_assign_callbacks);
@ -1145,11 +1147,11 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr<ClientConnection> &
ProcessFetchOrReconstructMessage(client, message_data);
} break;
case protocol::MessageType::NotifyDirectCallTaskBlocked: {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
HandleDirectCallTaskBlocked(worker);
} break;
case protocol::MessageType::NotifyDirectCallTaskUnblocked: {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
HandleDirectCallTaskUnblocked(worker);
} break;
case protocol::MessageType::NotifyUnblocked: {
@ -1214,8 +1216,8 @@ void NodeManager::ProcessRegisterClientRequestMessage(
WorkerID worker_id = from_flatbuf<WorkerID>(*message->worker_id());
pid_t pid = message->worker_pid();
std::string worker_ip_address = string_from_flatbuf(*message->ip_address());
auto worker = std::make_shared<Worker>(worker_id, language, worker_ip_address, client,
client_call_manager_);
auto worker = std::dynamic_pointer_cast<WorkerInterface>(std::make_shared<Worker>(
worker_id, language, worker_ip_address, client, client_call_manager_));
int assigned_port;
if (message->is_worker()) {
@ -1269,7 +1271,7 @@ void NodeManager::ProcessRegisterClientRequestMessage(
void NodeManager::ProcessAnnounceWorkerPortMessage(
const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data) {
bool is_worker = true;
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
if (worker == nullptr) {
is_worker = false;
worker = worker_pool_.GetRegisteredDriver(client);
@ -1345,11 +1347,11 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca
}
void NodeManager::HandleWorkerAvailable(const std::shared_ptr<ClientConnection> &client) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
HandleWorkerAvailable(worker);
}
void NodeManager::HandleWorkerAvailable(const std::shared_ptr<Worker> &worker) {
void NodeManager::HandleWorkerAvailable(const std::shared_ptr<WorkerInterface> &worker) {
RAY_CHECK(worker);
bool worker_idle = true;
@ -1376,7 +1378,7 @@ void NodeManager::HandleWorkerAvailable(const std::shared_ptr<Worker> &worker) {
void NodeManager::ProcessDisconnectClientMessage(
const std::shared_ptr<ClientConnection> &client, bool intentional_disconnect) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
bool is_worker = false, is_driver = false;
if (worker) {
// The client is a worker.
@ -1617,7 +1619,8 @@ void NodeManager::ProcessWaitForDirectActorCallArgsRequestMessage(
object_ids, -1, object_ids.size(), false,
[this, client, tag](std::vector<ObjectID> found, std::vector<ObjectID> remaining) {
RAY_CHECK(remaining.empty());
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker =
worker_pool_.GetRegisteredWorker(client);
if (!worker) {
RAY_LOG(ERROR) << "Lost worker for wait request " << client;
} else {
@ -1647,7 +1650,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest(
const auto &actor_entry = actor_registry_.find(actor_id);
RAY_CHECK(actor_entry != actor_registry_.end());
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
RAY_CHECK(worker && worker->GetActorId() == actor_id);
std::shared_ptr<ActorCheckpointData> checkpoint_data =
@ -1822,7 +1825,7 @@ void NodeManager::HandleReturnWorker(const rpc::ReturnWorkerRequest &request,
rpc::SendReplyCallback send_reply_callback) {
// Read the resource spec submitted by the client.
auto worker_id = WorkerID::FromBinary(request.worker_id());
std::shared_ptr<Worker> worker = leased_workers_[worker_id];
std::shared_ptr<WorkerInterface> worker = leased_workers_[worker_id];
Status status;
leased_workers_.erase(worker_id);
@ -2320,7 +2323,8 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag
}
}
void NodeManager::HandleDirectCallTaskBlocked(const std::shared_ptr<Worker> &worker) {
void NodeManager::HandleDirectCallTaskBlocked(
const std::shared_ptr<WorkerInterface> &worker) {
if (new_scheduler_enabled_) {
if (!worker) {
return;
@ -2349,7 +2353,8 @@ void NodeManager::HandleDirectCallTaskBlocked(const std::shared_ptr<Worker> &wor
DispatchTasks(local_queues_.GetReadyTasksByClass());
}
void NodeManager::HandleDirectCallTaskUnblocked(const std::shared_ptr<Worker> &worker) {
void NodeManager::HandleDirectCallTaskUnblocked(
const std::shared_ptr<WorkerInterface> &worker) {
if (new_scheduler_enabled_) {
if (!worker) {
return;
@ -2406,7 +2411,7 @@ void NodeManager::AsyncResolveObjects(
const std::shared_ptr<ClientConnection> &client,
const std::vector<rpc::ObjectReference> &required_object_refs,
const TaskID &current_task_id, bool ray_get, bool mark_worker_blocked) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
if (worker) {
// The client is a worker. If the worker is not already blocked and the
// blocked task matches the one assigned to the worker, then mark the
@ -2460,7 +2465,7 @@ void NodeManager::AsyncResolveObjects(
void NodeManager::AsyncResolveObjectsFinish(
const std::shared_ptr<ClientConnection> &client, const TaskID &current_task_id,
bool was_blocked) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
// TODO(swang): Because the object dependencies are tracked in the task
// dependency manager, we could actually remove this message entirely and
@ -2540,7 +2545,8 @@ void NodeManager::EnqueuePlaceableTask(const Task &task) {
task_dependency_manager_.TaskPending(task);
}
void NodeManager::AssignTask(const std::shared_ptr<Worker> &worker, const Task &task,
void NodeManager::AssignTask(const std::shared_ptr<WorkerInterface> &worker,
const Task &task,
std::vector<std::function<void()>> *post_assign_callbacks) {
const TaskSpecification &spec = task.GetTaskSpecification();
RAY_CHECK(post_assign_callbacks);
@ -2626,7 +2632,7 @@ void NodeManager::AssignTask(const std::shared_ptr<Worker> &worker, const Task &
}
}
bool NodeManager::FinishAssignedTask(Worker &worker) {
bool NodeManager::FinishAssignedTask(WorkerInterface &worker) {
TaskID task_id = worker.GetAssignedTaskId();
RAY_LOG(DEBUG) << "Finished task " << task_id;
@ -2735,7 +2741,7 @@ std::shared_ptr<ActorTableData> NodeManager::CreateActorTableDataFromCreationTas
return actor_info_ptr;
}
void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) {
void NodeManager::FinishAssignedActorTask(WorkerInterface &worker, const Task &task) {
RAY_LOG(DEBUG) << "Finishing assigned actor task";
ActorID actor_id;
TaskID caller_id;
@ -3303,7 +3309,7 @@ void NodeManager::ForwardTask(
});
}
void NodeManager::FinishAssignTask(const std::shared_ptr<Worker> &worker,
void NodeManager::FinishAssignTask(const std::shared_ptr<WorkerInterface> &worker,
const TaskID &task_id, bool success) {
RAY_LOG(DEBUG) << "FinishAssignTask: " << task_id;
// Remove the ASSIGNED task from the READY queue.
@ -3348,7 +3354,8 @@ void NodeManager::FinishAssignTask(const std::shared_ptr<Worker> &worker,
void NodeManager::ProcessSubscribePlasmaReady(
const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data) {
std::shared_ptr<Worker> associated_worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> associated_worker =
worker_pool_.GetRegisteredWorker(client);
if (associated_worker == nullptr) {
associated_worker = worker_pool_.GetRegisteredDriver(client);
}
@ -3361,7 +3368,7 @@ void NodeManager::ProcessSubscribePlasmaReady(
absl::MutexLock guard(&plasma_object_notification_lock_);
if (!async_plasma_objects_notification_.contains(id)) {
async_plasma_objects_notification_.emplace(
id, absl::flat_hash_set<std::shared_ptr<Worker>>());
id, absl::flat_hash_set<std::shared_ptr<WorkerInterface>>());
}
// Only insert a worker once
@ -3375,7 +3382,7 @@ ray::Status NodeManager::SetupPlasmaSubscription() {
return object_manager_.SubscribeObjAdded(
[this](const object_manager::protocol::ObjectInfoT &object_info) {
ObjectID object_id = ObjectID::FromBinary(object_info.object_id);
auto waiting_workers = absl::flat_hash_set<std::shared_ptr<Worker>>();
auto waiting_workers = absl::flat_hash_set<std::shared_ptr<WorkerInterface>>();
{
absl::MutexLock guard(&plasma_object_notification_lock_);
auto waiting = this->async_plasma_objects_notification_.extract(object_id);

View file

@ -256,7 +256,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// \param[in] task The task in question.
/// \param[out] post_assign_callbacks Vector of callbacks that will be appended
/// to with any logic that should run after the DispatchTasks loop runs.
void AssignTask(const std::shared_ptr<Worker> &worker, const Task &task,
void AssignTask(const std::shared_ptr<WorkerInterface> &worker, const Task &task,
std::vector<std::function<void()>> *post_assign_callbacks);
/// Handle a worker finishing its assigned task.
///
@ -264,7 +264,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// \return Whether the worker should be returned to the idle pool. This is
/// only false for direct actor creation calls, which should never be
/// returned to idle.
bool FinishAssignedTask(Worker &worker);
bool FinishAssignedTask(WorkerInterface &worker);
/// Helper function to produce actor table data for a newly created actor.
///
/// \param task_spec Task specification of the actor creation task that created the
@ -276,7 +276,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// \param worker The worker that finished the task.
/// \param task The actor task or actor creation task.
/// \return Void.
void FinishAssignedActorTask(Worker &worker, const Task &task);
void FinishAssignedActorTask(WorkerInterface &worker, const Task &task);
/// Helper function for handling worker to finish its assigned actor task
/// or actor creation task. Gets invoked when tasks's parent actor is known.
///
@ -395,20 +395,20 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// arrive after the worker lease has been returned to the node manager.
///
/// \param worker Shared ptr to the worker, or nullptr if lost.
void HandleDirectCallTaskBlocked(const std::shared_ptr<Worker> &worker);
void HandleDirectCallTaskBlocked(const std::shared_ptr<WorkerInterface> &worker);
/// Handle a direct call task that is unblocked. Note that this callback may
/// arrive after the worker lease has been returned to the node manager.
/// However, it is guaranteed to arrive after DirectCallTaskBlocked.
///
/// \param worker Shared ptr to the worker, or nullptr if lost.
void HandleDirectCallTaskUnblocked(const std::shared_ptr<Worker> &worker);
void HandleDirectCallTaskUnblocked(const std::shared_ptr<WorkerInterface> &worker);
/// Kill a worker.
///
/// \param worker The worker to kill.
/// \return Void.
void KillWorker(std::shared_ptr<Worker> worker);
void KillWorker(std::shared_ptr<WorkerInterface> worker);
/// The callback for handling an actor state transition (e.g., from ALIVE to
/// DEAD), whether as a notification from the actor table or as a handler for
@ -495,7 +495,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
///
/// \param worker The pointer to the worker
/// \return Void.
void HandleWorkerAvailable(const std::shared_ptr<Worker> &worker);
void HandleWorkerAvailable(const std::shared_ptr<WorkerInterface> &worker);
/// Handle a client that has disconnected. This can be called multiple times
/// on the same client because this is triggered both when a client
@ -582,8 +582,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// \param task_id Id of the task.
/// \param success Whether or not assigning the task was successful.
/// \return void.
void FinishAssignTask(const std::shared_ptr<Worker> &worker, const TaskID &task_id,
bool success);
void FinishAssignTask(const std::shared_ptr<WorkerInterface> &worker,
const TaskID &task_id, bool success);
/// Process worker subscribing to plasma.
///
@ -762,7 +762,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
remote_node_manager_clients_;
/// Map of workers leased out to direct call clients.
std::unordered_map<WorkerID, std::shared_ptr<Worker>> leased_workers_;
std::unordered_map<WorkerID, std::shared_ptr<WorkerInterface>> leased_workers_;
/// Map from owner worker ID to a list of worker IDs that the owner has a
/// lease on.
@ -805,7 +805,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
mutable absl::Mutex plasma_object_notification_lock_;
/// Keeps track of workers waiting for objects
absl::flat_hash_map<ObjectID, absl::flat_hash_set<std::shared_ptr<Worker>>>
absl::flat_hash_map<ObjectID, absl::flat_hash_set<std::shared_ptr<WorkerInterface>>>
async_plasma_objects_notification_ GUARDED_BY(plasma_object_notification_lock_);
/// Objects that are out of scope in the application and that should be freed

View file

@ -81,8 +81,8 @@ bool ClusterTaskManager::WaitForTaskArgsRequests(Work work) {
}
void ClusterTaskManager::DispatchScheduledTasksToWorkers(
WorkerPool &worker_pool,
std::unordered_map<WorkerID, std::shared_ptr<Worker>> &leased_workers) {
WorkerPoolInterface &worker_pool,
std::unordered_map<WorkerID, std::shared_ptr<WorkerInterface>> &leased_workers) {
// Check every task in task_to_dispatch queue to see
// whether it can be dispatched and ran. This avoids head-of-line
// blocking where a task which cannot be dispatched because
@ -94,7 +94,7 @@ void ClusterTaskManager::DispatchScheduledTasksToWorkers(
auto spec = task.GetTaskSpecification();
tasks_to_dispatch_.pop_front();
std::shared_ptr<Worker> worker = worker_pool.PopWorker(spec);
std::shared_ptr<WorkerInterface> worker = worker_pool.PopWorker(spec);
if (!worker) {
// No worker available to schedule this task.
// Put the task back in the dispatch queue.
@ -148,8 +148,8 @@ void ClusterTaskManager::TasksUnblocked(const std::vector<TaskID> ready_ids) {
}
void ClusterTaskManager::Dispatch(
std::shared_ptr<Worker> worker,
std::unordered_map<WorkerID, std::shared_ptr<Worker>> &leased_workers_,
std::shared_ptr<WorkerInterface> worker,
std::unordered_map<WorkerID, std::shared_ptr<WorkerInterface>> &leased_workers_,
const TaskSpecification &task_spec, rpc::RequestWorkerLeaseReply *reply,
rpc::SendReplyCallback send_reply_callback) {
reply->mutable_worker_address()->set_ip_address(worker->IpAddress());

View file

@ -61,14 +61,14 @@ class ClusterTaskManager {
/// `worker_pool` state will be modified (idle workers will be popped) during
/// dispatching.
void DispatchScheduledTasksToWorkers(
WorkerPool &worker_pool,
std::unordered_map<WorkerID, std::shared_ptr<Worker>> &leased_workers);
WorkerPoolInterface &worker_pool,
std::unordered_map<WorkerID, std::shared_ptr<WorkerInterface>> &leased_workers);
/// (Step 1) Queue tasks for scheduling.
/// \param fn: The function used during dispatching.
/// \param task: The incoming task to schedule.
void QueueTask(const Task &task, rpc::RequestWorkerLeaseReply *reply,
rpc::SendReplyCallback send_reply_callback);
rpc::SendReplyCallback);
/// Move tasks from waiting to ready for dispatch. Called when a task's
/// dependencies are resolved.
@ -96,10 +96,11 @@ class ClusterTaskManager {
/// \return True if the work can be immediately dispatched.
bool WaitForTaskArgsRequests(Work work);
void Dispatch(std::shared_ptr<Worker> worker,
std::unordered_map<WorkerID, std::shared_ptr<Worker>> &leased_workers_,
const TaskSpecification &task_spec, rpc::RequestWorkerLeaseReply *reply,
rpc::SendReplyCallback send_reply_callback);
void Dispatch(
std::shared_ptr<WorkerInterface> worker,
std::unordered_map<WorkerID, std::shared_ptr<WorkerInterface>> &leased_workers_,
const TaskSpecification &task_spec, rpc::RequestWorkerLeaseReply *reply,
rpc::SendReplyCallback send_reply_callback);
void Spillback(ClientID spillback_to, std::string address, int port,
rpc::RequestWorkerLeaseReply *reply,

View file

@ -30,10 +30,91 @@ namespace ray {
namespace raylet {
/// \class WorkerPoolInterface
///
/// Used for new scheduler unit tests.
class WorkerInterface {
public:
/// A destructor responsible for freeing all worker state.
virtual ~WorkerInterface() {}
virtual void MarkDead() = 0;
virtual bool IsDead() const = 0;
virtual void MarkBlocked() = 0;
virtual void MarkUnblocked() = 0;
virtual bool IsBlocked() const = 0;
/// Return the worker's ID.
virtual WorkerID WorkerId() const = 0;
/// Return the worker process.
virtual Process GetProcess() const = 0;
virtual void SetProcess(Process proc) = 0;
virtual Language GetLanguage() const = 0;
virtual const std::string IpAddress() const = 0;
/// Connect this worker's gRPC client.
virtual void Connect(int port) = 0;
virtual int Port() const = 0;
virtual int AssignedPort() const = 0;
virtual void SetAssignedPort(int port) = 0;
virtual void AssignTaskId(const TaskID &task_id) = 0;
virtual const TaskID &GetAssignedTaskId() const = 0;
virtual bool AddBlockedTaskId(const TaskID &task_id) = 0;
virtual bool RemoveBlockedTaskId(const TaskID &task_id) = 0;
virtual const std::unordered_set<TaskID> &GetBlockedTaskIds() const = 0;
virtual void AssignJobId(const JobID &job_id) = 0;
virtual const JobID &GetAssignedJobId() const = 0;
virtual void AssignActorId(const ActorID &actor_id) = 0;
virtual const ActorID &GetActorId() const = 0;
virtual void MarkDetachedActor() = 0;
virtual bool IsDetachedActor() const = 0;
virtual const std::shared_ptr<ClientConnection> Connection() const = 0;
virtual void SetOwnerAddress(const rpc::Address &address) = 0;
virtual const rpc::Address &GetOwnerAddress() const = 0;
virtual const ResourceIdSet &GetLifetimeResourceIds() const = 0;
virtual void SetLifetimeResourceIds(ResourceIdSet &resource_ids) = 0;
virtual void ResetLifetimeResourceIds() = 0;
virtual const ResourceIdSet &GetTaskResourceIds() const = 0;
virtual void SetTaskResourceIds(ResourceIdSet &resource_ids) = 0;
virtual void ResetTaskResourceIds() = 0;
virtual ResourceIdSet ReleaseTaskCpuResources() = 0;
virtual void AcquireTaskCpuResources(const ResourceIdSet &cpu_resources) = 0;
virtual Status AssignTask(const Task &task, const ResourceIdSet &resource_id_set) = 0;
virtual void DirectActorCallArgWaitComplete(int64_t tag) = 0;
// Setter, geter, and clear methods for allocated_instances_.
virtual void SetAllocatedInstances(
std::shared_ptr<TaskResourceInstances> &allocated_instances) = 0;
virtual std::shared_ptr<TaskResourceInstances> GetAllocatedInstances() = 0;
virtual void ClearAllocatedInstances() = 0;
virtual void SetLifetimeAllocatedInstances(
std::shared_ptr<TaskResourceInstances> &allocated_instances) = 0;
virtual std::shared_ptr<TaskResourceInstances> GetLifetimeAllocatedInstances() = 0;
virtual void ClearLifetimeAllocatedInstances() = 0;
virtual void SetBorrowedCPUInstances(std::vector<double> &cpu_instances) = 0;
virtual std::vector<double> &GetBorrowedCPUInstances() = 0;
virtual void ClearBorrowedCPUInstances() = 0;
virtual Task &GetAssignedTask() = 0;
virtual void SetAssignedTask(Task &assigned_task) = 0;
virtual bool IsRegistered() = 0;
virtual rpc::CoreWorkerClient *rpc_client() = 0;
};
/// Worker class encapsulates the implementation details of a worker. A worker
/// is the execution container around a unit of Ray work, such as a task or an
/// actor. Ray units of work execute in the context of a Worker.
class Worker {
class Worker : public WorkerInterface {
public:
/// A constructor that initializes a worker object.
/// NOTE: You MUST manually set the worker process.
@ -84,12 +165,8 @@ class Worker {
ResourceIdSet ReleaseTaskCpuResources();
void AcquireTaskCpuResources(const ResourceIdSet &cpu_resources);
const std::unordered_set<ObjectID> &GetActiveObjectIds() const;
void SetActiveObjectIds(const std::unordered_set<ObjectID> &&object_ids);
Status AssignTask(const Task &task, const ResourceIdSet &resource_id_set);
void DirectActorCallArgWaitComplete(int64_t tag);
void WorkerLeaseGranted(const std::string &address, int port);
// Setter, geter, and clear methods for allocated_instances_.
void SetAllocatedInstances(

View file

@ -29,8 +29,8 @@
namespace {
// A helper function to get a worker from a list.
std::shared_ptr<ray::raylet::Worker> GetWorker(
const std::unordered_set<std::shared_ptr<ray::raylet::Worker>> &worker_pool,
std::shared_ptr<ray::raylet::WorkerInterface> GetWorker(
const std::unordered_set<std::shared_ptr<ray::raylet::WorkerInterface>> &worker_pool,
const std::shared_ptr<ray::ClientConnection> &connection) {
for (auto it = worker_pool.begin(); it != worker_pool.end(); it++) {
if ((*it)->Connection() == connection) {
@ -42,8 +42,9 @@ std::shared_ptr<ray::raylet::Worker> GetWorker(
// A helper function to remove a worker from a list. Returns true if the worker
// was found and removed.
bool RemoveWorker(std::unordered_set<std::shared_ptr<ray::raylet::Worker>> &worker_pool,
const std::shared_ptr<ray::raylet::Worker> &worker) {
bool RemoveWorker(
std::unordered_set<std::shared_ptr<ray::raylet::WorkerInterface>> &worker_pool,
const std::shared_ptr<ray::raylet::WorkerInterface> &worker) {
return worker_pool.erase(worker) > 0;
}
@ -326,8 +327,8 @@ void WorkerPool::MarkPortAsFree(int port) {
}
}
Status WorkerPool::RegisterWorker(const std::shared_ptr<Worker> &worker, pid_t pid,
int *port) {
Status WorkerPool::RegisterWorker(const std::shared_ptr<WorkerInterface> &worker,
pid_t pid, int *port) {
auto &state = GetStateForLanguage(worker->GetLanguage());
auto it = state.starting_worker_processes.find(Process::FromPid(pid));
if (it == state.starting_worker_processes.end()) {
@ -347,7 +348,8 @@ Status WorkerPool::RegisterWorker(const std::shared_ptr<Worker> &worker, pid_t p
return Status::OK();
}
Status WorkerPool::RegisterDriver(const std::shared_ptr<Worker> &driver, int *port) {
Status WorkerPool::RegisterDriver(const std::shared_ptr<WorkerInterface> &driver,
int *port) {
RAY_CHECK(!driver->GetAssignedTaskId().IsNil());
RAY_RETURN_NOT_OK(GetNextFreePort(port));
driver->SetAssignedPort(*port);
@ -356,7 +358,7 @@ Status WorkerPool::RegisterDriver(const std::shared_ptr<Worker> &driver, int *po
return Status::OK();
}
std::shared_ptr<Worker> WorkerPool::GetRegisteredWorker(
std::shared_ptr<WorkerInterface> WorkerPool::GetRegisteredWorker(
const std::shared_ptr<ClientConnection> &connection) const {
for (const auto &entry : states_by_lang_) {
auto worker = GetWorker(entry.second.registered_workers, connection);
@ -367,7 +369,7 @@ std::shared_ptr<Worker> WorkerPool::GetRegisteredWorker(
return nullptr;
}
std::shared_ptr<Worker> WorkerPool::GetRegisteredDriver(
std::shared_ptr<WorkerInterface> WorkerPool::GetRegisteredDriver(
const std::shared_ptr<ClientConnection> &connection) const {
for (const auto &entry : states_by_lang_) {
auto driver = GetWorker(entry.second.registered_drivers, connection);
@ -378,7 +380,7 @@ std::shared_ptr<Worker> WorkerPool::GetRegisteredDriver(
return nullptr;
}
void WorkerPool::PushWorker(const std::shared_ptr<Worker> &worker) {
void WorkerPool::PushWorker(const std::shared_ptr<WorkerInterface> &worker) {
// Since the worker is now idle, unset its assigned task ID.
RAY_CHECK(worker->GetAssignedTaskId().IsNil())
<< "Idle workers cannot have an assigned task ID";
@ -401,10 +403,11 @@ void WorkerPool::PushWorker(const std::shared_ptr<Worker> &worker) {
}
}
std::shared_ptr<Worker> WorkerPool::PopWorker(const TaskSpecification &task_spec) {
std::shared_ptr<WorkerInterface> WorkerPool::PopWorker(
const TaskSpecification &task_spec) {
auto &state = GetStateForLanguage(task_spec.GetLanguage());
std::shared_ptr<Worker> worker = nullptr;
std::shared_ptr<WorkerInterface> worker = nullptr;
Process proc;
if (task_spec.IsActorCreationTask() && !task_spec.DynamicWorkerOptions().empty()) {
// Code path of actor creation task with dynamic worker options.
@ -455,7 +458,7 @@ std::shared_ptr<Worker> WorkerPool::PopWorker(const TaskSpecification &task_spec
return worker;
}
bool WorkerPool::DisconnectWorker(const std::shared_ptr<Worker> &worker) {
bool WorkerPool::DisconnectWorker(const std::shared_ptr<WorkerInterface> &worker) {
auto &state = GetStateForLanguage(worker->GetLanguage());
RAY_CHECK(RemoveWorker(state.registered_workers, worker));
@ -467,7 +470,7 @@ bool WorkerPool::DisconnectWorker(const std::shared_ptr<Worker> &worker) {
return RemoveWorker(state.idle, worker);
}
void WorkerPool::DisconnectDriver(const std::shared_ptr<Worker> &driver) {
void WorkerPool::DisconnectDriver(const std::shared_ptr<WorkerInterface> &driver) {
auto &state = GetStateForLanguage(driver->GetLanguage());
RAY_CHECK(RemoveWorker(state.registered_drivers, driver));
stats::CurrentDriver().Record(
@ -482,9 +485,9 @@ inline WorkerPool::State &WorkerPool::GetStateForLanguage(const Language &langua
return state->second;
}
std::vector<std::shared_ptr<Worker>> WorkerPool::GetWorkersRunningTasksForJob(
std::vector<std::shared_ptr<WorkerInterface>> WorkerPool::GetWorkersRunningTasksForJob(
const JobID &job_id) const {
std::vector<std::shared_ptr<Worker>> workers;
std::vector<std::shared_ptr<WorkerInterface>> workers;
for (const auto &entry : states_by_lang_) {
for (const auto &worker : entry.second.registered_workers) {
@ -497,8 +500,9 @@ std::vector<std::shared_ptr<Worker>> WorkerPool::GetWorkersRunningTasksForJob(
return workers;
}
const std::vector<std::shared_ptr<Worker>> WorkerPool::GetAllRegisteredWorkers() const {
std::vector<std::shared_ptr<Worker>> workers;
const std::vector<std::shared_ptr<WorkerInterface>> WorkerPool::GetAllRegisteredWorkers()
const {
std::vector<std::shared_ptr<WorkerInterface>> workers;
for (const auto &entry : states_by_lang_) {
for (const auto &worker : entry.second.registered_workers) {
@ -511,8 +515,9 @@ const std::vector<std::shared_ptr<Worker>> WorkerPool::GetAllRegisteredWorkers()
return workers;
}
const std::vector<std::shared_ptr<Worker>> WorkerPool::GetAllRegisteredDrivers() const {
std::vector<std::shared_ptr<Worker>> drivers;
const std::vector<std::shared_ptr<WorkerInterface>> WorkerPool::GetAllRegisteredDrivers()
const {
std::vector<std::shared_ptr<WorkerInterface>> drivers;
for (const auto &entry : states_by_lang_) {
for (const auto &driver : entry.second.registered_drivers) {

View file

@ -36,13 +36,35 @@ namespace raylet {
using WorkerCommandMap =
std::unordered_map<Language, std::vector<std::string>, std::hash<int>>;
/// \class WorkerPoolInterface
///
/// Used for new scheduler unit tests.
class WorkerPoolInterface {
public:
/// Pop an idle worker from the pool. The caller is responsible for pushing
/// the worker back onto the pool once the worker has completed its work.
///
/// \param task_spec The returned worker must be able to execute this task.
/// \return An idle worker with the requested task spec. Returns nullptr if no
/// such worker exists.
virtual std::shared_ptr<WorkerInterface> PopWorker(
const TaskSpecification &task_spec) = 0;
/// Add an idle worker to the pool.
///
/// \param The idle worker to add.
virtual void PushWorker(const std::shared_ptr<WorkerInterface> &worker) = 0;
virtual ~WorkerPoolInterface(){};
};
class WorkerInterface;
class Worker;
/// \class WorkerPool
///
/// The WorkerPool is responsible for managing a pool of Workers. Each Worker
/// is a container for a unit of work.
class WorkerPool {
class WorkerPool : public WorkerPoolInterface {
public:
/// Create a pool and asynchronously start at least the specified number of workers per
/// language.
@ -81,7 +103,8 @@ class WorkerPool {
/// \param[out] port The port that this worker's gRPC server should listen on.
/// Returns 0 if the worker should bind on a random port.
/// \return If the registration is successful.
Status RegisterWorker(const std::shared_ptr<Worker> &worker, pid_t pid, int *port);
Status RegisterWorker(const std::shared_ptr<WorkerInterface> &worker, pid_t pid,
int *port);
/// Register a new driver.
///
@ -89,14 +112,14 @@ class WorkerPool {
/// \param[out] port The port that this driver's gRPC server should listen on.
/// Returns 0 if the driver should bind on a random port.
/// \return If the registration is successful.
Status RegisterDriver(const std::shared_ptr<Worker> &worker, int *port);
Status RegisterDriver(const std::shared_ptr<WorkerInterface> &worker, int *port);
/// Get the client connection's registered worker.
///
/// \param The client connection owned by a registered worker.
/// \return The Worker that owns the given client connection. Returns nullptr
/// if the client has not registered a worker yet.
std::shared_ptr<Worker> GetRegisteredWorker(
std::shared_ptr<WorkerInterface> GetRegisteredWorker(
const std::shared_ptr<ClientConnection> &connection) const;
/// Get the client connection's registered driver.
@ -104,24 +127,24 @@ class WorkerPool {
/// \param The client connection owned by a registered driver.
/// \return The Worker that owns the given client connection. Returns nullptr
/// if the client has not registered a driver.
std::shared_ptr<Worker> GetRegisteredDriver(
std::shared_ptr<WorkerInterface> GetRegisteredDriver(
const std::shared_ptr<ClientConnection> &connection) const;
/// Disconnect a registered worker.
///
/// \param The worker to disconnect. The worker must be registered.
/// \return Whether the given worker was in the pool of idle workers.
bool DisconnectWorker(const std::shared_ptr<Worker> &worker);
bool DisconnectWorker(const std::shared_ptr<WorkerInterface> &worker);
/// Disconnect a registered driver.
///
/// \param The driver to disconnect. The driver must be registered.
void DisconnectDriver(const std::shared_ptr<Worker> &driver);
void DisconnectDriver(const std::shared_ptr<WorkerInterface> &driver);
/// Add an idle worker to the pool.
///
/// \param The idle worker to add.
void PushWorker(const std::shared_ptr<Worker> &worker);
void PushWorker(const std::shared_ptr<WorkerInterface> &worker);
/// Pop an idle worker from the pool. The caller is responsible for pushing
/// the worker back onto the pool once the worker has completed its work.
@ -129,7 +152,7 @@ class WorkerPool {
/// \param task_spec The returned worker must be able to execute this task.
/// \return An idle worker with the requested task spec. Returns nullptr if no
/// such worker exists.
std::shared_ptr<Worker> PopWorker(const TaskSpecification &task_spec);
std::shared_ptr<WorkerInterface> PopWorker(const TaskSpecification &task_spec);
/// Return the current size of the worker pool for the requested language. Counts only
/// idle workers.
@ -142,18 +165,18 @@ class WorkerPool {
///
/// \param job_id The job ID.
/// \return A list containing all the workers which are running tasks for the job.
std::vector<std::shared_ptr<Worker>> GetWorkersRunningTasksForJob(
std::vector<std::shared_ptr<WorkerInterface>> GetWorkersRunningTasksForJob(
const JobID &job_id) const;
/// Get all the registered workers.
///
/// \return A list containing all the workers.
const std::vector<std::shared_ptr<Worker>> GetAllRegisteredWorkers() const;
const std::vector<std::shared_ptr<WorkerInterface>> GetAllRegisteredWorkers() const;
/// Get all the registered drivers.
///
/// \return A list containing all the drivers.
const std::vector<std::shared_ptr<Worker>> GetAllRegisteredDrivers() const;
const std::vector<std::shared_ptr<WorkerInterface>> GetAllRegisteredDrivers() const;
/// Whether there is a pending worker for the given task.
/// Note that, this is only used for actor creation task with dynamic options.
@ -210,16 +233,16 @@ class WorkerPool {
int num_workers_per_process;
/// The pool of dedicated workers for actor creation tasks
/// with prefix or suffix worker command.
std::unordered_map<TaskID, std::shared_ptr<Worker>> idle_dedicated_workers;
std::unordered_map<TaskID, std::shared_ptr<WorkerInterface>> idle_dedicated_workers;
/// The pool of idle non-actor workers.
std::unordered_set<std::shared_ptr<Worker>> idle;
std::unordered_set<std::shared_ptr<WorkerInterface>> idle;
/// The pool of idle actor workers.
std::unordered_map<ActorID, std::shared_ptr<Worker>> idle_actor;
std::unordered_map<ActorID, std::shared_ptr<WorkerInterface>> idle_actor;
/// All workers that have registered and are still connected, including both
/// idle and executing.
std::unordered_set<std::shared_ptr<Worker>> registered_workers;
std::unordered_set<std::shared_ptr<WorkerInterface>> registered_workers;
/// All drivers that have registered and are still connected.
std::unordered_set<std::shared_ptr<Worker>> registered_drivers;
std::unordered_set<std::shared_ptr<WorkerInterface>> registered_drivers;
/// A map from the pids of starting worker processes
/// to the number of their unregistered workers.
std::unordered_map<Process, int> starting_worker_processes;

View file

@ -100,8 +100,8 @@ class WorkerPoolTest : public ::testing::Test {
worker_pool_ = std::unique_ptr<WorkerPoolMock>(new WorkerPoolMock(io_service_));
}
std::shared_ptr<Worker> CreateWorker(Process proc,
const Language &language = Language::PYTHON) {
std::shared_ptr<WorkerInterface> CreateWorker(
Process proc, const Language &language = Language::PYTHON) {
std::function<void(ClientConnection &)> client_handler =
[this](ClientConnection &client) { HandleNewClient(client); };
std::function<void(std::shared_ptr<ClientConnection>, int64_t,
@ -115,8 +115,10 @@ class WorkerPoolTest : public ::testing::Test {
auto client =
ClientConnection::Create(client_handler, message_handler, std::move(socket),
"worker", {}, error_message_type_);
std::shared_ptr<Worker> worker = std::make_shared<Worker>(
std::shared_ptr<Worker> worker_ = std::make_shared<Worker>(
WorkerID::FromRandom(), language, "127.0.0.1", client, client_call_manager_);
std::shared_ptr<WorkerInterface> worker =
std::dynamic_pointer_cast<WorkerInterface>(worker_);
if (!proc.IsNull()) {
worker->SetProcess(proc);
}
@ -205,7 +207,7 @@ TEST_F(WorkerPoolTest, CompareWorkerProcessObjects) {
TEST_F(WorkerPoolTest, HandleWorkerRegistration) {
Process proc = worker_pool_->StartWorkerProcess(Language::JAVA);
std::vector<std::shared_ptr<Worker>> workers;
std::vector<std::shared_ptr<WorkerInterface>> workers;
for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) {
workers.push_back(CreateWorker(Process(), Language::JAVA));
}
@ -254,13 +256,13 @@ TEST_F(WorkerPoolTest, InitialWorkerProcessCount) {
TEST_F(WorkerPoolTest, HandleWorkerPushPop) {
// Try to pop a worker from the empty pool and make sure we don't get one.
std::shared_ptr<Worker> popped_worker;
std::shared_ptr<WorkerInterface> popped_worker;
const auto task_spec = ExampleTaskSpec();
popped_worker = worker_pool_->PopWorker(task_spec);
ASSERT_EQ(popped_worker, nullptr);
// Create some workers.
std::unordered_set<std::shared_ptr<Worker>> workers;
std::unordered_set<std::shared_ptr<WorkerInterface>> workers;
workers.insert(CreateWorker(Process::CreateNewDummy()));
workers.insert(CreateWorker(Process::CreateNewDummy()));
// Add the workers to the pool.