From 7c1e0e5715109bfa1b9d7a4a5d8198ef1535f5d3 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 28 Dec 2019 17:40:49 -0800 Subject: [PATCH] Implement wait_local for wait (#6524) --- python/ray/tests/test_advanced.py | 27 ++++++++ rllib/utils/actors.py | 6 +- src/ray/core_worker/core_worker.cc | 64 ++++++++++++++----- .../store_provider/plasma_store_provider.cc | 2 +- src/ray/object_manager/object_manager.cc | 33 +++++----- src/ray/object_manager/object_manager.h | 5 +- 6 files changed, 99 insertions(+), 38 deletions(-) diff --git a/python/ray/tests/test_advanced.py b/python/ray/tests/test_advanced.py index 36a667254..43074c60d 100644 --- a/python/ray/tests/test_advanced.py +++ b/python/ray/tests/test_advanced.py @@ -749,6 +749,33 @@ def test_local_mode(shutdown_only): assert ray.get(indirect_dep.remote(["hello"])) == "hello" +def test_wait_makes_object_local(ray_start_cluster): + cluster = ray_start_cluster + cluster.add_node(num_cpus=0) + cluster.add_node(num_cpus=2) + ray.init(address=cluster.address) + + @ray.remote + class Foo(object): + def method(self): + return np.zeros(1024 * 1024) + + a = Foo.remote() + + # Test get makes the object local. + x_id = a.method.remote() + assert not ray.worker.global_worker.core_worker.object_exists(x_id) + ray.get(x_id) + assert ray.worker.global_worker.core_worker.object_exists(x_id) + + # Test wait makes the object local. + x_id = a.method.remote() + assert not ray.worker.global_worker.core_worker.object_exists(x_id) + ok, _ = ray.wait([x_id]) + assert len(ok) == 1 + assert ray.worker.global_worker.core_worker.object_exists(x_id) + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/utils/actors.py b/rllib/utils/actors.py index 8c4929949..70026d38f 100644 --- a/rllib/utils/actors.py +++ b/rllib/utils/actors.py @@ -40,16 +40,12 @@ class TaskPool(object): Assumes obj_id only is one id.""" for worker, obj_id in self.completed(blocking_wait=blocking_wait): - (ray.worker.global_worker.raylet_client.fetch_or_reconstruct( - [obj_id], True)) self._fetching.append((worker, obj_id)) remaining = [] num_yielded = 0 for worker, obj_id in self._fetching: - if (num_yielded < max_yield - and ray.worker.global_worker.core_worker.object_exists( - obj_id)): + if num_yielded < max_yield: yield (worker, obj_id) num_yielded += 1 else: diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index e4a2ffa66..0156b472e 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -493,6 +493,28 @@ Status CoreWorker::Contains(const ObjectID &object_id, bool *has_object) { return Status::OK(); } +// For any objects that are ErrorType::OBJECT_IN_PLASMA, we need to move them from +// the ready set into the plasma_object_ids set to wait on them there. +void RetryObjectInPlasmaErrors(std::shared_ptr &memory_store, + WorkerContext &worker_context, + absl::flat_hash_set &memory_object_ids, + absl::flat_hash_set &plasma_object_ids, + absl::flat_hash_set &ready) { + for (const auto &mem_id : memory_object_ids) { + if (ready.find(mem_id) != ready.end()) { + std::vector> found; + RAY_CHECK_OK(memory_store->Get({mem_id}, /*num_objects=*/1, /*timeout=*/0, + worker_context, + /*remote_after_get=*/false, &found)); + if (found.size() == 1 && found[0]->IsInPlasmaError()) { + memory_object_ids.erase(mem_id); + ready.erase(mem_id); + plasma_object_ids.insert(mem_id); + } + } + } +} + Status CoreWorker::Wait(const std::vector &ids, int num_objects, int64_t timeout_ms, std::vector *results) { results->resize(ids.size(), false); @@ -523,17 +545,21 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, // Wait from both store providers with timeout set to 0. This is to avoid the case // where we might use up the entire timeout on trying to get objects from one store // provider before even trying another (which might have all of the objects available). - if (plasma_object_ids.size() > 0) { - RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( - plasma_object_ids, num_objects, /*timeout_ms=*/0, worker_context_, &ready)); + if (memory_object_ids.size() > 0) { + RAY_RETURN_NOT_OK(memory_store_->Wait( + memory_object_ids, + std::min(static_cast(memory_object_ids.size()), num_objects), + /*timeout_ms=*/0, worker_context_, &ready)); + RetryObjectInPlasmaErrors(memory_store_, worker_context_, memory_object_ids, + plasma_object_ids, ready); } RAY_CHECK(static_cast(ready.size()) <= num_objects); - if (static_cast(ready.size()) < num_objects && memory_object_ids.size() > 0) { - // TODO(ekl) for memory objects that are ErrorType::OBJECT_IN_PLASMA, we should - // consider waiting on them in plasma as well to ensure they are local. - RAY_RETURN_NOT_OK(memory_store_->Wait(memory_object_ids, - num_objects - static_cast(ready.size()), - /*timeout_ms=*/0, worker_context_, &ready)); + if (static_cast(ready.size()) < num_objects && plasma_object_ids.size() > 0) { + RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( + plasma_object_ids, + std::min(static_cast(plasma_object_ids.size()), + num_objects - static_cast(ready.size())), + /*timeout_ms=*/0, worker_context_, &ready)); } RAY_CHECK(static_cast(ready.size()) <= num_objects); @@ -543,19 +569,25 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, ready.clear(); int64_t start_time = current_time_ms(); - if (plasma_object_ids.size() > 0) { - RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( - plasma_object_ids, num_objects, timeout_ms, worker_context_, &ready)); + if (memory_object_ids.size() > 0) { + RAY_RETURN_NOT_OK(memory_store_->Wait( + memory_object_ids, + std::min(static_cast(memory_object_ids.size()), num_objects), timeout_ms, + worker_context_, &ready)); + RetryObjectInPlasmaErrors(memory_store_, worker_context_, memory_object_ids, + plasma_object_ids, ready); } RAY_CHECK(static_cast(ready.size()) <= num_objects); if (timeout_ms > 0) { timeout_ms = std::max(0, static_cast(timeout_ms - (current_time_ms() - start_time))); } - if (static_cast(ready.size()) < num_objects && memory_object_ids.size() > 0) { - RAY_RETURN_NOT_OK(memory_store_->Wait(memory_object_ids, - num_objects - static_cast(ready.size()), - timeout_ms, worker_context_, &ready)); + if (static_cast(ready.size()) < num_objects && plasma_object_ids.size() > 0) { + RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( + plasma_object_ids, + std::min(static_cast(plasma_object_ids.size()), + num_objects - static_cast(ready.size())), + timeout_ms, worker_context_, &ready)); } RAY_CHECK(static_cast(ready.size()) <= num_objects); } diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index 482863d7c..ff110369f 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -257,7 +257,7 @@ Status CoreWorkerPlasmaStoreProvider::Wait( RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked()); } RAY_RETURN_NOT_OK( - raylet_client_->Wait(id_vector, num_objects, call_timeout, false, + raylet_client_->Wait(id_vector, num_objects, call_timeout, /*wait_local*/ true, /*mark_worker_blocked*/ !ctx.CurrentTaskIsDirectCall(), ctx.GetCurrentTaskID(), &result_pair)); diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index b5a279c68..31e9ce69d 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -494,10 +494,6 @@ ray::Status ObjectManager::AddWaitRequest(const UniqueID &wait_id, int64_t timeout_ms, uint64_t num_required_objects, bool wait_local, const WaitCallback &callback) { - if (wait_local) { - return ray::Status::NotImplemented("Wait for local objects is not yet implemented."); - } - RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1); RAY_CHECK(num_required_objects != 0); RAY_CHECK(num_required_objects <= object_ids.size()) @@ -512,6 +508,7 @@ ray::Status ObjectManager::AddWaitRequest(const UniqueID &wait_id, wait_state.object_id_order = object_ids; wait_state.timeout_ms = timeout_ms; wait_state.num_required_objects = num_required_objects; + wait_state.wait_local = wait_local; for (const auto &object_id : object_ids) { if (local_objects_.count(object_id) > 0) { wait_state.found.insert(object_id); @@ -541,7 +538,10 @@ ray::Status ObjectManager::LookupRemainingWaitObjects(const UniqueID &wait_id) { object_id, [this, wait_id](const ObjectID &lookup_object_id, const std::unordered_set &client_ids) { auto &wait_state = active_wait_requests_.find(wait_id)->second; - if (!client_ids.empty()) { + // Note that the object is guaranteed to be added to local_objects_ before + // the notification is triggered. + if (local_objects_.count(lookup_object_id) > 0 || + (!wait_state.wait_local && !client_ids.empty())) { wait_state.remaining.erase(lookup_object_id); wait_state.found.insert(lookup_object_id); } @@ -578,19 +578,22 @@ void ObjectManager::SubscribeRemainingWaitObjects(const UniqueID &wait_id) { wait_id, object_id, [this, wait_id](const ObjectID &subscribe_object_id, const std::unordered_set &client_ids) { - if (!client_ids.empty()) { + auto object_id_wait_state = active_wait_requests_.find(wait_id); + if (object_id_wait_state == active_wait_requests_.end()) { + // Depending on the timing of calls to the object directory, we + // may get a subscription notification after the wait call has + // already completed. If so, then don't process the + // notification. + return; + } + auto &wait_state = object_id_wait_state->second; + // Note that the object is guaranteed to be added to local_objects_ before + // the notification is triggered. + if (local_objects_.count(subscribe_object_id) > 0 || + (!wait_state.wait_local && !client_ids.empty())) { RAY_LOG(DEBUG) << "Wait request " << wait_id << ": subscription notification received for object " << subscribe_object_id; - auto object_id_wait_state = active_wait_requests_.find(wait_id); - if (object_id_wait_state == active_wait_requests_.end()) { - // Depending on the timing of calls to the object directory, we - // may get a subscription notification after the wait call has - // already completed. If so, then don't process the - // notification. - return; - } - auto &wait_state = object_id_wait_state->second; wait_state.remaining.erase(subscribe_object_id); wait_state.found.insert(subscribe_object_id); wait_state.requested_objects.erase(subscribe_object_id); diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index cf8ddd7f1..917768fdf 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -265,6 +265,8 @@ class ObjectManager : public ObjectManagerInterface, callback(callback) {} /// The period of time to wait before invoking the callback. int64_t timeout_ms; + /// Whether to wait for objects to become local before returning. + bool wait_local; /// The timer used whenever wait_ms > 0. std::unique_ptr timeout_timer; /// The callback invoked when WaitCallback is complete. @@ -273,7 +275,8 @@ class ObjectManager : public ObjectManagerInterface, std::vector object_id_order; /// The objects that have not yet been found. std::unordered_set remaining; - /// The objects that have been found. + /// The objects that have been found. Note that if wait_local is true, then + /// this will only contain objects that are in local_objects_ too. std::unordered_set found; /// Objects that have been requested either by Lookup or Subscribe. std::unordered_set requested_objects;