mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
Implement wait_local for wait (#6524)
This commit is contained in:
parent
677004ee3d
commit
7c1e0e5715
6 changed files with 99 additions and 38 deletions
|
@ -749,6 +749,33 @@ def test_local_mode(shutdown_only):
|
|||
assert ray.get(indirect_dep.remote(["hello"])) == "hello"
|
||||
|
||||
|
||||
def test_wait_makes_object_local(ray_start_cluster):
|
||||
cluster = ray_start_cluster
|
||||
cluster.add_node(num_cpus=0)
|
||||
cluster.add_node(num_cpus=2)
|
||||
ray.init(address=cluster.address)
|
||||
|
||||
@ray.remote
|
||||
class Foo(object):
|
||||
def method(self):
|
||||
return np.zeros(1024 * 1024)
|
||||
|
||||
a = Foo.remote()
|
||||
|
||||
# Test get makes the object local.
|
||||
x_id = a.method.remote()
|
||||
assert not ray.worker.global_worker.core_worker.object_exists(x_id)
|
||||
ray.get(x_id)
|
||||
assert ray.worker.global_worker.core_worker.object_exists(x_id)
|
||||
|
||||
# Test wait makes the object local.
|
||||
x_id = a.method.remote()
|
||||
assert not ray.worker.global_worker.core_worker.object_exists(x_id)
|
||||
ok, _ = ray.wait([x_id])
|
||||
assert len(ok) == 1
|
||||
assert ray.worker.global_worker.core_worker.object_exists(x_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -40,16 +40,12 @@ class TaskPool(object):
|
|||
Assumes obj_id only is one id."""
|
||||
|
||||
for worker, obj_id in self.completed(blocking_wait=blocking_wait):
|
||||
(ray.worker.global_worker.raylet_client.fetch_or_reconstruct(
|
||||
[obj_id], True))
|
||||
self._fetching.append((worker, obj_id))
|
||||
|
||||
remaining = []
|
||||
num_yielded = 0
|
||||
for worker, obj_id in self._fetching:
|
||||
if (num_yielded < max_yield
|
||||
and ray.worker.global_worker.core_worker.object_exists(
|
||||
obj_id)):
|
||||
if num_yielded < max_yield:
|
||||
yield (worker, obj_id)
|
||||
num_yielded += 1
|
||||
else:
|
||||
|
|
|
@ -493,6 +493,28 @@ Status CoreWorker::Contains(const ObjectID &object_id, bool *has_object) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// For any objects that are ErrorType::OBJECT_IN_PLASMA, we need to move them from
|
||||
// the ready set into the plasma_object_ids set to wait on them there.
|
||||
void RetryObjectInPlasmaErrors(std::shared_ptr<CoreWorkerMemoryStore> &memory_store,
|
||||
WorkerContext &worker_context,
|
||||
absl::flat_hash_set<ObjectID> &memory_object_ids,
|
||||
absl::flat_hash_set<ObjectID> &plasma_object_ids,
|
||||
absl::flat_hash_set<ObjectID> &ready) {
|
||||
for (const auto &mem_id : memory_object_ids) {
|
||||
if (ready.find(mem_id) != ready.end()) {
|
||||
std::vector<std::shared_ptr<RayObject>> found;
|
||||
RAY_CHECK_OK(memory_store->Get({mem_id}, /*num_objects=*/1, /*timeout=*/0,
|
||||
worker_context,
|
||||
/*remote_after_get=*/false, &found));
|
||||
if (found.size() == 1 && found[0]->IsInPlasmaError()) {
|
||||
memory_object_ids.erase(mem_id);
|
||||
ready.erase(mem_id);
|
||||
plasma_object_ids.insert(mem_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
|
||||
int64_t timeout_ms, std::vector<bool> *results) {
|
||||
results->resize(ids.size(), false);
|
||||
|
@ -523,17 +545,21 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
|
|||
// Wait from both store providers with timeout set to 0. This is to avoid the case
|
||||
// where we might use up the entire timeout on trying to get objects from one store
|
||||
// provider before even trying another (which might have all of the objects available).
|
||||
if (plasma_object_ids.size() > 0) {
|
||||
RAY_RETURN_NOT_OK(plasma_store_provider_->Wait(
|
||||
plasma_object_ids, num_objects, /*timeout_ms=*/0, worker_context_, &ready));
|
||||
if (memory_object_ids.size() > 0) {
|
||||
RAY_RETURN_NOT_OK(memory_store_->Wait(
|
||||
memory_object_ids,
|
||||
std::min(static_cast<int>(memory_object_ids.size()), num_objects),
|
||||
/*timeout_ms=*/0, worker_context_, &ready));
|
||||
RetryObjectInPlasmaErrors(memory_store_, worker_context_, memory_object_ids,
|
||||
plasma_object_ids, ready);
|
||||
}
|
||||
RAY_CHECK(static_cast<int>(ready.size()) <= num_objects);
|
||||
if (static_cast<int>(ready.size()) < num_objects && memory_object_ids.size() > 0) {
|
||||
// TODO(ekl) for memory objects that are ErrorType::OBJECT_IN_PLASMA, we should
|
||||
// consider waiting on them in plasma as well to ensure they are local.
|
||||
RAY_RETURN_NOT_OK(memory_store_->Wait(memory_object_ids,
|
||||
num_objects - static_cast<int>(ready.size()),
|
||||
/*timeout_ms=*/0, worker_context_, &ready));
|
||||
if (static_cast<int>(ready.size()) < num_objects && plasma_object_ids.size() > 0) {
|
||||
RAY_RETURN_NOT_OK(plasma_store_provider_->Wait(
|
||||
plasma_object_ids,
|
||||
std::min(static_cast<int>(plasma_object_ids.size()),
|
||||
num_objects - static_cast<int>(ready.size())),
|
||||
/*timeout_ms=*/0, worker_context_, &ready));
|
||||
}
|
||||
RAY_CHECK(static_cast<int>(ready.size()) <= num_objects);
|
||||
|
||||
|
@ -543,19 +569,25 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
|
|||
ready.clear();
|
||||
|
||||
int64_t start_time = current_time_ms();
|
||||
if (plasma_object_ids.size() > 0) {
|
||||
RAY_RETURN_NOT_OK(plasma_store_provider_->Wait(
|
||||
plasma_object_ids, num_objects, timeout_ms, worker_context_, &ready));
|
||||
if (memory_object_ids.size() > 0) {
|
||||
RAY_RETURN_NOT_OK(memory_store_->Wait(
|
||||
memory_object_ids,
|
||||
std::min(static_cast<int>(memory_object_ids.size()), num_objects), timeout_ms,
|
||||
worker_context_, &ready));
|
||||
RetryObjectInPlasmaErrors(memory_store_, worker_context_, memory_object_ids,
|
||||
plasma_object_ids, ready);
|
||||
}
|
||||
RAY_CHECK(static_cast<int>(ready.size()) <= num_objects);
|
||||
if (timeout_ms > 0) {
|
||||
timeout_ms =
|
||||
std::max(0, static_cast<int>(timeout_ms - (current_time_ms() - start_time)));
|
||||
}
|
||||
if (static_cast<int>(ready.size()) < num_objects && memory_object_ids.size() > 0) {
|
||||
RAY_RETURN_NOT_OK(memory_store_->Wait(memory_object_ids,
|
||||
num_objects - static_cast<int>(ready.size()),
|
||||
timeout_ms, worker_context_, &ready));
|
||||
if (static_cast<int>(ready.size()) < num_objects && plasma_object_ids.size() > 0) {
|
||||
RAY_RETURN_NOT_OK(plasma_store_provider_->Wait(
|
||||
plasma_object_ids,
|
||||
std::min(static_cast<int>(plasma_object_ids.size()),
|
||||
num_objects - static_cast<int>(ready.size())),
|
||||
timeout_ms, worker_context_, &ready));
|
||||
}
|
||||
RAY_CHECK(static_cast<int>(ready.size()) <= num_objects);
|
||||
}
|
||||
|
|
|
@ -257,7 +257,7 @@ Status CoreWorkerPlasmaStoreProvider::Wait(
|
|||
RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked());
|
||||
}
|
||||
RAY_RETURN_NOT_OK(
|
||||
raylet_client_->Wait(id_vector, num_objects, call_timeout, false,
|
||||
raylet_client_->Wait(id_vector, num_objects, call_timeout, /*wait_local*/ true,
|
||||
/*mark_worker_blocked*/ !ctx.CurrentTaskIsDirectCall(),
|
||||
ctx.GetCurrentTaskID(), &result_pair));
|
||||
|
||||
|
|
|
@ -494,10 +494,6 @@ ray::Status ObjectManager::AddWaitRequest(const UniqueID &wait_id,
|
|||
int64_t timeout_ms,
|
||||
uint64_t num_required_objects, bool wait_local,
|
||||
const WaitCallback &callback) {
|
||||
if (wait_local) {
|
||||
return ray::Status::NotImplemented("Wait for local objects is not yet implemented.");
|
||||
}
|
||||
|
||||
RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1);
|
||||
RAY_CHECK(num_required_objects != 0);
|
||||
RAY_CHECK(num_required_objects <= object_ids.size())
|
||||
|
@ -512,6 +508,7 @@ ray::Status ObjectManager::AddWaitRequest(const UniqueID &wait_id,
|
|||
wait_state.object_id_order = object_ids;
|
||||
wait_state.timeout_ms = timeout_ms;
|
||||
wait_state.num_required_objects = num_required_objects;
|
||||
wait_state.wait_local = wait_local;
|
||||
for (const auto &object_id : object_ids) {
|
||||
if (local_objects_.count(object_id) > 0) {
|
||||
wait_state.found.insert(object_id);
|
||||
|
@ -541,7 +538,10 @@ ray::Status ObjectManager::LookupRemainingWaitObjects(const UniqueID &wait_id) {
|
|||
object_id, [this, wait_id](const ObjectID &lookup_object_id,
|
||||
const std::unordered_set<ClientID> &client_ids) {
|
||||
auto &wait_state = active_wait_requests_.find(wait_id)->second;
|
||||
if (!client_ids.empty()) {
|
||||
// Note that the object is guaranteed to be added to local_objects_ before
|
||||
// the notification is triggered.
|
||||
if (local_objects_.count(lookup_object_id) > 0 ||
|
||||
(!wait_state.wait_local && !client_ids.empty())) {
|
||||
wait_state.remaining.erase(lookup_object_id);
|
||||
wait_state.found.insert(lookup_object_id);
|
||||
}
|
||||
|
@ -578,19 +578,22 @@ void ObjectManager::SubscribeRemainingWaitObjects(const UniqueID &wait_id) {
|
|||
wait_id, object_id,
|
||||
[this, wait_id](const ObjectID &subscribe_object_id,
|
||||
const std::unordered_set<ClientID> &client_ids) {
|
||||
if (!client_ids.empty()) {
|
||||
auto object_id_wait_state = active_wait_requests_.find(wait_id);
|
||||
if (object_id_wait_state == active_wait_requests_.end()) {
|
||||
// Depending on the timing of calls to the object directory, we
|
||||
// may get a subscription notification after the wait call has
|
||||
// already completed. If so, then don't process the
|
||||
// notification.
|
||||
return;
|
||||
}
|
||||
auto &wait_state = object_id_wait_state->second;
|
||||
// Note that the object is guaranteed to be added to local_objects_ before
|
||||
// the notification is triggered.
|
||||
if (local_objects_.count(subscribe_object_id) > 0 ||
|
||||
(!wait_state.wait_local && !client_ids.empty())) {
|
||||
RAY_LOG(DEBUG) << "Wait request " << wait_id
|
||||
<< ": subscription notification received for object "
|
||||
<< subscribe_object_id;
|
||||
auto object_id_wait_state = active_wait_requests_.find(wait_id);
|
||||
if (object_id_wait_state == active_wait_requests_.end()) {
|
||||
// Depending on the timing of calls to the object directory, we
|
||||
// may get a subscription notification after the wait call has
|
||||
// already completed. If so, then don't process the
|
||||
// notification.
|
||||
return;
|
||||
}
|
||||
auto &wait_state = object_id_wait_state->second;
|
||||
wait_state.remaining.erase(subscribe_object_id);
|
||||
wait_state.found.insert(subscribe_object_id);
|
||||
wait_state.requested_objects.erase(subscribe_object_id);
|
||||
|
|
|
@ -265,6 +265,8 @@ class ObjectManager : public ObjectManagerInterface,
|
|||
callback(callback) {}
|
||||
/// The period of time to wait before invoking the callback.
|
||||
int64_t timeout_ms;
|
||||
/// Whether to wait for objects to become local before returning.
|
||||
bool wait_local;
|
||||
/// The timer used whenever wait_ms > 0.
|
||||
std::unique_ptr<boost::asio::deadline_timer> timeout_timer;
|
||||
/// The callback invoked when WaitCallback is complete.
|
||||
|
@ -273,7 +275,8 @@ class ObjectManager : public ObjectManagerInterface,
|
|||
std::vector<ObjectID> object_id_order;
|
||||
/// The objects that have not yet been found.
|
||||
std::unordered_set<ObjectID> remaining;
|
||||
/// The objects that have been found.
|
||||
/// The objects that have been found. Note that if wait_local is true, then
|
||||
/// this will only contain objects that are in local_objects_ too.
|
||||
std::unordered_set<ObjectID> found;
|
||||
/// Objects that have been requested either by Lookup or Subscribe.
|
||||
std::unordered_set<ObjectID> requested_objects;
|
||||
|
|
Loading…
Add table
Reference in a new issue