Mark worker as blocked and trigger reconstruction in ray.wait. (#2864)

* Trigger reconstruction in ray.wait and mark worker as blocked.

* Add test.

* Linting.

* Don't run new test with legacy Ray.

* Only call HandleClientUnblocked if it actually blocked in ray.wait.

* Reduce time to ray.wait in the test.
This commit is contained in:
Robert Nishihara 2018-09-13 15:28:17 -07:00 committed by Philipp Moritz
parent a1b8e79c30
commit f16d33593b
4 changed files with 108 additions and 51 deletions

View file

@ -41,7 +41,9 @@ public class PlasmaFreeTest {
Ray.call(PlasmaFreeTest::hello).get();
}
waitResult = Ray.wait(waitFor, 1, 2 * 1000);
// Check if the object has been evicted. Don't give ray.wait enough
// time to reconstruct the object.
waitResult = Ray.wait(waitFor, 1, 0);
readyOnes = waitResult.getReady();
unreadyOnes = waitResult.getUnready();
Assert.assertEquals(0, readyOnes.size());

View file

@ -697,57 +697,11 @@ void NodeManager::ProcessClientMessage(
}
if (!required_object_ids.empty()) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
if (worker) {
// The client is a worker. Mark the worker as blocked. This
// temporarily releases any resources that the worker holds while it is
// blocked.
HandleWorkerBlocked(worker);
} else {
// The client is a driver. Drivers do not hold resources, so we simply
// mark the driver as blocked.
worker = worker_pool_.GetRegisteredDriver(client);
RAY_CHECK(worker);
worker->MarkBlocked();
}
const TaskID current_task_id = worker->GetAssignedTaskId();
RAY_CHECK(!current_task_id.is_nil());
// Subscribe to the objects required by the ray.get. These objects will
// be fetched and/or reconstructed as necessary, until the objects become
// local or are unsubscribed.
task_dependency_manager_.SubscribeDependencies(current_task_id,
required_object_ids);
HandleClientBlocked(client, required_object_ids);
}
} break;
case protocol::MessageType::NotifyUnblocked: {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
// Re-acquire the CPU resources for the task that was assigned to the
// unblocked worker.
// TODO(swang): Because the object dependencies are tracked in the task
// dependency manager, we could actually remove this message entirely and
// instead unblock the worker once all the objects become available.
bool was_blocked;
if (worker) {
was_blocked = worker->IsBlocked();
// Mark the worker as unblocked. This returns the temporarily released
// resources to the worker.
HandleWorkerUnblocked(worker);
} else {
// The client is a driver. Drivers do not hold resources, so we simply
// mark the driver as unblocked.
worker = worker_pool_.GetRegisteredDriver(client);
RAY_CHECK(worker);
was_blocked = worker->IsBlocked();
worker->MarkUnblocked();
}
// Unsubscribe to the objects. Any fetch or reconstruction operations to
// make the objects local are canceled.
if (was_blocked) {
const TaskID current_task_id = worker->GetAssignedTaskId();
RAY_CHECK(!current_task_id.is_nil());
task_dependency_manager_.UnsubscribeDependencies(current_task_id);
}
HandleClientUnblocked(client);
} break;
case protocol::MessageType::WaitRequest: {
// Read the data.
@ -757,9 +711,25 @@ void NodeManager::ProcessClientMessage(
uint64_t num_required_objects = static_cast<uint64_t>(message->num_ready_objects());
bool wait_local = message->wait_local();
std::vector<ObjectID> required_object_ids;
for (auto const &object_id : object_ids) {
if (!task_dependency_manager_.CheckObjectLocal(object_id)) {
// Add any missing objects to the list to subscribe to in the task
// dependency manager. These objects will be pulled from remote node
// managers and reconstructed if necessary.
required_object_ids.push_back(object_id);
}
}
bool client_blocked = !required_object_ids.empty();
if (client_blocked) {
HandleClientBlocked(client, required_object_ids);
}
ray::Status status = object_manager_.Wait(
object_ids, wait_ms, num_required_objects, wait_local,
[client](std::vector<ObjectID> found, std::vector<ObjectID> remaining) {
[this, client_blocked, client](std::vector<ObjectID> found,
std::vector<ObjectID> remaining) {
// Write the data.
flatbuffers::FlatBufferBuilder fbb;
flatbuffers::Offset<protocol::WaitReply> wait_reply = protocol::CreateWaitReply(
@ -768,6 +738,10 @@ void NodeManager::ProcessClientMessage(
RAY_CHECK_OK(
client->WriteMessage(static_cast<int64_t>(protocol::MessageType::WaitReply),
fbb.GetSize(), fbb.GetBufferPointer()));
// The client is unblocked now because the wait call has returned.
if (client_blocked) {
HandleClientUnblocked(client);
}
});
RAY_CHECK_OK(status);
} break;
@ -1117,6 +1091,62 @@ void NodeManager::HandleWorkerUnblocked(std::shared_ptr<Worker> worker) {
worker->MarkUnblocked();
}
void NodeManager::HandleClientBlocked(
const std::shared_ptr<LocalClientConnection> &client,
const std::vector<ObjectID> &required_object_ids) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
if (worker) {
// The client is a worker. Mark the worker as blocked. This
// temporarily releases any resources that the worker holds while it is
// blocked.
HandleWorkerBlocked(worker);
} else {
// The client is a driver. Drivers do not hold resources, so we simply
// mark the driver as blocked.
worker = worker_pool_.GetRegisteredDriver(client);
RAY_CHECK(worker);
worker->MarkBlocked();
}
const TaskID current_task_id = worker->GetAssignedTaskId();
RAY_CHECK(!current_task_id.is_nil());
// Subscribe to the objects required by the ray.get. These objects will
// be fetched and/or reconstructed as necessary, until the objects become
// local or are unsubscribed.
task_dependency_manager_.SubscribeDependencies(current_task_id, required_object_ids);
}
void NodeManager::HandleClientUnblocked(
const std::shared_ptr<LocalClientConnection> &client) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
// Re-acquire the CPU resources for the task that was assigned to the
// unblocked worker.
// TODO(swang): Because the object dependencies are tracked in the task
// dependency manager, we could actually remove this message entirely and
// instead unblock the worker once all the objects become available.
bool was_blocked;
if (worker) {
was_blocked = worker->IsBlocked();
// Mark the worker as unblocked. This returns the temporarily released
// resources to the worker.
HandleWorkerUnblocked(worker);
} else {
// The client is a driver. Drivers do not hold resources, so we simply
// mark the driver as unblocked.
worker = worker_pool_.GetRegisteredDriver(client);
RAY_CHECK(worker);
was_blocked = worker->IsBlocked();
worker->MarkUnblocked();
}
// Unsubscribe to the objects. Any fetch or reconstruction operations to
// make the objects local are canceled.
if (was_blocked) {
const TaskID current_task_id = worker->GetAssignedTaskId();
RAY_CHECK(!current_task_id.is_nil());
task_dependency_manager_.UnsubscribeDependencies(current_task_id);
}
}
void NodeManager::EnqueuePlaceableTask(const Task &task) {
// TODO(atumanov): add task lookup hashmap and change EnqueuePlaceableTask to take
// a vector of TaskIDs. Trigger MoveTask internally.

View file

@ -192,6 +192,23 @@ class NodeManager {
/// \return Void.
void HandleWorkerUnblocked(std::shared_ptr<Worker> worker);
/// Handle a client that is blocked. This could be a worker or a driver. This
/// can be triggered when a client starts a get call or a wait call.
///
/// \param client The client that is blocked.
/// \param required_object_ids The IDs that the client is blocked waiting for.
/// \return Void.
void HandleClientBlocked(const std::shared_ptr<LocalClientConnection> &client,
const std::vector<ObjectID> &required_object_ids);
/// Handle a client that is unblocked. This could be a worker or a driver.
/// This can be triggered when a client is finished with a get call or a wait
/// call. It is ok to call this even if the client is not actually blocked.
///
/// \param client The client that is unblocked.
/// \return Void.
void HandleClientUnblocked(const std::shared_ptr<LocalClientConnection> &client);
/// Kill a worker.
///
/// \param worker The worker to kill.

View file

@ -2063,7 +2063,15 @@ class WorkerPoolTests(unittest.TestCase):
object_ids = [f.remote(i, j) for j in range(2)]
return ray.get(object_ids)
ray.get([g.remote(i) for i in range(4)])
@ray.remote
def h(i):
# Each instance of g submits and blocks on the result of another
# remote task using ray.wait.
object_ids = [f.remote(i, j) for j in range(2)]
return ray.wait(object_ids, num_returns=len(object_ids))
if os.environ.get("RAY_USE_XRAY") == "1":
ray.get([h.remote(i) for i in range(4)])
@ray.remote
def _sleep(i):