From e7b752cf3323f8f82b69104b2c7412467e509538 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 21 Jun 2021 22:32:04 -0700 Subject: [PATCH] [core] Fix bug in task dependency management for duplicate args (#16365) * Pytest * Skip on windows * C++ --- python/ray/tests/test_scheduling.py | 32 +++++++++++++++++++++++ src/ray/object_manager/pull_manager.cc | 2 +- src/ray/raylet/dependency_manager.cc | 7 ++++- src/ray/raylet/dependency_manager.h | 7 ++--- src/ray/raylet/dependency_manager_test.cc | 31 ++++++++++++++++++++++ 5 files changed, 72 insertions(+), 7 deletions(-) diff --git a/python/ray/tests/test_scheduling.py b/python/ray/tests/test_scheduling.py index f593c7e57..430a0d199 100644 --- a/python/ray/tests/test_scheduling.py +++ b/python/ray/tests/test_scheduling.py @@ -446,6 +446,38 @@ def test_lease_request_leak(shutdown_only): assert object_memory_usage() == 0 +@pytest.mark.skipif(sys.platform == "win32", reason="Fails on windows") +def test_many_args(ray_start_cluster): + # This test ensures that a task will run where its task dependencies are + # located, even when those objects are borrowed. + cluster = ray_start_cluster + object_size = int(1e6) + + # Disable worker caching so worker leases are not reused, and disable + # inlining of return objects so return objects are always put into Plasma. + for _ in range(4): + cluster.add_node( + num_cpus=1, object_store_memory=(4 * object_size * 25)) + ray.init(address=cluster.address) + + @ray.remote + def f(i, *args): + print(i) + return + + @ray.remote + def put(): + return np.zeros(object_size, dtype=np.uint8) + + xs = [put.remote() for _ in range(100)] + ray.wait(xs, num_returns=len(xs), fetch_local=False) + tasks = [] + for i in range(100): + args = [np.random.choice(xs) for _ in range(25)] + tasks.append(f.remote(i, *args)) + ray.get(tasks, timeout=30) + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/src/ray/object_manager/pull_manager.cc b/src/ray/object_manager/pull_manager.cc index 9fb473873..b9fbc7bc1 100644 --- a/src/ray/object_manager/pull_manager.cc +++ b/src/ray/object_manager/pull_manager.cc @@ -347,9 +347,9 @@ std::vector PullManager::CancelPull(uint64_t request_id) { std::vector object_ids_to_cancel_subscription; for (const auto &ref : bundle_it->second.objects) { auto obj_id = ObjectRefToId(ref); - RAY_LOG(DEBUG) << "Removing an object pull request of id: " << obj_id; auto it = object_pull_requests_.find(obj_id); if (it != object_pull_requests_.end()) { + RAY_LOG(DEBUG) << "Removing an object pull request of id: " << obj_id; it->second.bundle_request_ids.erase(bundle_it->first); if (it->second.bundle_request_ids.empty()) { object_pull_requests_.erase(it); diff --git a/src/ray/raylet/dependency_manager.cc b/src/ray/raylet/dependency_manager.cc index f0b7e8824..0d2ce1c27 100644 --- a/src/ray/raylet/dependency_manager.cc +++ b/src/ray/raylet/dependency_manager.cc @@ -156,7 +156,10 @@ bool DependencyManager::RequestTaskDependencies( const TaskID &task_id, const std::vector &required_objects) { RAY_LOG(DEBUG) << "Adding dependencies for task " << task_id << ". Required objects length: " << required_objects.size(); - auto inserted = queued_task_requests_.emplace(task_id, required_objects); + + const auto required_ids = ObjectRefsToIds(required_objects); + absl::flat_hash_set deduped_ids(required_ids.begin(), required_ids.end()); + auto inserted = queued_task_requests_.emplace(task_id, std::move(deduped_ids)); RAY_CHECK(inserted.second) << "Task depedencies can be requested only once per task. " << task_id; auto &task_entry = inserted.first->second; @@ -167,7 +170,9 @@ bool DependencyManager::RequestTaskDependencies( auto it = GetOrInsertRequiredObject(obj_id, ref); it->second.dependent_tasks.insert(task_id); + } + for (const auto &obj_id : task_entry.dependencies) { if (local_objects_.count(obj_id)) { task_entry.num_missing_dependencies--; } diff --git a/src/ray/raylet/dependency_manager.h b/src/ray/raylet/dependency_manager.h index 471de1ae4..5077db8c6 100644 --- a/src/ray/raylet/dependency_manager.h +++ b/src/ray/raylet/dependency_manager.h @@ -192,11 +192,8 @@ class DependencyManager : public TaskDependencyManagerInterface { /// A struct to represent the object dependencies of a task. struct TaskDependencies { - TaskDependencies(const std::vector &deps) - : num_missing_dependencies(deps.size()) { - const auto dep_ids = ObjectRefsToIds(deps); - dependencies.insert(dep_ids.begin(), dep_ids.end()); - } + TaskDependencies(const absl::flat_hash_set &deps) + : dependencies(std::move(deps)), num_missing_dependencies(dependencies.size()) {} /// The objects that the task depends on. These are the arguments to the /// task. These must all be simultaneously local before the task is ready /// to execute. Objects are removed from this set once diff --git a/src/ray/raylet/dependency_manager_test.cc b/src/ray/raylet/dependency_manager_test.cc index 7283c5a67..cce6db942 100644 --- a/src/ray/raylet/dependency_manager_test.cc +++ b/src/ray/raylet/dependency_manager_test.cc @@ -314,6 +314,37 @@ TEST_F(DependencyManagerTest, TestWaitObjectLocal) { AssertNoLeaks(); } +/// Test requesting the dependencies for a task. The dependency manager should +/// return the task ID as ready once all of its unique arguments are local. +TEST_F(DependencyManagerTest, TestDuplicateTaskArgs) { + // Create a task with 3 arguments. + int num_arguments = 3; + auto obj_id = ObjectID::FromRandom(); + std::vector arguments; + for (int i = 0; i < num_arguments; i++) { + arguments.push_back(obj_id); + } + TaskID task_id = RandomTaskId(); + bool ready = + dependency_manager_.RequestTaskDependencies(task_id, ObjectIdsToRefs(arguments)); + ASSERT_FALSE(ready); + ASSERT_EQ(object_manager_mock_.active_task_requests.size(), 1); + + auto ready_task_ids = dependency_manager_.HandleObjectLocal(obj_id); + ASSERT_EQ(ready_task_ids.size(), 1); + ASSERT_EQ(ready_task_ids.front(), task_id); + dependency_manager_.RemoveTaskDependencies(task_id); + + TaskID task_id2 = RandomTaskId(); + ready = + dependency_manager_.RequestTaskDependencies(task_id2, ObjectIdsToRefs(arguments)); + ASSERT_TRUE(ready); + ASSERT_EQ(object_manager_mock_.active_task_requests.size(), 1); + dependency_manager_.RemoveTaskDependencies(task_id2); + + AssertNoLeaks(); +} + } // namespace raylet } // namespace ray