mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Mark worker as blocked and trigger reconstruction in ray.wait. (#2864)
* Trigger reconstruction in ray.wait and mark worker as blocked. * Add test. * Linting. * Don't run new test with legacy Ray. * Only call HandleClientUnblocked if it actually blocked in ray.wait. * Reduce time to ray.wait in the test.
This commit is contained in:
parent
a1b8e79c30
commit
f16d33593b
4 changed files with 108 additions and 51 deletions
|
@ -41,7 +41,9 @@ public class PlasmaFreeTest {
|
|||
Ray.call(PlasmaFreeTest::hello).get();
|
||||
}
|
||||
|
||||
waitResult = Ray.wait(waitFor, 1, 2 * 1000);
|
||||
// Check if the object has been evicted. Don't give ray.wait enough
|
||||
// time to reconstruct the object.
|
||||
waitResult = Ray.wait(waitFor, 1, 0);
|
||||
readyOnes = waitResult.getReady();
|
||||
unreadyOnes = waitResult.getUnready();
|
||||
Assert.assertEquals(0, readyOnes.size());
|
||||
|
|
|
@ -697,57 +697,11 @@ void NodeManager::ProcessClientMessage(
|
|||
}
|
||||
|
||||
if (!required_object_ids.empty()) {
|
||||
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
|
||||
if (worker) {
|
||||
// The client is a worker. Mark the worker as blocked. This
|
||||
// temporarily releases any resources that the worker holds while it is
|
||||
// blocked.
|
||||
HandleWorkerBlocked(worker);
|
||||
} else {
|
||||
// The client is a driver. Drivers do not hold resources, so we simply
|
||||
// mark the driver as blocked.
|
||||
worker = worker_pool_.GetRegisteredDriver(client);
|
||||
RAY_CHECK(worker);
|
||||
worker->MarkBlocked();
|
||||
}
|
||||
const TaskID current_task_id = worker->GetAssignedTaskId();
|
||||
RAY_CHECK(!current_task_id.is_nil());
|
||||
// Subscribe to the objects required by the ray.get. These objects will
|
||||
// be fetched and/or reconstructed as necessary, until the objects become
|
||||
// local or are unsubscribed.
|
||||
task_dependency_manager_.SubscribeDependencies(current_task_id,
|
||||
required_object_ids);
|
||||
HandleClientBlocked(client, required_object_ids);
|
||||
}
|
||||
} break;
|
||||
case protocol::MessageType::NotifyUnblocked: {
|
||||
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
|
||||
|
||||
// Re-acquire the CPU resources for the task that was assigned to the
|
||||
// unblocked worker.
|
||||
// TODO(swang): Because the object dependencies are tracked in the task
|
||||
// dependency manager, we could actually remove this message entirely and
|
||||
// instead unblock the worker once all the objects become available.
|
||||
bool was_blocked;
|
||||
if (worker) {
|
||||
was_blocked = worker->IsBlocked();
|
||||
// Mark the worker as unblocked. This returns the temporarily released
|
||||
// resources to the worker.
|
||||
HandleWorkerUnblocked(worker);
|
||||
} else {
|
||||
// The client is a driver. Drivers do not hold resources, so we simply
|
||||
// mark the driver as unblocked.
|
||||
worker = worker_pool_.GetRegisteredDriver(client);
|
||||
RAY_CHECK(worker);
|
||||
was_blocked = worker->IsBlocked();
|
||||
worker->MarkUnblocked();
|
||||
}
|
||||
// Unsubscribe to the objects. Any fetch or reconstruction operations to
|
||||
// make the objects local are canceled.
|
||||
if (was_blocked) {
|
||||
const TaskID current_task_id = worker->GetAssignedTaskId();
|
||||
RAY_CHECK(!current_task_id.is_nil());
|
||||
task_dependency_manager_.UnsubscribeDependencies(current_task_id);
|
||||
}
|
||||
HandleClientUnblocked(client);
|
||||
} break;
|
||||
case protocol::MessageType::WaitRequest: {
|
||||
// Read the data.
|
||||
|
@ -757,9 +711,25 @@ void NodeManager::ProcessClientMessage(
|
|||
uint64_t num_required_objects = static_cast<uint64_t>(message->num_ready_objects());
|
||||
bool wait_local = message->wait_local();
|
||||
|
||||
std::vector<ObjectID> required_object_ids;
|
||||
for (auto const &object_id : object_ids) {
|
||||
if (!task_dependency_manager_.CheckObjectLocal(object_id)) {
|
||||
// Add any missing objects to the list to subscribe to in the task
|
||||
// dependency manager. These objects will be pulled from remote node
|
||||
// managers and reconstructed if necessary.
|
||||
required_object_ids.push_back(object_id);
|
||||
}
|
||||
}
|
||||
|
||||
bool client_blocked = !required_object_ids.empty();
|
||||
if (client_blocked) {
|
||||
HandleClientBlocked(client, required_object_ids);
|
||||
}
|
||||
|
||||
ray::Status status = object_manager_.Wait(
|
||||
object_ids, wait_ms, num_required_objects, wait_local,
|
||||
[client](std::vector<ObjectID> found, std::vector<ObjectID> remaining) {
|
||||
[this, client_blocked, client](std::vector<ObjectID> found,
|
||||
std::vector<ObjectID> remaining) {
|
||||
// Write the data.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
flatbuffers::Offset<protocol::WaitReply> wait_reply = protocol::CreateWaitReply(
|
||||
|
@ -768,6 +738,10 @@ void NodeManager::ProcessClientMessage(
|
|||
RAY_CHECK_OK(
|
||||
client->WriteMessage(static_cast<int64_t>(protocol::MessageType::WaitReply),
|
||||
fbb.GetSize(), fbb.GetBufferPointer()));
|
||||
// The client is unblocked now because the wait call has returned.
|
||||
if (client_blocked) {
|
||||
HandleClientUnblocked(client);
|
||||
}
|
||||
});
|
||||
RAY_CHECK_OK(status);
|
||||
} break;
|
||||
|
@ -1117,6 +1091,62 @@ void NodeManager::HandleWorkerUnblocked(std::shared_ptr<Worker> worker) {
|
|||
worker->MarkUnblocked();
|
||||
}
|
||||
|
||||
void NodeManager::HandleClientBlocked(
|
||||
const std::shared_ptr<LocalClientConnection> &client,
|
||||
const std::vector<ObjectID> &required_object_ids) {
|
||||
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
|
||||
if (worker) {
|
||||
// The client is a worker. Mark the worker as blocked. This
|
||||
// temporarily releases any resources that the worker holds while it is
|
||||
// blocked.
|
||||
HandleWorkerBlocked(worker);
|
||||
} else {
|
||||
// The client is a driver. Drivers do not hold resources, so we simply
|
||||
// mark the driver as blocked.
|
||||
worker = worker_pool_.GetRegisteredDriver(client);
|
||||
RAY_CHECK(worker);
|
||||
worker->MarkBlocked();
|
||||
}
|
||||
const TaskID current_task_id = worker->GetAssignedTaskId();
|
||||
RAY_CHECK(!current_task_id.is_nil());
|
||||
// Subscribe to the objects required by the ray.get. These objects will
|
||||
// be fetched and/or reconstructed as necessary, until the objects become
|
||||
// local or are unsubscribed.
|
||||
task_dependency_manager_.SubscribeDependencies(current_task_id, required_object_ids);
|
||||
}
|
||||
|
||||
void NodeManager::HandleClientUnblocked(
|
||||
const std::shared_ptr<LocalClientConnection> &client) {
|
||||
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
|
||||
|
||||
// Re-acquire the CPU resources for the task that was assigned to the
|
||||
// unblocked worker.
|
||||
// TODO(swang): Because the object dependencies are tracked in the task
|
||||
// dependency manager, we could actually remove this message entirely and
|
||||
// instead unblock the worker once all the objects become available.
|
||||
bool was_blocked;
|
||||
if (worker) {
|
||||
was_blocked = worker->IsBlocked();
|
||||
// Mark the worker as unblocked. This returns the temporarily released
|
||||
// resources to the worker.
|
||||
HandleWorkerUnblocked(worker);
|
||||
} else {
|
||||
// The client is a driver. Drivers do not hold resources, so we simply
|
||||
// mark the driver as unblocked.
|
||||
worker = worker_pool_.GetRegisteredDriver(client);
|
||||
RAY_CHECK(worker);
|
||||
was_blocked = worker->IsBlocked();
|
||||
worker->MarkUnblocked();
|
||||
}
|
||||
// Unsubscribe to the objects. Any fetch or reconstruction operations to
|
||||
// make the objects local are canceled.
|
||||
if (was_blocked) {
|
||||
const TaskID current_task_id = worker->GetAssignedTaskId();
|
||||
RAY_CHECK(!current_task_id.is_nil());
|
||||
task_dependency_manager_.UnsubscribeDependencies(current_task_id);
|
||||
}
|
||||
}
|
||||
|
||||
void NodeManager::EnqueuePlaceableTask(const Task &task) {
|
||||
// TODO(atumanov): add task lookup hashmap and change EnqueuePlaceableTask to take
|
||||
// a vector of TaskIDs. Trigger MoveTask internally.
|
||||
|
|
|
@ -192,6 +192,23 @@ class NodeManager {
|
|||
/// \return Void.
|
||||
void HandleWorkerUnblocked(std::shared_ptr<Worker> worker);
|
||||
|
||||
/// Handle a client that is blocked. This could be a worker or a driver. This
|
||||
/// can be triggered when a client starts a get call or a wait call.
|
||||
///
|
||||
/// \param client The client that is blocked.
|
||||
/// \param required_object_ids The IDs that the client is blocked waiting for.
|
||||
/// \return Void.
|
||||
void HandleClientBlocked(const std::shared_ptr<LocalClientConnection> &client,
|
||||
const std::vector<ObjectID> &required_object_ids);
|
||||
|
||||
/// Handle a client that is unblocked. This could be a worker or a driver.
|
||||
/// This can be triggered when a client is finished with a get call or a wait
|
||||
/// call. It is ok to call this even if the client is not actually blocked.
|
||||
///
|
||||
/// \param client The client that is unblocked.
|
||||
/// \return Void.
|
||||
void HandleClientUnblocked(const std::shared_ptr<LocalClientConnection> &client);
|
||||
|
||||
/// Kill a worker.
|
||||
///
|
||||
/// \param worker The worker to kill.
|
||||
|
|
|
@ -2063,7 +2063,15 @@ class WorkerPoolTests(unittest.TestCase):
|
|||
object_ids = [f.remote(i, j) for j in range(2)]
|
||||
return ray.get(object_ids)
|
||||
|
||||
ray.get([g.remote(i) for i in range(4)])
|
||||
@ray.remote
|
||||
def h(i):
|
||||
# Each instance of g submits and blocks on the result of another
|
||||
# remote task using ray.wait.
|
||||
object_ids = [f.remote(i, j) for j in range(2)]
|
||||
return ray.wait(object_ids, num_returns=len(object_ids))
|
||||
|
||||
if os.environ.get("RAY_USE_XRAY") == "1":
|
||||
ray.get([h.remote(i) for i in range(4)])
|
||||
|
||||
@ray.remote
|
||||
def _sleep(i):
|
||||
|
|
Loading…
Add table
Reference in a new issue