Implement wait_local for wait (#6524)

This commit is contained in:
Eric Liang 2019-12-28 17:40:49 -08:00 committed by GitHub
parent 677004ee3d
commit 7c1e0e5715
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 99 additions and 38 deletions

View file

@ -749,6 +749,33 @@ def test_local_mode(shutdown_only):
assert ray.get(indirect_dep.remote(["hello"])) == "hello"
def test_wait_makes_object_local(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=0)
cluster.add_node(num_cpus=2)
ray.init(address=cluster.address)
@ray.remote
class Foo(object):
def method(self):
return np.zeros(1024 * 1024)
a = Foo.remote()
# Test get makes the object local.
x_id = a.method.remote()
assert not ray.worker.global_worker.core_worker.object_exists(x_id)
ray.get(x_id)
assert ray.worker.global_worker.core_worker.object_exists(x_id)
# Test wait makes the object local.
x_id = a.method.remote()
assert not ray.worker.global_worker.core_worker.object_exists(x_id)
ok, _ = ray.wait([x_id])
assert len(ok) == 1
assert ray.worker.global_worker.core_worker.object_exists(x_id)
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))

View file

@ -40,16 +40,12 @@ class TaskPool(object):
Assumes obj_id only is one id."""
for worker, obj_id in self.completed(blocking_wait=blocking_wait):
(ray.worker.global_worker.raylet_client.fetch_or_reconstruct(
[obj_id], True))
self._fetching.append((worker, obj_id))
remaining = []
num_yielded = 0
for worker, obj_id in self._fetching:
if (num_yielded < max_yield
and ray.worker.global_worker.core_worker.object_exists(
obj_id)):
if num_yielded < max_yield:
yield (worker, obj_id)
num_yielded += 1
else:

View file

@ -493,6 +493,28 @@ Status CoreWorker::Contains(const ObjectID &object_id, bool *has_object) {
return Status::OK();
}
// For any objects that are ErrorType::OBJECT_IN_PLASMA, we need to move them from
// the ready set into the plasma_object_ids set to wait on them there.
void RetryObjectInPlasmaErrors(std::shared_ptr<CoreWorkerMemoryStore> &memory_store,
WorkerContext &worker_context,
absl::flat_hash_set<ObjectID> &memory_object_ids,
absl::flat_hash_set<ObjectID> &plasma_object_ids,
absl::flat_hash_set<ObjectID> &ready) {
for (const auto &mem_id : memory_object_ids) {
if (ready.find(mem_id) != ready.end()) {
std::vector<std::shared_ptr<RayObject>> found;
RAY_CHECK_OK(memory_store->Get({mem_id}, /*num_objects=*/1, /*timeout=*/0,
worker_context,
/*remote_after_get=*/false, &found));
if (found.size() == 1 && found[0]->IsInPlasmaError()) {
memory_object_ids.erase(mem_id);
ready.erase(mem_id);
plasma_object_ids.insert(mem_id);
}
}
}
}
Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
int64_t timeout_ms, std::vector<bool> *results) {
results->resize(ids.size(), false);
@ -523,17 +545,21 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
// Wait from both store providers with timeout set to 0. This is to avoid the case
// where we might use up the entire timeout on trying to get objects from one store
// provider before even trying another (which might have all of the objects available).
if (plasma_object_ids.size() > 0) {
RAY_RETURN_NOT_OK(plasma_store_provider_->Wait(
plasma_object_ids, num_objects, /*timeout_ms=*/0, worker_context_, &ready));
if (memory_object_ids.size() > 0) {
RAY_RETURN_NOT_OK(memory_store_->Wait(
memory_object_ids,
std::min(static_cast<int>(memory_object_ids.size()), num_objects),
/*timeout_ms=*/0, worker_context_, &ready));
RetryObjectInPlasmaErrors(memory_store_, worker_context_, memory_object_ids,
plasma_object_ids, ready);
}
RAY_CHECK(static_cast<int>(ready.size()) <= num_objects);
if (static_cast<int>(ready.size()) < num_objects && memory_object_ids.size() > 0) {
// TODO(ekl) for memory objects that are ErrorType::OBJECT_IN_PLASMA, we should
// consider waiting on them in plasma as well to ensure they are local.
RAY_RETURN_NOT_OK(memory_store_->Wait(memory_object_ids,
num_objects - static_cast<int>(ready.size()),
/*timeout_ms=*/0, worker_context_, &ready));
if (static_cast<int>(ready.size()) < num_objects && plasma_object_ids.size() > 0) {
RAY_RETURN_NOT_OK(plasma_store_provider_->Wait(
plasma_object_ids,
std::min(static_cast<int>(plasma_object_ids.size()),
num_objects - static_cast<int>(ready.size())),
/*timeout_ms=*/0, worker_context_, &ready));
}
RAY_CHECK(static_cast<int>(ready.size()) <= num_objects);
@ -543,19 +569,25 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
ready.clear();
int64_t start_time = current_time_ms();
if (plasma_object_ids.size() > 0) {
RAY_RETURN_NOT_OK(plasma_store_provider_->Wait(
plasma_object_ids, num_objects, timeout_ms, worker_context_, &ready));
if (memory_object_ids.size() > 0) {
RAY_RETURN_NOT_OK(memory_store_->Wait(
memory_object_ids,
std::min(static_cast<int>(memory_object_ids.size()), num_objects), timeout_ms,
worker_context_, &ready));
RetryObjectInPlasmaErrors(memory_store_, worker_context_, memory_object_ids,
plasma_object_ids, ready);
}
RAY_CHECK(static_cast<int>(ready.size()) <= num_objects);
if (timeout_ms > 0) {
timeout_ms =
std::max(0, static_cast<int>(timeout_ms - (current_time_ms() - start_time)));
}
if (static_cast<int>(ready.size()) < num_objects && memory_object_ids.size() > 0) {
RAY_RETURN_NOT_OK(memory_store_->Wait(memory_object_ids,
num_objects - static_cast<int>(ready.size()),
timeout_ms, worker_context_, &ready));
if (static_cast<int>(ready.size()) < num_objects && plasma_object_ids.size() > 0) {
RAY_RETURN_NOT_OK(plasma_store_provider_->Wait(
plasma_object_ids,
std::min(static_cast<int>(plasma_object_ids.size()),
num_objects - static_cast<int>(ready.size())),
timeout_ms, worker_context_, &ready));
}
RAY_CHECK(static_cast<int>(ready.size()) <= num_objects);
}

View file

@ -257,7 +257,7 @@ Status CoreWorkerPlasmaStoreProvider::Wait(
RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked());
}
RAY_RETURN_NOT_OK(
raylet_client_->Wait(id_vector, num_objects, call_timeout, false,
raylet_client_->Wait(id_vector, num_objects, call_timeout, /*wait_local*/ true,
/*mark_worker_blocked*/ !ctx.CurrentTaskIsDirectCall(),
ctx.GetCurrentTaskID(), &result_pair));

View file

@ -494,10 +494,6 @@ ray::Status ObjectManager::AddWaitRequest(const UniqueID &wait_id,
int64_t timeout_ms,
uint64_t num_required_objects, bool wait_local,
const WaitCallback &callback) {
if (wait_local) {
return ray::Status::NotImplemented("Wait for local objects is not yet implemented.");
}
RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1);
RAY_CHECK(num_required_objects != 0);
RAY_CHECK(num_required_objects <= object_ids.size())
@ -512,6 +508,7 @@ ray::Status ObjectManager::AddWaitRequest(const UniqueID &wait_id,
wait_state.object_id_order = object_ids;
wait_state.timeout_ms = timeout_ms;
wait_state.num_required_objects = num_required_objects;
wait_state.wait_local = wait_local;
for (const auto &object_id : object_ids) {
if (local_objects_.count(object_id) > 0) {
wait_state.found.insert(object_id);
@ -541,7 +538,10 @@ ray::Status ObjectManager::LookupRemainingWaitObjects(const UniqueID &wait_id) {
object_id, [this, wait_id](const ObjectID &lookup_object_id,
const std::unordered_set<ClientID> &client_ids) {
auto &wait_state = active_wait_requests_.find(wait_id)->second;
if (!client_ids.empty()) {
// Note that the object is guaranteed to be added to local_objects_ before
// the notification is triggered.
if (local_objects_.count(lookup_object_id) > 0 ||
(!wait_state.wait_local && !client_ids.empty())) {
wait_state.remaining.erase(lookup_object_id);
wait_state.found.insert(lookup_object_id);
}
@ -578,19 +578,22 @@ void ObjectManager::SubscribeRemainingWaitObjects(const UniqueID &wait_id) {
wait_id, object_id,
[this, wait_id](const ObjectID &subscribe_object_id,
const std::unordered_set<ClientID> &client_ids) {
if (!client_ids.empty()) {
auto object_id_wait_state = active_wait_requests_.find(wait_id);
if (object_id_wait_state == active_wait_requests_.end()) {
// Depending on the timing of calls to the object directory, we
// may get a subscription notification after the wait call has
// already completed. If so, then don't process the
// notification.
return;
}
auto &wait_state = object_id_wait_state->second;
// Note that the object is guaranteed to be added to local_objects_ before
// the notification is triggered.
if (local_objects_.count(subscribe_object_id) > 0 ||
(!wait_state.wait_local && !client_ids.empty())) {
RAY_LOG(DEBUG) << "Wait request " << wait_id
<< ": subscription notification received for object "
<< subscribe_object_id;
auto object_id_wait_state = active_wait_requests_.find(wait_id);
if (object_id_wait_state == active_wait_requests_.end()) {
// Depending on the timing of calls to the object directory, we
// may get a subscription notification after the wait call has
// already completed. If so, then don't process the
// notification.
return;
}
auto &wait_state = object_id_wait_state->second;
wait_state.remaining.erase(subscribe_object_id);
wait_state.found.insert(subscribe_object_id);
wait_state.requested_objects.erase(subscribe_object_id);

View file

@ -265,6 +265,8 @@ class ObjectManager : public ObjectManagerInterface,
callback(callback) {}
/// The period of time to wait before invoking the callback.
int64_t timeout_ms;
/// Whether to wait for objects to become local before returning.
bool wait_local;
/// The timer used whenever wait_ms > 0.
std::unique_ptr<boost::asio::deadline_timer> timeout_timer;
/// The callback invoked when WaitCallback is complete.
@ -273,7 +275,8 @@ class ObjectManager : public ObjectManagerInterface,
std::vector<ObjectID> object_id_order;
/// The objects that have not yet been found.
std::unordered_set<ObjectID> remaining;
/// The objects that have been found.
/// The objects that have been found. Note that if wait_local is true, then
/// this will only contain objects that are in local_objects_ too.
std::unordered_set<ObjectID> found;
/// Objects that have been requested either by Lookup or Subscribe.
std::unordered_set<ObjectID> requested_objects;