From 84d3d4b67b371538f8b3b76d6aed55e6d63daf20 Mon Sep 17 00:00:00 2001 From: micafan <550435771@qq.com> Date: Mon, 23 Dec 2019 17:54:21 +0800 Subject: [PATCH] [GCS]refactor the GCS Client Task Interface (#6556) --- src/ray/core_worker/actor_manager.cc | 3 +- src/ray/core_worker/core_worker.cc | 2 +- src/ray/gcs/accessor.h | 183 ++++++ src/ray/gcs/actor_info_accessor.h | 87 --- src/ray/gcs/gcs_client.h | 11 +- src/ray/gcs/job_info_accessor.h | 54 -- ...tor_info_accessor.cc => redis_accessor.cc} | 97 +++- src/ray/gcs/redis_accessor.h | 139 +++++ src/ray/gcs/redis_actor_info_accessor.h | 68 --- src/ray/gcs/redis_gcs_client.cc | 13 +- src/ray/gcs/redis_gcs_client.h | 7 +- src/ray/gcs/redis_job_info_accessor.cc | 50 -- src/ray/gcs/redis_job_info_accessor.h | 53 -- src/ray/gcs/subscription_executor.cc | 1 + src/ray/gcs/tables.cc | 7 + src/ray/gcs/tables.h | 21 + src/ray/gcs/test/accessor_test_base.h | 1 + src/ray/gcs/test/redis_gcs_client_test.cc | 547 +++++++++--------- .../gcs/test/redis_job_info_accessor_test.cc | 1 - src/ray/raylet/lineage_cache.cc | 45 +- src/ray/raylet/lineage_cache.h | 21 +- src/ray/raylet/lineage_cache_test.cc | 337 ++++++----- src/ray/raylet/node_manager.cc | 79 +-- 23 files changed, 991 insertions(+), 836 deletions(-) create mode 100644 src/ray/gcs/accessor.h delete mode 100644 src/ray/gcs/actor_info_accessor.h delete mode 100644 src/ray/gcs/job_info_accessor.h rename src/ray/gcs/{redis_actor_info_accessor.cc => redis_accessor.cc} (57%) create mode 100644 src/ray/gcs/redis_accessor.h delete mode 100644 src/ray/gcs/redis_actor_info_accessor.h delete mode 100644 src/ray/gcs/redis_job_info_accessor.cc delete mode 100644 src/ray/gcs/redis_job_info_accessor.h diff --git a/src/ray/core_worker/actor_manager.cc b/src/ray/core_worker/actor_manager.cc index d6665440d..6ee5fa5ab 100644 --- a/src/ray/core_worker/actor_manager.cc +++ b/src/ray/core_worker/actor_manager.cc @@ -1,6 +1,5 @@ #include "ray/core_worker/actor_manager.h" - -#include "ray/gcs/redis_actor_info_accessor.h" +#include "ray/gcs/redis_accessor.h" namespace ray { diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 2ae41f399..43d0af509 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -220,7 +220,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, std::shared_ptr data = std::make_shared(); data->mutable_task()->mutable_task_spec()->CopyFrom(builder.Build().GetMessage()); - RAY_CHECK_OK(gcs_client_->raylet_task_table().Add(job_id, task_id, data, nullptr)); + RAY_CHECK_OK(gcs_client_->Tasks().AsyncAdd(data, nullptr)); SetCurrentTaskId(task_id); } diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h new file mode 100644 index 000000000..2066583da --- /dev/null +++ b/src/ray/gcs/accessor.h @@ -0,0 +1,183 @@ +#ifndef RAY_GCS_ACCESSOR_H +#define RAY_GCS_ACCESSOR_H + +#include "ray/common/id.h" +#include "ray/gcs/callback.h" +#include "ray/protobuf/gcs.pb.h" + +namespace ray { + +namespace gcs { + +/// \class ActorInfoAccessor +/// `ActorInfoAccessor` is a sub-interface of `GcsClient`. +/// This class includes all the methods that are related to accessing +/// actor information in the GCS. +class ActorInfoAccessor { + public: + virtual ~ActorInfoAccessor() = default; + + /// Get actor specification from GCS asynchronously. + /// + /// \param actor_id The ID of actor to look up in the GCS. + /// \param callback Callback that will be called after lookup finishes. + /// \return Status + virtual Status AsyncGet(const ActorID &actor_id, + const OptionalItemCallback &callback) = 0; + + /// Register an actor to GCS asynchronously. + /// + /// \param data_ptr The actor that will be registered to the GCS. + /// \param callback Callback that will be called after actor has been registered + /// to the GCS. + /// \return Status + virtual Status AsyncRegister(const std::shared_ptr &data_ptr, + const StatusCallback &callback) = 0; + + /// Update dynamic states of actor in GCS asynchronously. + /// + /// \param actor_id ID of the actor to update. + /// \param data_ptr Data of the actor to update. + /// \param callback Callback that will be called after update finishes. + /// \return Status + /// TODO(micafan) Don't expose the whole `ActorTableData` and only allow + /// updating dynamic states. + virtual Status AsyncUpdate(const ActorID &actor_id, + const std::shared_ptr &data_ptr, + const StatusCallback &callback) = 0; + + /// Subscribe to any register or update operations of actors. + /// + /// \param subscribe Callback that will be called each time when an actor is registered + /// or updated. + /// \param done Callback that will be called when subscription is complete and we + /// are ready to receive notification. + /// \return Status + virtual Status AsyncSubscribeAll( + const SubscribeCallback &subscribe, + const StatusCallback &done) = 0; + + /// Subscribe to any update operations of an actor. + /// + /// \param actor_id The ID of actor to be subscribed to. + /// \param subscribe Callback that will be called each time when the actor is updated. + /// \param done Callback that will be called when subscription is complete. + /// \return Status + virtual Status AsyncSubscribe( + const ActorID &actor_id, + const SubscribeCallback &subscribe, + const StatusCallback &done) = 0; + + /// Cancel subscription to an actor. + /// + /// \param actor_id The ID of the actor to be unsubscribed to. + /// \param done Callback that will be called when unsubscribe is complete. + /// \return Status + virtual Status AsyncUnsubscribe(const ActorID &actor_id, + const StatusCallback &done) = 0; + + protected: + ActorInfoAccessor() = default; +}; + +/// \class JobInfoAccessor +/// `JobInfoAccessor` is a sub-interface of `GcsClient`. +/// This class includes all the methods that are related to accessing +/// job information in the GCS. +class JobInfoAccessor { + public: + virtual ~JobInfoAccessor() = default; + + /// Add a job to GCS asynchronously. + /// + /// \param data_ptr The job that will be add to GCS. + /// \param callback Callback that will be called after job has been added + /// to GCS. + /// \return Status + virtual Status AsyncAdd(const std::shared_ptr &data_ptr, + const StatusCallback &callback) = 0; + + /// Mark job as finished in GCS asynchronously. + /// + /// \param job_id ID of the job that will be make finished to GCS. + /// \param callback Callback that will be called after update finished. + /// \return Status + virtual Status AsyncMarkFinished(const JobID &job_id, + const StatusCallback &callback) = 0; + + /// Subscribe to finished jobs. + /// + /// \param subscribe Callback that will be called each time when a job finishes. + /// \param done Callback that will be called when subscription is complete. + /// \return Status + virtual Status AsyncSubscribeToFinishedJobs( + const SubscribeCallback &subscribe, + const StatusCallback &done) = 0; + + protected: + JobInfoAccessor() = default; +}; + +/// \class TaskInfoAccessor +/// `TaskInfoAccessor` is a sub-interface of `GcsClient`. +/// This class includes all the methods that are related to accessing +/// task information in the GCS. +class TaskInfoAccessor { + public: + virtual ~TaskInfoAccessor() {} + + /// Add a task to GCS asynchronously. + /// + /// \param data_ptr The task that will be added to GCS. + /// \param callback Callback that will be called after task has been added + /// to GCS. + /// \return Status + virtual Status AsyncAdd(const std::shared_ptr &data_ptr, + const StatusCallback &callback) = 0; + + /// Get task information from GCS asynchronously. + /// + /// \param task_id The ID of the task to look up in GCS. + /// \param callback Callback that is called after lookup finished. + /// \return Status + virtual Status AsyncGet(const TaskID &task_id, + const OptionalItemCallback &callback) = 0; + + /// Delete tasks from GCS asynchronously. + /// + /// \param task_ids The vector of IDs to delete from GCS. + /// \param callback Callback that is called after delete finished. + /// \return Status + // TODO(micafan) Will support callback of batch deletion in the future. + // Currently this callback will never be called. + virtual Status AsyncDelete(const std::vector &task_ids, + const StatusCallback &callback) = 0; + + /// Subscribe asynchronously to the event that the given task is added in GCS. + /// + /// \param task_id The ID of the task to be subscribed to. + /// \param subscribe Callback that will be called each time when the task is updated. + /// \param done Callback that will be called when subscription is complete. + /// \return Status + virtual Status AsyncSubscribe( + const TaskID &task_id, + const SubscribeCallback &subscribe, + const StatusCallback &done) = 0; + + /// Cancel subscription to a task asynchronously. + /// This method is for node only (core worker shouldn't use this method). + /// + /// \param task_id The ID of the task to be unsubscribed to. + /// \param done Callback that will be called when unsubscribe is complete. + /// \return Status + virtual Status AsyncUnsubscribe(const TaskID &task_id, const StatusCallback &done) = 0; + + protected: + TaskInfoAccessor() = default; +}; + +} // namespace gcs + +} // namespace ray + +#endif // RAY_GCS_ACCESSOR_H \ No newline at end of file diff --git a/src/ray/gcs/actor_info_accessor.h b/src/ray/gcs/actor_info_accessor.h deleted file mode 100644 index 4a3a3109c..000000000 --- a/src/ray/gcs/actor_info_accessor.h +++ /dev/null @@ -1,87 +0,0 @@ -#ifndef RAY_GCS_ACTOR_INFO_ACCESSOR_H -#define RAY_GCS_ACTOR_INFO_ACCESSOR_H - -#include "ray/common/id.h" -#include "ray/gcs/callback.h" -#include "ray/protobuf/gcs.pb.h" - -namespace ray { - -namespace gcs { - -/// \class ActorInfoAccessor -/// `ActorInfoAccessor` is a sub-interface of `GcsClient`. -/// This class includes all the methods that are related to accessing -/// actor information in the GCS. -class ActorInfoAccessor { - public: - virtual ~ActorInfoAccessor() = default; - - /// Get actor specification from GCS asynchronously. - /// - /// \param actor_id The ID of actor to look up in the GCS. - /// \param callback Callback that will be called after lookup finishes. - /// \return Status - virtual Status AsyncGet(const ActorID &actor_id, - const OptionalItemCallback &callback) = 0; - - /// Register an actor to GCS asynchronously. - /// - /// \param data_ptr The actor that will be registered to the GCS. - /// \param callback Callback that will be called after actor has been registered - /// to the GCS. - /// \return Status - virtual Status AsyncRegister(const std::shared_ptr &data_ptr, - const StatusCallback &callback) = 0; - - /// Update dynamic states of actor in GCS asynchronously. - /// - /// \param actor_id ID of the actor to update. - /// \param data_ptr Data of the actor to update. - /// \param callback Callback that will be called after update finishes. - /// \return Status - /// TODO(micafan) Don't expose the whole `ActorTableData` and only allow - /// updating dynamic states. - virtual Status AsyncUpdate(const ActorID &actor_id, - const std::shared_ptr &data_ptr, - const StatusCallback &callback) = 0; - - /// Subscribe to any register or update operations of actors. - /// - /// \param subscribe Callback that will be called each time when an actor is registered - /// or updated. - /// \param done Callback that will be called when subscription is complete and we - /// are ready to receive notification. - /// \return Status - virtual Status AsyncSubscribeAll( - const SubscribeCallback &subscribe, - const StatusCallback &done) = 0; - - /// Subscribe to any update operations of an actor. - /// - /// \param actor_id The ID of actor to be subscribed to. - /// \param subscribe Callback that will be called each time when the actor is updated. - /// \param done Callback that will be called when subscription is complete. - /// \return Status - virtual Status AsyncSubscribe( - const ActorID &actor_id, - const SubscribeCallback &subscribe, - const StatusCallback &done) = 0; - - /// Cancel subscription to an actor. - /// - /// \param actor_id The ID of the actor to be unsubscribed to. - /// \param done Callback that will be called when unsubscribe is complete. - /// \return Status - virtual Status AsyncUnsubscribe(const ActorID &actor_id, - const StatusCallback &done) = 0; - - protected: - ActorInfoAccessor() = default; -}; - -} // namespace gcs - -} // namespace ray - -#endif // RAY_GCS_ACTOR_INFO_ACCESSOR_H diff --git a/src/ray/gcs/gcs_client.h b/src/ray/gcs/gcs_client.h index c5ebd8e35..d71bdbbec 100644 --- a/src/ray/gcs/gcs_client.h +++ b/src/ray/gcs/gcs_client.h @@ -6,8 +6,7 @@ #include #include #include "ray/common/status.h" -#include "ray/gcs/actor_info_accessor.h" -#include "ray/gcs/job_info_accessor.h" +#include "ray/gcs/accessor.h" #include "ray/util/logging.h" namespace ray { @@ -75,6 +74,13 @@ class GcsClient : public std::enable_shared_from_this { return *job_accessor_; } + /// Get the sub-interface for accessing task information in GCS. + /// This function is thread safe. + TaskInfoAccessor &Tasks() { + RAY_CHECK(task_accessor_ != nullptr); + return *task_accessor_; + } + protected: /// Constructor of GcsClient. /// @@ -88,6 +94,7 @@ class GcsClient : public std::enable_shared_from_this { std::unique_ptr actor_accessor_; std::unique_ptr job_accessor_; + std::unique_ptr task_accessor_; }; } // namespace gcs diff --git a/src/ray/gcs/job_info_accessor.h b/src/ray/gcs/job_info_accessor.h deleted file mode 100644 index e24a582e9..000000000 --- a/src/ray/gcs/job_info_accessor.h +++ /dev/null @@ -1,54 +0,0 @@ -#ifndef RAY_GCS_JOB_INFO_ACCESSOR_H -#define RAY_GCS_JOB_INFO_ACCESSOR_H - -#include "ray/common/id.h" -#include "ray/gcs/callback.h" -#include "ray/protobuf/gcs.pb.h" - -namespace ray { - -namespace gcs { - -/// \class JobInfoAccessor -/// `JobInfoAccessor` is a sub-interface of `GcsClient`. -/// This class includes all the methods that are related to accessing -/// job information in the GCS. -class JobInfoAccessor { - public: - virtual ~JobInfoAccessor() = default; - - /// Add a job to GCS asynchronously. - /// - /// \param data_ptr The job that will be add to GCS. - /// \param callback Callback that will be called after job has been added - /// to GCS. - /// \return Status - virtual Status AsyncAdd(const std::shared_ptr &data_ptr, - const StatusCallback &callback) = 0; - - /// Mark job as finished in GCS asynchronously. - /// - /// \param job_id ID of the job that will be make finished to GCS. - /// \param callback Callback that will be called after update finished. - /// \return Status - virtual Status AsyncMarkFinished(const JobID &job_id, - const StatusCallback &callback) = 0; - - /// Subscribe to finished jobs. - /// - /// \param subscribe Callback that will be called each time when a job finishes. - /// \param done Callback that will be called when subscription is complete. - /// \return Status - virtual Status AsyncSubscribeToFinishedJobs( - const SubscribeCallback &subscribe, - const StatusCallback &done) = 0; - - protected: - JobInfoAccessor() = default; -}; - -} // namespace gcs - -} // namespace ray - -#endif // RAY_GCS_JOB_INFO_ACCESSOR_H diff --git a/src/ray/gcs/redis_actor_info_accessor.cc b/src/ray/gcs/redis_accessor.cc similarity index 57% rename from src/ray/gcs/redis_actor_info_accessor.cc rename to src/ray/gcs/redis_accessor.cc index a093d702f..a4ebee85d 100644 --- a/src/ray/gcs/redis_actor_info_accessor.cc +++ b/src/ray/gcs/redis_accessor.cc @@ -1,5 +1,6 @@ -#include "ray/gcs/redis_actor_info_accessor.h" +#include "ray/gcs/redis_accessor.h" #include +#include "ray/gcs/pb_util.h" #include "ray/gcs/redis_gcs_client.h" #include "ray/util/logging.h" @@ -132,6 +133,100 @@ Status RedisActorInfoAccessor::AsyncUnsubscribe(const ActorID &actor_id, return actor_sub_executor_.AsyncUnsubscribe(node_id_, actor_id, done); } +RedisJobInfoAccessor::RedisJobInfoAccessor(RedisGcsClient *client_impl) + : client_impl_(client_impl), job_sub_executor_(client_impl->job_table()) {} + +Status RedisJobInfoAccessor::AsyncAdd(const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + return DoAsyncAppend(data_ptr, callback); +} + +Status RedisJobInfoAccessor::AsyncMarkFinished(const JobID &job_id, + const StatusCallback &callback) { + std::shared_ptr data_ptr = + CreateJobTableData(job_id, /*is_dead*/ true, /*time_stamp*/ std::time(nullptr), + /*node_manager_address*/ "", /*driver_pid*/ -1); + return DoAsyncAppend(data_ptr, callback); +} + +Status RedisJobInfoAccessor::DoAsyncAppend(const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + JobTable::WriteCallback on_done = nullptr; + if (callback != nullptr) { + on_done = [callback](RedisGcsClient *client, const JobID &job_id, + const JobTableData &data) { callback(Status::OK()); }; + } + + JobID job_id = JobID::FromBinary(data_ptr->job_id()); + return client_impl_->job_table().Append(job_id, job_id, data_ptr, on_done); +} + +Status RedisJobInfoAccessor::AsyncSubscribeToFinishedJobs( + const SubscribeCallback &subscribe, const StatusCallback &done) { + RAY_CHECK(subscribe != nullptr); + auto on_subscribe = [subscribe](const JobID &job_id, const JobTableData &job_data) { + if (job_data.is_dead()) { + subscribe(job_id, job_data); + } + }; + return job_sub_executor_.AsyncSubscribeAll(ClientID::Nil(), on_subscribe, done); +} + +RedisTaskInfoAccessor::RedisTaskInfoAccessor(RedisGcsClient *client_impl) + : client_impl_(client_impl), task_sub_executor_(client_impl->raylet_task_table()) {} + +Status RedisTaskInfoAccessor::AsyncAdd(const std::shared_ptr &data_ptr, + const StatusCallback &callback) { + raylet::TaskTable::WriteCallback on_done = nullptr; + if (callback != nullptr) { + on_done = [callback](RedisGcsClient *client, const TaskID &task_id, + const TaskTableData &data) { callback(Status::OK()); }; + } + + TaskID task_id = TaskID::FromBinary(data_ptr->task().task_spec().task_id()); + raylet::TaskTable &task_table = client_impl_->raylet_task_table(); + return task_table.Add(JobID::Nil(), task_id, data_ptr, on_done); +} + +Status RedisTaskInfoAccessor::AsyncGet( + const TaskID &task_id, const OptionalItemCallback &callback) { + RAY_CHECK(callback != nullptr); + auto on_success = [callback](RedisGcsClient *client, const TaskID &task_id, + const TaskTableData &data) { + boost::optional result(data); + callback(Status::OK(), result); + }; + + auto on_failure = [callback](RedisGcsClient *client, const TaskID &task_id) { + boost::optional result; + callback(Status::Invalid("Task not exist."), result); + }; + + raylet::TaskTable &task_table = client_impl_->raylet_task_table(); + return task_table.Lookup(JobID::Nil(), task_id, on_success, on_failure); +} + +Status RedisTaskInfoAccessor::AsyncDelete(const std::vector &task_ids, + const StatusCallback &callback) { + raylet::TaskTable &task_table = client_impl_->raylet_task_table(); + task_table.Delete(JobID::Nil(), task_ids); + // TODO(micafan) Always return OK here. + // Confirm if we need to handle the deletion failure and how to handle it. + return Status::OK(); +} + +Status RedisTaskInfoAccessor::AsyncSubscribe( + const TaskID &task_id, const SubscribeCallback &subscribe, + const StatusCallback &done) { + RAY_CHECK(subscribe != nullptr); + return task_sub_executor_.AsyncSubscribe(subscribe_id_, task_id, subscribe, done); +} + +Status RedisTaskInfoAccessor::AsyncUnsubscribe(const TaskID &task_id, + const StatusCallback &done) { + return task_sub_executor_.AsyncUnsubscribe(subscribe_id_, task_id, done); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h new file mode 100644 index 000000000..53b6668f7 --- /dev/null +++ b/src/ray/gcs/redis_accessor.h @@ -0,0 +1,139 @@ +#ifndef RAY_GCS_REDIS_ACCESSOR_H +#define RAY_GCS_REDIS_ACCESSOR_H + +#include "ray/common/id.h" +#include "ray/common/task/task_spec.h" +#include "ray/gcs/accessor.h" +#include "ray/gcs/callback.h" +#include "ray/gcs/subscription_executor.h" +#include "ray/gcs/tables.h" + +namespace ray { + +namespace gcs { + +class RedisGcsClient; + +std::shared_ptr CreateActorTableData( + const TaskSpecification &task_spec, const rpc::Address &address, + gcs::ActorTableData::ActorState state, uint64_t remaining_reconstructions); + +/// \class RedisActorInfoAccessor +/// `RedisActorInfoAccessor` is an implementation of `ActorInfoAccessor` +/// that uses Redis as the backend storage. +class RedisActorInfoAccessor : public ActorInfoAccessor { + public: + explicit RedisActorInfoAccessor(RedisGcsClient *client_impl); + + virtual ~RedisActorInfoAccessor() {} + + Status AsyncGet(const ActorID &actor_id, + const OptionalItemCallback &callback) override; + + Status AsyncRegister(const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + Status AsyncUpdate(const ActorID &actor_id, + const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + Status AsyncSubscribeAll(const SubscribeCallback &subscribe, + const StatusCallback &done) override; + + Status AsyncSubscribe(const ActorID &actor_id, + const SubscribeCallback &subscribe, + const StatusCallback &done) override; + + Status AsyncUnsubscribe(const ActorID &actor_id, const StatusCallback &done) override; + + private: + RedisGcsClient *client_impl_{nullptr}; + // Use a random ClientID for actor subscription. Because: + // If we use ClientID::Nil, GCS will still send all actors' updates to this GCS Client. + // Even we can filter out irrelevant updates, but there will be extra overhead. + // And because the new GCS Client will no longer hold the local ClientID, so we use + // random ClientID instead. + // TODO(micafan): Remove this random id, once GCS becomes a service. + ClientID node_id_{ClientID::FromRandom()}; + + typedef SubscriptionExecutor + ActorSubscriptionExecutor; + ActorSubscriptionExecutor actor_sub_executor_; +}; + +/// \class RedisJobInfoAccessor +/// RedisJobInfoAccessor is an implementation of `JobInfoAccessor` +/// that uses Redis as the backend storage. +class RedisJobInfoAccessor : public JobInfoAccessor { + public: + explicit RedisJobInfoAccessor(RedisGcsClient *client_impl); + + virtual ~RedisJobInfoAccessor() {} + + Status AsyncAdd(const std::shared_ptr &data_ptr, + const StatusCallback &callback) override; + + Status AsyncMarkFinished(const JobID &job_id, const StatusCallback &callback) override; + + Status AsyncSubscribeToFinishedJobs( + const SubscribeCallback &subscribe, + const StatusCallback &done) override; + + private: + /// Append job information to GCS asynchronously. + /// + /// \param data_ptr The job information that will be appended to GCS. + /// \param callback Callback that will be called after append done. + /// \return Status + Status DoAsyncAppend(const std::shared_ptr &data_ptr, + const StatusCallback &callback); + + RedisGcsClient *client_impl_{nullptr}; + + typedef SubscriptionExecutor JobSubscriptionExecutor; + JobSubscriptionExecutor job_sub_executor_; +}; + +/// \class RedisTaskInfoAccessor +/// `RedisTaskInfoAccessor` is an implementation of `TaskInfoAccessor` +/// that uses Redis as the backend storage. +class RedisTaskInfoAccessor : public TaskInfoAccessor { + public: + explicit RedisTaskInfoAccessor(RedisGcsClient *client_impl); + + ~RedisTaskInfoAccessor() {} + + Status AsyncAdd(const std::shared_ptr &data_ptr, + const StatusCallback &callback); + + Status AsyncGet(const TaskID &task_id, + const OptionalItemCallback &callback); + + Status AsyncDelete(const std::vector &task_ids, const StatusCallback &callback); + + Status AsyncSubscribe(const TaskID &task_id, + const SubscribeCallback &subscribe, + const StatusCallback &done); + + Status AsyncUnsubscribe(const TaskID &task_id, const StatusCallback &done); + + private: + RedisGcsClient *client_impl_{nullptr}; + // Use a random ClientID for task subscription. Because: + // If we use ClientID::Nil, GCS will still send all tasks' updates to this GCS Client. + // Even we can filter out irrelevant updates, but there will be extra overhead. + // And because the new GCS Client will no longer hold the local ClientID, so we use + // random ClientID instead. + // TODO(micafan): Remove this random id, once GCS becomes a service. + ClientID subscribe_id_{ClientID::FromRandom()}; + + typedef SubscriptionExecutor + TaskSubscriptionExecutor; + TaskSubscriptionExecutor task_sub_executor_; +}; + +} // namespace gcs + +} // namespace ray + +#endif // RAY_GCS_REDIS_ACCESSOR_H \ No newline at end of file diff --git a/src/ray/gcs/redis_actor_info_accessor.h b/src/ray/gcs/redis_actor_info_accessor.h deleted file mode 100644 index a5c92ea85..000000000 --- a/src/ray/gcs/redis_actor_info_accessor.h +++ /dev/null @@ -1,68 +0,0 @@ -#ifndef RAY_GCS_REDIS_ACTOR_INFO_ACCESSOR_H -#define RAY_GCS_REDIS_ACTOR_INFO_ACCESSOR_H - -#include "ray/common/id.h" -#include "ray/common/task/task_spec.h" -#include "ray/gcs/actor_info_accessor.h" -#include "ray/gcs/callback.h" -#include "ray/gcs/subscription_executor.h" -#include "ray/gcs/tables.h" - -namespace ray { - -namespace gcs { - -std::shared_ptr CreateActorTableData( - const TaskSpecification &task_spec, const rpc::Address &address, - gcs::ActorTableData::ActorState state, uint64_t remaining_reconstructions); - -class RedisGcsClient; - -/// \class RedisActorInfoAccessor -/// `RedisActorInfoAccessor` is an implementation of `ActorInfoAccessor` -/// that uses Redis as the backend storage. -class RedisActorInfoAccessor : public ActorInfoAccessor { - public: - explicit RedisActorInfoAccessor(RedisGcsClient *client_impl); - - virtual ~RedisActorInfoAccessor() {} - - Status AsyncGet(const ActorID &actor_id, - const OptionalItemCallback &callback) override; - - Status AsyncRegister(const std::shared_ptr &data_ptr, - const StatusCallback &callback) override; - - Status AsyncUpdate(const ActorID &actor_id, - const std::shared_ptr &data_ptr, - const StatusCallback &callback) override; - - Status AsyncSubscribeAll(const SubscribeCallback &subscribe, - const StatusCallback &done) override; - - Status AsyncSubscribe(const ActorID &actor_id, - const SubscribeCallback &subscribe, - const StatusCallback &done) override; - - Status AsyncUnsubscribe(const ActorID &actor_id, const StatusCallback &done) override; - - private: - RedisGcsClient *client_impl_{nullptr}; - // Use a random ClientID for actor subscription. Because: - // If we use ClientID::Nil, GCS will still send all actors' updates to this GCS Client. - // Even we can filter out irrelevant updates, but there will be extra overhead. - // And because the new GCS Client will no longer hold the local ClientID, so we use - // random ClientID instead. - // TODO(micafan): Remove this random id, once GCS becomes a service. - ClientID node_id_{ClientID::FromRandom()}; - - typedef SubscriptionExecutor - ActorSubscriptionExecutor; - ActorSubscriptionExecutor actor_sub_executor_; -}; - -} // namespace gcs - -} // namespace ray - -#endif // RAY_GCS_REDIS_ACTOR_INFO_ACCESSOR_H diff --git a/src/ray/gcs/redis_gcs_client.cc b/src/ray/gcs/redis_gcs_client.cc index af2a408dc..3f060e39f 100644 --- a/src/ray/gcs/redis_gcs_client.cc +++ b/src/ray/gcs/redis_gcs_client.cc @@ -2,9 +2,8 @@ #include #include "ray/common/ray_config.h" -#include "ray/gcs/redis_actor_info_accessor.h" +#include "ray/gcs/redis_accessor.h" #include "ray/gcs/redis_context.h" -#include "ray/gcs/redis_job_info_accessor.h" static void GetRedisShards(redisContext *context, std::vector &addresses, std::vector &ports) { @@ -73,13 +72,8 @@ namespace ray { namespace gcs { -RedisGcsClient::RedisGcsClient(const GcsClientOptions &options) : GcsClient(options) { -#if RAY_USE_NEW_GCS - command_type_ = CommandType::kChain; -#else - command_type_ = CommandType::kRegular; -#endif -} +RedisGcsClient::RedisGcsClient(const GcsClientOptions &options) + : GcsClient(options), command_type_(CommandType::kRegular) {} RedisGcsClient::RedisGcsClient(const GcsClientOptions &options, CommandType command_type) : GcsClient(options), command_type_(command_type) {} @@ -150,6 +144,7 @@ Status RedisGcsClient::Connect(boost::asio::io_service &io_service) { actor_accessor_.reset(new RedisActorInfoAccessor(this)); job_accessor_.reset(new RedisJobInfoAccessor(this)); + task_accessor_.reset(new RedisTaskInfoAccessor(this)); is_connected_ = true; diff --git a/src/ray/gcs/redis_gcs_client.h b/src/ray/gcs/redis_gcs_client.h index 5eda50896..7606ba83d 100644 --- a/src/ray/gcs/redis_gcs_client.h +++ b/src/ray/gcs/redis_gcs_client.h @@ -18,12 +18,14 @@ namespace gcs { class RedisContext; class RAY_EXPORT RedisGcsClient : public GcsClient { - // TODO(micafan) Will remove those friend class after we replace RedisGcsClient + // TODO(micafan) Will remove those friend class / method after we replace RedisGcsClient // with interface class GcsClient in raylet. friend class RedisActorInfoAccessor; friend class RedisJobInfoAccessor; + friend class RedisTaskInfoAccessor; friend class SubscriptionExecutorTest; friend class LogSubscribeTestHelper; + friend class TaskTableTestHelper; public: /// Constructor of RedisGcsClient. @@ -55,7 +57,6 @@ class RAY_EXPORT RedisGcsClient : public GcsClient { // TODO: Some API for getting the error on the driver ObjectTable &object_table(); - raylet::TaskTable &raylet_task_table(); TaskReconstructionLog &task_reconstruction_log(); TaskLeaseTable &task_lease_table(); ClientTable &client_table(); @@ -96,6 +97,8 @@ class RAY_EXPORT RedisGcsClient : public GcsClient { ActorTable &actor_table(); /// This method will be deprecated, use method Jobs() instead. JobTable &job_table(); + /// This method will be deprecated, use method Tasks() instead. + raylet::TaskTable &raylet_task_table(); // GCS command type. If CommandType::kChain, chain-replicated versions of the tables // might be used, if available. diff --git a/src/ray/gcs/redis_job_info_accessor.cc b/src/ray/gcs/redis_job_info_accessor.cc deleted file mode 100644 index d7028c163..000000000 --- a/src/ray/gcs/redis_job_info_accessor.cc +++ /dev/null @@ -1,50 +0,0 @@ -#include "ray/gcs/redis_job_info_accessor.h" -#include "ray/gcs/pb_util.h" -#include "ray/gcs/redis_gcs_client.h" - -namespace ray { - -namespace gcs { - -RedisJobInfoAccessor::RedisJobInfoAccessor(RedisGcsClient *client_impl) - : client_impl_(client_impl), job_sub_executor_(client_impl->job_table()) {} - -Status RedisJobInfoAccessor::AsyncAdd(const std::shared_ptr &data_ptr, - const StatusCallback &callback) { - return DoAsyncAppend(data_ptr, callback); -} - -Status RedisJobInfoAccessor::AsyncMarkFinished(const JobID &job_id, - const StatusCallback &callback) { - std::shared_ptr data_ptr = - CreateJobTableData(job_id, /*is_dead*/ true, /*time_stamp*/ std::time(nullptr), - /*node_manager_address*/ "", /*driver_pid*/ -1); - return DoAsyncAppend(data_ptr, callback); -} - -Status RedisJobInfoAccessor::DoAsyncAppend(const std::shared_ptr &data_ptr, - const StatusCallback &callback) { - JobTable::WriteCallback on_done = nullptr; - if (callback != nullptr) { - on_done = [callback](RedisGcsClient *client, const JobID &job_id, - const JobTableData &data) { callback(Status::OK()); }; - } - - JobID job_id = JobID::FromBinary(data_ptr->job_id()); - return client_impl_->job_table().Append(job_id, job_id, data_ptr, on_done); -} - -Status RedisJobInfoAccessor::AsyncSubscribeToFinishedJobs( - const SubscribeCallback &subscribe, const StatusCallback &done) { - RAY_CHECK(subscribe != nullptr); - auto on_subscribe = [subscribe](const JobID &job_id, const JobTableData &job_data) { - if (job_data.is_dead()) { - subscribe(job_id, job_data); - } - }; - return job_sub_executor_.AsyncSubscribeAll(ClientID::Nil(), on_subscribe, done); -} - -} // namespace gcs - -} // namespace ray diff --git a/src/ray/gcs/redis_job_info_accessor.h b/src/ray/gcs/redis_job_info_accessor.h deleted file mode 100644 index 2fb8ebc75..000000000 --- a/src/ray/gcs/redis_job_info_accessor.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef RAY_GCS_REDIS_JOB_INFO_ACCESSOR_H -#define RAY_GCS_REDIS_JOB_INFO_ACCESSOR_H - -#include "ray/common/id.h" -#include "ray/gcs/callback.h" -#include "ray/gcs/job_info_accessor.h" -#include "ray/gcs/subscription_executor.h" -#include "ray/gcs/tables.h" - -namespace ray { - -namespace gcs { - -class RedisGcsClient; - -/// \class RedisJobInfoAccessor -/// RedisJobInfoAccessor is an implementation of `JobInfoAccessor` -/// that uses Redis as the backend storage. -class RedisJobInfoAccessor : public JobInfoAccessor { - public: - explicit RedisJobInfoAccessor(RedisGcsClient *client_impl); - - virtual ~RedisJobInfoAccessor() {} - - Status AsyncAdd(const std::shared_ptr &data_ptr, - const StatusCallback &callback) override; - - Status AsyncMarkFinished(const JobID &job_id, const StatusCallback &callback) override; - - Status AsyncSubscribeToFinishedJobs( - const SubscribeCallback &subscribe, - const StatusCallback &done) override; - - private: - /// Append job information to GCS asynchronously. - /// - /// \param data_ptr The job information that will be appended to GCS. - /// \param callback Callback that will be called after append done. - /// \return Status - Status DoAsyncAppend(const std::shared_ptr &data_ptr, - const StatusCallback &callback); - - RedisGcsClient *client_impl_{nullptr}; - - typedef SubscriptionExecutor JobSubscriptionExecutor; - JobSubscriptionExecutor job_sub_executor_; -}; - -} // namespace gcs - -} // namespace ray - -#endif // RAY_GCS_REDIS_JOB_INFO_ACCESSOR_H diff --git a/src/ray/gcs/subscription_executor.cc b/src/ray/gcs/subscription_executor.cc index 84f2b65b5..8d5b60856 100644 --- a/src/ray/gcs/subscription_executor.cc +++ b/src/ray/gcs/subscription_executor.cc @@ -189,6 +189,7 @@ Status SubscriptionExecutor::AsyncUnsubscribe( template class SubscriptionExecutor; template class SubscriptionExecutor; template class SubscriptionExecutor; +template class SubscriptionExecutor; } // namespace gcs diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index ca7b8e7c5..1b357d494 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -306,6 +306,13 @@ Status Table::Subscribe(const JobID &job_id, const ClientID &client_id done); } +template +Status Table::Subscribe(const JobID &job_id, const ClientID &client_id, + const Callback &subscribe, + const SubscriptionCallback &done) { + return Subscribe(job_id, client_id, subscribe, /*failure*/ nullptr, done); +} + template std::string Table::DebugString() const { std::stringstream result; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index d970e6263..b3c52a372 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -311,6 +311,9 @@ class Table : private Log, using Log::RequestNotifications; using Log::CancelNotifications; + /// Expose this interface for use by subscription tools class SubscriptionExecutor. + /// In this way TaskTable() can also reuse class SubscriptionExecutor. + using Log::Subscribe; /// Add an entry to the table. This overwrites any existing data at the key. /// @@ -356,6 +359,24 @@ class Table : private Log, const Callback &subscribe, const FailureCallback &failure, const SubscriptionCallback &done); + /// Subscribe to any Add operations to this table. The caller may choose to + /// subscribe to all Adds, or to subscribe only to keys that it requests + /// notifications for. This may only be called once per Table instance. + /// + /// \param job_id The ID of the job. + /// \param client_id The type of update to listen to. If this is nil, then a + /// message for each Add to the table will be received. Else, only + /// messages for the given client will be received. In the latter + /// case, the client may request notifications on specific keys in the + /// table via `RequestNotifications`. + /// \param subscribe Callback that is called on each received message. If the + /// callback is called with an empty vector, then there was no data at the key. + /// \param done Callback that is called when subscription is complete and we + /// are ready to receive messages. + /// \return Status + Status Subscribe(const JobID &job_id, const ClientID &client_id, + const Callback &subscribe, const SubscriptionCallback &done); + void Delete(const JobID &job_id, const ID &id) { Log::Delete(job_id, id); } void Delete(const JobID &job_id, const std::vector &ids) { diff --git a/src/ray/gcs/test/accessor_test_base.h b/src/ray/gcs/test/accessor_test_base.h index 0e9be39ec..50563d7d1 100644 --- a/src/ray/gcs/test/accessor_test_base.h +++ b/src/ray/gcs/test/accessor_test_base.h @@ -7,6 +7,7 @@ #include #include #include "gtest/gtest.h" +#include "ray/gcs/redis_accessor.h" #include "ray/gcs/redis_gcs_client.h" #include "ray/util/test_util.h" diff --git a/src/ray/gcs/test/redis_gcs_client_test.cc b/src/ray/gcs/test/redis_gcs_client_test.cc index 76dbdcbf3..24f7aedfe 100644 --- a/src/ray/gcs/test/redis_gcs_client_test.cc +++ b/src/ray/gcs/test/redis_gcs_client_test.cc @@ -87,69 +87,278 @@ class TestGcsWithChainAsio : public TestGcsWithAsio { TestGcsWithChainAsio() : TestGcsWithAsio(gcs::CommandType::kChain){}; }; -/// A helper function that creates a GCS `TaskTableData` object. -std::shared_ptr CreateTaskTableData(const TaskID &task_id, - uint64_t num_returns = 0) { - auto data = std::make_shared(); - data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary()); - data->mutable_task()->mutable_task_spec()->set_num_returns(num_returns); - return data; -} +class TaskTableTestHelper { + public: + /// A helper function that creates a GCS `TaskTableData` object. + static std::shared_ptr CreateTaskTableData(const TaskID &task_id, + uint64_t num_returns = 0) { + auto data = std::make_shared(); + data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary()); + data->mutable_task()->mutable_task_spec()->set_num_returns(num_returns); + return data; + } -/// A helper function that compare whether 2 `TaskTableData` objects are equal. -/// Note, this function only compares fields set by `CreateTaskTableData`. -bool TaskTableDataEqual(const TaskTableData &data1, const TaskTableData &data2) { - const auto &spec1 = data1.task().task_spec(); - const auto &spec2 = data2.task().task_spec(); - return (spec1.task_id() == spec2.task_id() && - spec1.num_returns() == spec2.num_returns()); -} + /// A helper function that compare whether 2 `TaskTableData` objects are equal. + /// Note, this function only compares fields set by `CreateTaskTableData`. + static bool TaskTableDataEqual(const TaskTableData &data1, const TaskTableData &data2) { + const auto &spec1 = data1.task().task_spec(); + const auto &spec2 = data2.task().task_spec(); + return (spec1.task_id() == spec2.task_id() && + spec1.num_returns() == spec2.num_returns()); + } -void TestTableLookup(const JobID &job_id, std::shared_ptr client) { - const auto task_id = RandomTaskId(); - const auto data = CreateTaskTableData(task_id); + static void TestTableLookup(const JobID &job_id, + std::shared_ptr client) { + const auto task_id = RandomTaskId(); + const auto data = CreateTaskTableData(task_id); - // Check that we added the correct task. - auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &d) { - ASSERT_EQ(id, task_id); - ASSERT_TRUE(TaskTableDataEqual(*data, d)); - }; + // Check that we added the correct task. + auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, + const TaskTableData &d) { + ASSERT_EQ(id, task_id); + ASSERT_TRUE(TaskTableDataEqual(*data, d)); + }; - // Check that the lookup returns the added task. - auto lookup_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &d) { - ASSERT_EQ(id, task_id); - ASSERT_TRUE(TaskTableDataEqual(*data, d)); - test->Stop(); - }; + // Check that the lookup returns the added task. + auto lookup_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, + const TaskTableData &d) { + ASSERT_EQ(id, task_id); + ASSERT_TRUE(TaskTableDataEqual(*data, d)); + test->Stop(); + }; - // Check that the lookup does not return an empty entry. - auto failure_callback = [](gcs::RedisGcsClient *client, const TaskID &id) { - RAY_CHECK(false); - }; + // Check that the lookup does not return an empty entry. + auto failure_callback = [](gcs::RedisGcsClient *client, const TaskID &id) { + RAY_CHECK(false); + }; - // Add the task, then do a lookup. - RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback)); - RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, lookup_callback, - failure_callback)); - // Run the event loop. The loop will only stop if the Lookup callback is - // called (or an assertion failure). - test->Start(); -} + // Add the task, then do a lookup. + RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback)); + RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, lookup_callback, + failure_callback)); + // Run the event loop. The loop will only stop if the Lookup callback is + // called (or an assertion failure). + test->Start(); + } + + static void TestTableLookupFailure(const JobID &job_id, + std::shared_ptr client) { + TaskID task_id = RandomTaskId(); + + // Check that the lookup does not return data. + auto lookup_callback = [](gcs::RedisGcsClient *client, const TaskID &id, + const TaskTableData &d) { RAY_CHECK(false); }; + + // Check that the lookup returns an empty entry. + auto failure_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id) { + ASSERT_EQ(id, task_id); + test->Stop(); + }; + + // Lookup the task. We have not done any writes, so the key should be empty. + RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, lookup_callback, + failure_callback)); + // Run the event loop. The loop will only stop if the failure callback is + // called (or an assertion failure). + test->Start(); + } + + static void TestDeleteKeysFromTable( + const JobID &job_id, std::shared_ptr client, + std::vector> &data_vector, bool stop_at_end) { + std::vector ids; + TaskID task_id; + for (auto &data : data_vector) { + task_id = RandomTaskId(); + ids.push_back(task_id); + // Check that we added the correct object entries. + auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, + const TaskTableData &d) { + ASSERT_EQ(id, task_id); + ASSERT_TRUE(TaskTableDataEqual(*data, d)); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback)); + } + for (const auto &task_id : ids) { + auto task_lookup_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id, + const TaskTableData &data) { + ASSERT_EQ(id, task_id); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, + task_lookup_callback, nullptr)); + } + if (ids.size() == 1) { + client->raylet_task_table().Delete(job_id, ids[0]); + } else { + client->raylet_task_table().Delete(job_id, ids); + } + auto expected_failure_callback = [](RedisGcsClient *client, const TaskID &id) { + ASSERT_TRUE(true); + test->IncrementNumCallbacks(); + }; + auto undesired_callback = [](gcs::RedisGcsClient *client, const TaskID &id, + const TaskTableData &data) { ASSERT_TRUE(false); }; + for (size_t i = 0; i < ids.size(); ++i) { + RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, undesired_callback, + expected_failure_callback)); + } + if (stop_at_end) { + auto stop_callback = [](RedisGcsClient *client, const TaskID &id) { test->Stop(); }; + RAY_CHECK_OK( + client->raylet_task_table().Lookup(job_id, ids[0], nullptr, stop_callback)); + } + } + + static void TestTableSubscribeId(const JobID &job_id, + std::shared_ptr client) { + size_t num_modifications = 3; + + // Add a table entry. + TaskID task_id1 = RandomTaskId(); + + // Add a table entry at a second key. + TaskID task_id2 = RandomTaskId(); + + // The callback for a notification from the table. This should only be + // received for keys that we requested notifications for. + auto notification_callback = [task_id2, num_modifications]( + gcs::RedisGcsClient *client, const TaskID &id, + const TaskTableData &data) { + // Check that we only get notifications for the requested key. + ASSERT_EQ(id, task_id2); + // Check that we get notifications in the same order as the writes. + ASSERT_TRUE( + TaskTableDataEqual(data, *CreateTaskTableData(task_id2, test->NumCallbacks()))); + test->IncrementNumCallbacks(); + if (test->NumCallbacks() == num_modifications) { + test->Stop(); + } + }; + + // The failure callback should be called once since both keys start as empty. + bool failure_notification_received = false; + auto failure_callback = [task_id2, &failure_notification_received]( + gcs::RedisGcsClient *client, const TaskID &id) { + ASSERT_EQ(id, task_id2); + // The failure notification should be the first notification received. + ASSERT_EQ(test->NumCallbacks(), 0); + failure_notification_received = true; + }; + + // The callback for subscription success. Once we've subscribed, request + // notifications for only one of the keys, then write to both keys. + auto subscribe_callback = [job_id, task_id1, task_id2, + num_modifications](gcs::RedisGcsClient *client) { + // Request notifications for one of the keys. + RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( + job_id, task_id2, client->client_table().GetLocalClientId(), nullptr)); + // Write both keys. We should only receive notifications for the key that + // we requested them for. + for (uint64_t i = 0; i < num_modifications; i++) { + auto data = CreateTaskTableData(task_id1, i); + RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id1, data, nullptr)); + } + for (uint64_t i = 0; i < num_modifications; i++) { + auto data = CreateTaskTableData(task_id2, i); + RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id2, data, nullptr)); + } + }; + + // Subscribe to notifications for this client. This allows us to request and + // receive notifications for specific keys. + RAY_CHECK_OK(client->raylet_task_table().Subscribe( + job_id, client->client_table().GetLocalClientId(), notification_callback, + failure_callback, subscribe_callback)); + // Run the event loop. The loop will only stop if the registered subscription + // callback is called for the requested key. + test->Start(); + // Check that the failure callback was called since the key was initially + // empty. + ASSERT_TRUE(failure_notification_received); + // Check that we received one notification callback for each write to the + // requested key. + ASSERT_EQ(test->NumCallbacks(), num_modifications); + } + + static void TestTableSubscribeCancel(const JobID &job_id, + std::shared_ptr client) { + // Add a table entry. + const auto task_id = RandomTaskId(); + const int num_modifications = 3; + const auto data = CreateTaskTableData(task_id, 0); + RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr)); + + // The failure callback should not be called since all keys are non-empty + // when notifications are requested. + auto failure_callback = [](gcs::RedisGcsClient *client, const TaskID &id) { + RAY_CHECK(false); + }; + + // The callback for a notification from the table. This should only be + // received for keys that we requested notifications for. + auto notification_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id, + const TaskTableData &data) { + ASSERT_EQ(id, task_id); + // Check that we only get notifications for the first and last writes, + // since notifications are canceled in between. + if (test->NumCallbacks() == 0) { + ASSERT_TRUE(TaskTableDataEqual(data, *CreateTaskTableData(task_id, 0))); + } else { + ASSERT_TRUE(TaskTableDataEqual( + data, *CreateTaskTableData(task_id, num_modifications - 1))); + } + test->IncrementNumCallbacks(); + if (test->NumCallbacks() == num_modifications - 1) { + test->Stop(); + } + }; + + // The callback for a notification from the table. This should only be + // received for keys that we requested notifications for. + auto subscribe_callback = [job_id, task_id](gcs::RedisGcsClient *client) { + // Request notifications, then cancel immediately. We should receive a + // notification for the current value at the key. + RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( + job_id, task_id, client->client_table().GetLocalClientId(), nullptr)); + RAY_CHECK_OK(client->raylet_task_table().CancelNotifications( + job_id, task_id, client->client_table().GetLocalClientId(), nullptr)); + // Write to the key. Since we canceled notifications, we should not receive + // a notification for these writes. + for (uint64_t i = 1; i < num_modifications; i++) { + auto data = CreateTaskTableData(task_id, i); + RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr)); + } + // Request notifications again. We should receive a notification for the + // current value at the key. + RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( + job_id, task_id, client->client_table().GetLocalClientId(), nullptr)); + }; + + // Subscribe to notifications for this client. This allows us to request and + // receive notifications for specific keys. + RAY_CHECK_OK(client->raylet_task_table().Subscribe( + job_id, client->client_table().GetLocalClientId(), notification_callback, + failure_callback, subscribe_callback)); + // Run the event loop. The loop will only stop if the registered subscription + // callback is called for the requested key. + test->Start(); + // Check that we received a notification callback for the first and least + // writes to the key, since notifications are canceled in between. + ASSERT_EQ(test->NumCallbacks(), 2); + } +}; // Convenient macro to test across {ae, asio} x {regular, chain} x {the tests}. // Undefined at the end. -#define TEST_MACRO(FIXTURE, TEST) \ - TEST_F(FIXTURE, TEST) { \ - test = this; \ - TEST(job_id_, client_); \ +#define TEST_TASK_TABLE_MACRO(FIXTURE, TEST) \ + TEST_F(FIXTURE, TEST) { \ + test = this; \ + TaskTableTestHelper::TEST(job_id_, client_); \ } -TEST_MACRO(TestGcsWithAsio, TestTableLookup); -#if RAY_USE_NEW_GCS -TEST_MACRO(TestGcsWithChainAsio, TestTableLookup); -#endif +TEST_TASK_TABLE_MACRO(TestGcsWithAsio, TestTableLookup); void TestLogLookup(const JobID &job_id, std::shared_ptr client) { // Append some entries to the log at an object ID. @@ -196,32 +405,7 @@ TEST_F(TestGcsWithAsio, TestLogLookup) { TestLogLookup(job_id_, client_); } -void TestTableLookupFailure(const JobID &job_id, - std::shared_ptr client) { - TaskID task_id = RandomTaskId(); - - // Check that the lookup does not return data. - auto lookup_callback = [](gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &d) { RAY_CHECK(false); }; - - // Check that the lookup returns an empty entry. - auto failure_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id) { - ASSERT_EQ(id, task_id); - test->Stop(); - }; - - // Lookup the task. We have not done any writes, so the key should be empty. - RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, lookup_callback, - failure_callback)); - // Run the event loop. The loop will only stop if the failure callback is - // called (or an assertion failure). - test->Start(); -} - -TEST_MACRO(TestGcsWithAsio, TestTableLookupFailure); -#if RAY_USE_NEW_GCS -TEST_MACRO(TestGcsWithChainAsio, TestTableLookupFailure); -#endif +TEST_TASK_TABLE_MACRO(TestGcsWithAsio, TestTableLookupFailure); void TestLogAppendAt(const JobID &job_id, std::shared_ptr client) { TaskID task_id = RandomTaskId(); @@ -395,55 +579,6 @@ void TestDeleteKeysFromLog( } } -void TestDeleteKeysFromTable(const JobID &job_id, - std::shared_ptr client, - std::vector> &data_vector, - bool stop_at_end) { - std::vector ids; - TaskID task_id; - for (auto &data : data_vector) { - task_id = RandomTaskId(); - ids.push_back(task_id); - // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &d) { - ASSERT_EQ(id, task_id); - ASSERT_TRUE(TaskTableDataEqual(*data, d)); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback)); - } - for (const auto &task_id : ids) { - auto task_lookup_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &data) { - ASSERT_EQ(id, task_id); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, task_lookup_callback, - nullptr)); - } - if (ids.size() == 1) { - client->raylet_task_table().Delete(job_id, ids[0]); - } else { - client->raylet_task_table().Delete(job_id, ids); - } - auto expected_failure_callback = [](RedisGcsClient *client, const TaskID &id) { - ASSERT_TRUE(true); - test->IncrementNumCallbacks(); - }; - auto undesired_callback = [](gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &data) { ASSERT_TRUE(false); }; - for (size_t i = 0; i < ids.size(); ++i) { - RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, undesired_callback, - expected_failure_callback)); - } - if (stop_at_end) { - auto stop_callback = [](RedisGcsClient *client, const TaskID &id) { test->Stop(); }; - RAY_CHECK_OK( - client->raylet_task_table().Lookup(job_id, ids[0], nullptr, stop_callback)); - } -} - void TestDeleteKeysFromSet(const JobID &job_id, std::shared_ptr client, std::vector> &data_vector) { @@ -523,21 +658,21 @@ void TestDeleteKeys(const JobID &job_id, std::shared_ptr cl std::vector> task_vector; auto AppendTaskData = [&task_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - task_vector.push_back(CreateTaskTableData(RandomTaskId())); + task_vector.push_back(TaskTableTestHelper::CreateTaskTableData(RandomTaskId())); } }; AppendTaskData(1); ASSERT_EQ(task_vector.size(), 1); - TestDeleteKeysFromTable(job_id, client, task_vector, false); + TaskTableTestHelper::TestDeleteKeysFromTable(job_id, client, task_vector, false); AppendTaskData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); ASSERT_GT(task_vector.size(), 1); ASSERT_LT(task_vector.size(), RayConfig::instance().maximum_gcs_deletion_batch_size()); - TestDeleteKeysFromTable(job_id, client, task_vector, false); + TaskTableTestHelper::TestDeleteKeysFromTable(job_id, client, task_vector, false); AppendTaskData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); ASSERT_GT(task_vector.size(), RayConfig::instance().maximum_gcs_deletion_batch_size()); - TestDeleteKeysFromTable(job_id, client, task_vector, true); + TaskTableTestHelper::TestDeleteKeysFromTable(job_id, client, task_vector, true); test->Start(); ASSERT_GT(test->NumCallbacks(), @@ -841,81 +976,7 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeAll) { TestSetSubscribeAll(job_id_, client_); } -void TestTableSubscribeId(const JobID &job_id, - std::shared_ptr client) { - size_t num_modifications = 3; - - // Add a table entry. - TaskID task_id1 = RandomTaskId(); - - // Add a table entry at a second key. - TaskID task_id2 = RandomTaskId(); - - // The callback for a notification from the table. This should only be - // received for keys that we requested notifications for. - auto notification_callback = [task_id2, num_modifications](gcs::RedisGcsClient *client, - const TaskID &id, - const TaskTableData &data) { - // Check that we only get notifications for the requested key. - ASSERT_EQ(id, task_id2); - // Check that we get notifications in the same order as the writes. - ASSERT_TRUE( - TaskTableDataEqual(data, *CreateTaskTableData(task_id2, test->NumCallbacks()))); - test->IncrementNumCallbacks(); - if (test->NumCallbacks() == num_modifications) { - test->Stop(); - } - }; - - // The failure callback should be called once since both keys start as empty. - bool failure_notification_received = false; - auto failure_callback = [task_id2, &failure_notification_received]( - gcs::RedisGcsClient *client, const TaskID &id) { - ASSERT_EQ(id, task_id2); - // The failure notification should be the first notification received. - ASSERT_EQ(test->NumCallbacks(), 0); - failure_notification_received = true; - }; - - // The callback for subscription success. Once we've subscribed, request - // notifications for only one of the keys, then write to both keys. - auto subscribe_callback = [job_id, task_id1, task_id2, - num_modifications](gcs::RedisGcsClient *client) { - // Request notifications for one of the keys. - RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - job_id, task_id2, client->client_table().GetLocalClientId(), nullptr)); - // Write both keys. We should only receive notifications for the key that - // we requested them for. - for (uint64_t i = 0; i < num_modifications; i++) { - auto data = CreateTaskTableData(task_id1, i); - RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id1, data, nullptr)); - } - for (uint64_t i = 0; i < num_modifications; i++) { - auto data = CreateTaskTableData(task_id2, i); - RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id2, data, nullptr)); - } - }; - - // Subscribe to notifications for this client. This allows us to request and - // receive notifications for specific keys. - RAY_CHECK_OK(client->raylet_task_table().Subscribe( - job_id, client->client_table().GetLocalClientId(), notification_callback, - failure_callback, subscribe_callback)); - // Run the event loop. The loop will only stop if the registered subscription - // callback is called for the requested key. - test->Start(); - // Check that the failure callback was called since the key was initially - // empty. - ASSERT_TRUE(failure_notification_received); - // Check that we received one notification callback for each write to the - // requested key. - ASSERT_EQ(test->NumCallbacks(), num_modifications); -} - -TEST_MACRO(TestGcsWithAsio, TestTableSubscribeId); -#if RAY_USE_NEW_GCS -TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeId); -#endif +TEST_TASK_TABLE_MACRO(TestGcsWithAsio, TestTableSubscribeId); TEST_F(TestGcsWithAsio, TestLogSubscribeId) { test = this; @@ -998,77 +1059,7 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeId) { TestSetSubscribeId(job_id_, client_); } -void TestTableSubscribeCancel(const JobID &job_id, - std::shared_ptr client) { - // Add a table entry. - const auto task_id = RandomTaskId(); - const int num_modifications = 3; - const auto data = CreateTaskTableData(task_id, 0); - RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr)); - - // The failure callback should not be called since all keys are non-empty - // when notifications are requested. - auto failure_callback = [](gcs::RedisGcsClient *client, const TaskID &id) { - RAY_CHECK(false); - }; - - // The callback for a notification from the table. This should only be - // received for keys that we requested notifications for. - auto notification_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &data) { - ASSERT_EQ(id, task_id); - // Check that we only get notifications for the first and last writes, - // since notifications are canceled in between. - if (test->NumCallbacks() == 0) { - ASSERT_TRUE(TaskTableDataEqual(data, *CreateTaskTableData(task_id, 0))); - } else { - ASSERT_TRUE( - TaskTableDataEqual(data, *CreateTaskTableData(task_id, num_modifications - 1))); - } - test->IncrementNumCallbacks(); - if (test->NumCallbacks() == num_modifications - 1) { - test->Stop(); - } - }; - - // The callback for a notification from the table. This should only be - // received for keys that we requested notifications for. - auto subscribe_callback = [job_id, task_id](gcs::RedisGcsClient *client) { - // Request notifications, then cancel immediately. We should receive a - // notification for the current value at the key. - RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - job_id, task_id, client->client_table().GetLocalClientId(), nullptr)); - RAY_CHECK_OK(client->raylet_task_table().CancelNotifications( - job_id, task_id, client->client_table().GetLocalClientId(), nullptr)); - // Write to the key. Since we canceled notifications, we should not receive - // a notification for these writes. - for (uint64_t i = 1; i < num_modifications; i++) { - auto data = CreateTaskTableData(task_id, i); - RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr)); - } - // Request notifications again. We should receive a notification for the - // current value at the key. - RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - job_id, task_id, client->client_table().GetLocalClientId(), nullptr)); - }; - - // Subscribe to notifications for this client. This allows us to request and - // receive notifications for specific keys. - RAY_CHECK_OK(client->raylet_task_table().Subscribe( - job_id, client->client_table().GetLocalClientId(), notification_callback, - failure_callback, subscribe_callback)); - // Run the event loop. The loop will only stop if the registered subscription - // callback is called for the requested key. - test->Start(); - // Check that we received a notification callback for the first and least - // writes to the key, since notifications are canceled in between. - ASSERT_EQ(test->NumCallbacks(), 2); -} - -TEST_MACRO(TestGcsWithAsio, TestTableSubscribeCancel); -#if RAY_USE_NEW_GCS -TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeCancel); -#endif +TEST_TASK_TABLE_MACRO(TestGcsWithAsio, TestTableSubscribeCancel); TEST_F(TestGcsWithAsio, TestLogSubscribeCancel) { test = this; @@ -1431,7 +1422,7 @@ TEST_F(TestGcsWithAsio, TestHashTable) { TestHashTable(job_id_, client_); } -#undef TEST_MACRO +#undef TEST_TASK_TABLE_MACRO } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/test/redis_job_info_accessor_test.cc b/src/ray/gcs/test/redis_job_info_accessor_test.cc index d93f7c40d..68f43f506 100644 --- a/src/ray/gcs/test/redis_job_info_accessor_test.cc +++ b/src/ray/gcs/test/redis_job_info_accessor_test.cc @@ -1,4 +1,3 @@ -#include "ray/gcs/redis_job_info_accessor.h" #include #include "gtest/gtest.h" #include "ray/gcs/pb_util.h" diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 863761440..b7ed3e62f 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -1,7 +1,7 @@ #include "lineage_cache.h" -#include "ray/stats/stats.h" - #include +#include "ray/gcs/redis_gcs_client.h" +#include "ray/stats/stats.h" namespace ray { @@ -152,16 +152,15 @@ const std::unordered_set &Lineage::GetChildren(const TaskID &task_id) co } } -LineageCache::LineageCache(const ClientID &client_id, - gcs::TableInterface &task_storage, - gcs::PubsubInterface &task_pubsub, +LineageCache::LineageCache(std::shared_ptr gcs_client, uint64_t max_lineage_size) - : client_id_(client_id), task_storage_(task_storage), task_pubsub_(task_pubsub) {} + : gcs_client_(gcs_client) {} /// A helper function to add some uncommitted lineage to the local cache. void LineageCache::AddUncommittedLineage(const TaskID &task_id, const Lineage &uncommitted_lineage) { - RAY_LOG(DEBUG) << "Adding uncommitted task " << task_id << " on " << client_id_; + RAY_LOG(DEBUG) << "Adding uncommitted task " << task_id << " on " + << gcs_client_->client_table().GetLocalClientId(); // If the entry is not found in the lineage to merge, then we stop since // there is nothing to copy into the merged lineage. auto entry = uncommitted_lineage.GetEntry(task_id); @@ -192,7 +191,8 @@ bool LineageCache::CommitTask(const Task &task) { return true; } const TaskID task_id = task.GetTaskSpecification().TaskId(); - RAY_LOG(DEBUG) << "Committing task " << task_id << " on " << client_id_; + RAY_LOG(DEBUG) << "Committing task " << task_id << " on " + << gcs_client_->client_table().GetLocalClientId(); if (lineage_.SetEntry(task, GcsStatus::UNCOMMITTED) || lineage_.GetEntry(task_id)->GetStatus() == GcsStatus::UNCOMMITTED) { @@ -275,17 +275,17 @@ void LineageCache::FlushTask(const TaskID &task_id) { RAY_CHECK(entry); RAY_CHECK(entry->GetStatus() < GcsStatus::COMMITTING); - gcs::raylet::TaskTable::WriteCallback task_callback = - [this](ray::gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &data) { HandleEntryCommitted(id); }; + auto task_callback = [this, task_id](Status status) { + RAY_CHECK(status.ok()); + HandleEntryCommitted(task_id); + }; auto task = lineage_.GetEntry(task_id); auto task_data = std::make_shared(); task_data->mutable_task()->mutable_task_spec()->CopyFrom( task->TaskData().GetTaskSpecification().GetMessage()); task_data->mutable_task()->mutable_task_execution_spec()->CopyFrom( task->TaskData().GetTaskExecutionSpec().GetMessage()); - RAY_CHECK_OK(task_storage_.Add(JobID(task->TaskData().GetTaskSpecification().JobId()), - task_id, task_data, task_callback)); + RAY_CHECK_OK(gcs_client_->Tasks().AsyncAdd(task_data, task_callback)); // We successfully wrote the task, so mark it as committing. // TODO(swang): Use a batched interface and write with all object entries. @@ -296,10 +296,12 @@ bool LineageCache::SubscribeTask(const TaskID &task_id) { auto inserted = subscribed_tasks_.insert(task_id); bool unsubscribed = inserted.second; if (unsubscribed) { - // Request notifications for the task if we haven't already requested - // notifications for it. - RAY_CHECK_OK(task_pubsub_.RequestNotifications(JobID::Nil(), task_id, client_id_, - /*done*/ nullptr)); + auto subscribe = [this](const TaskID &task_id, const TaskTableData) { + HandleEntryCommitted(task_id); + }; + // Subscribe to the task. + RAY_CHECK_OK(gcs_client_->Tasks().AsyncSubscribe(task_id, subscribe, + /*done*/ nullptr)); } // Return whether we were previously unsubscribed to this task and are now // subscribed. @@ -310,10 +312,8 @@ bool LineageCache::UnsubscribeTask(const TaskID &task_id) { auto it = subscribed_tasks_.find(task_id); bool subscribed = (it != subscribed_tasks_.end()); if (subscribed) { - // Cancel notifications for the task if we previously requested - // notifications for it. - RAY_CHECK_OK(task_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_, - /*done*/ nullptr)); + // Cancel subscribe to the task. + RAY_CHECK_OK(gcs_client_->Tasks().AsyncUnsubscribe(task_id, /*done*/ nullptr)); subscribed_tasks_.erase(it); } // Return whether we were previously subscribed to this task and are now @@ -339,7 +339,8 @@ void LineageCache::EvictTask(const TaskID &task_id) { } // Evict the task. - RAY_LOG(DEBUG) << "Evicting task " << task_id << " on " << client_id_; + RAY_LOG(DEBUG) << "Evicting task " << task_id << " on " + << gcs_client_->client_table().GetLocalClientId(); lineage_.PopEntry(task_id); // Try to evict the children of the evict task. These are the tasks that have // a dependency on the evicted task. diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index b41e69278..7c14c6255 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -12,7 +12,7 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/common/task/task.h" -#include "ray/gcs/tables.h" +#include "ray/gcs/redis_gcs_client.h" namespace ray { @@ -209,9 +209,8 @@ class LineageCache { public: /// Create a lineage cache for the given task storage system. /// TODO(swang): Pass in the policy (interface?). - LineageCache(const ClientID &client_id, - gcs::TableInterface &task_storage, - gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size); + LineageCache(std::shared_ptr gcs_client, + uint64_t max_lineage_size); /// Asynchronously commit a task to the GCS. /// @@ -303,19 +302,13 @@ class LineageCache { /// was successful (whether we were subscribed). bool UnsubscribeTask(const TaskID &task_id); - /// The client ID, used to request notifications for specific tasks. - /// TODO(swang): Move the ClientID into the generic Table implementation. - ClientID client_id_; - /// The durable storage system for task information. - gcs::TableInterface &task_storage_; - /// The pubsub storage system for task information. This can be used to - /// request notifications for the commit of a task entry. - gcs::PubsubInterface &task_pubsub_; + /// A client connection to the GCS. + std::shared_ptr gcs_client_; /// All tasks and objects that we are responsible for writing back to the /// GCS, and the tasks and objects in their lineage. Lineage lineage_; - /// The tasks that we've subscribed to notifications for from the pubsub - /// storage system. We will receive a notification for these tasks on commit. + /// The tasks that we've subscribed to. + /// We will receive a notification for these tasks on commit. std::unordered_set subscribed_tasks_; }; diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 7250a7627..4de75ee4a 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -1,4 +1,5 @@ #include +#include #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -9,6 +10,8 @@ #include "ray/common/task/task_util.h" #include "ray/gcs/callback.h" +#include "ray/gcs/redis_accessor.h" +#include "ray/gcs/redis_gcs_client.h" #include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/lineage_cache.h" @@ -23,66 +26,68 @@ const static JobID kDefaultJobId = JobID::FromInt(1); const static TaskID kDefaultDriverTaskId = TaskID::ForDriverTask(kDefaultJobId); -class MockGcs : public gcs::TableInterface, - public gcs::PubsubInterface { - public: - MockGcs() {} +class MockGcsClient; - void Subscribe(const gcs::raylet::TaskTable::WriteCallback ¬ification_callback) { +class MockTaskInfoAccessor : public gcs::RedisTaskInfoAccessor { + public: + MockTaskInfoAccessor(gcs::RedisGcsClient *gcs_client) + : RedisTaskInfoAccessor(gcs_client) {} + + virtual ~MockTaskInfoAccessor() {} + + void RegisterSubscribeCallback( + const gcs::SubscribeCallback ¬ification_callback) { notification_callback_ = notification_callback; } - Status Add(const JobID &job_id, const TaskID &task_id, - const std::shared_ptr &task_data, - const gcs::TableInterface::WriteCallback &done) { + Status AsyncAdd(const std::shared_ptr &task_data, + const gcs::StatusCallback &done) { + TaskID task_id = TaskID::FromBinary(task_data->task().task_spec().task_id()); task_table_[task_id] = task_data; auto callback = done; // If we requested notifications for this task ID, send the notification as // part of the callback. if (subscribed_tasks_.count(task_id) == 1) { - callback = [this, done](ray::gcs::RedisGcsClient *client, const TaskID &task_id, - const TaskTableData &data) { - done(client, task_id, data); + callback = [this, done, task_id, task_data](Status status) { + done(status); // If we're subscribed to the task to be added, also send a // subscription notification. - notification_callback_(client, task_id, data); + notification_callback_(task_id, *task_data); }; } - callbacks_.push_back( - std::pair(callback, task_id)); + callbacks_.push_back({callback, task_id}); num_task_adds_++; return ray::Status::OK(); } - Status RemoteAdd(const TaskID &task_id, std::shared_ptr task_data) { + Status RemoteAdd(std::shared_ptr task_data) { + TaskID task_id = TaskID::FromBinary(task_data->task().task_spec().task_id()); task_table_[task_id] = task_data; // Send a notification after the add if the lineage cache requested // notifications for this key. bool send_notification = (subscribed_tasks_.count(task_id) == 1); - auto callback = [this, send_notification](ray::gcs::RedisGcsClient *client, - const TaskID &task_id, - const TaskTableData &data) { + auto callback = [this, send_notification, task_id, task_data](Status status) { if (send_notification) { - notification_callback_(client, task_id, data); + notification_callback_(task_id, *task_data); } }; - return Add(JobID::Nil(), task_id, task_data, callback); + return AsyncAdd(task_data, callback); } - Status RequestNotifications(const JobID &job_id, const TaskID &task_id, - const ClientID &client_id, - const gcs::StatusCallback &done) { + Status AsyncSubscribe( + const TaskID &task_id, + const gcs::SubscribeCallback ¬ification_callback, + const gcs::StatusCallback &done) { subscribed_tasks_.insert(task_id); if (task_table_.count(task_id) == 1) { - callbacks_.push_back({notification_callback_, task_id}); + notification_callbacks_.push_back({notification_callback_, task_id}); } num_requested_notifications_ += 1; return ray::Status::OK(); } - Status CancelNotifications(const JobID &job_id, const TaskID &task_id, - const ClientID &client_id, const gcs::StatusCallback &done) { + Status AsyncUnsubscribe(const TaskID &task_id, const gcs::StatusCallback &done) { subscribed_tasks_.erase(task_id); return ray::Status::OK(); } @@ -91,7 +96,10 @@ class MockGcs : public gcs::TableInterface, auto callbacks = std::move(callbacks_); callbacks_.clear(); for (const auto &callback : callbacks) { - callback.first(NULL, callback.second, *task_table_[callback.second]); + callback.first(Status::OK()); + } + for (const auto &callback : notification_callbacks_) { + callback.first(callback.second, *task_table_[callback.second]); } } @@ -107,32 +115,59 @@ class MockGcs : public gcs::TableInterface, private: std::unordered_map> task_table_; - std::vector> callbacks_; - gcs::raylet::TaskTable::WriteCallback notification_callback_; + std::vector> callbacks_; + + typedef gcs::SubscribeCallback TaskSubscribeCallback; + TaskSubscribeCallback notification_callback_; + std::vector> notification_callbacks_; + std::unordered_set subscribed_tasks_; int num_requested_notifications_ = 0; int num_task_adds_ = 0; }; +class MockGcsClient : public gcs::RedisGcsClient { + public: + MockGcsClient(const gcs::GcsClientOptions &options) : RedisGcsClient(options) { + client_table_fake_.reset( + new gcs::ClientTable({nullptr}, this, ClientID::FromRandom())); + task_table_fake_.reset(new gcs::raylet::TaskTable({nullptr}, this)); + task_accessor_.reset(new MockTaskInfoAccessor(this)); + } + + gcs::ClientTable &client_table() { return *client_table_fake_; } + + gcs::raylet::TaskTable &raylet_task_table() { return *task_table_fake_; } + + MockTaskInfoAccessor &MockTasks() { + return *dynamic_cast(task_accessor_.get()); + } + + private: + std::unique_ptr client_table_fake_; + std::unique_ptr task_table_fake_; +}; + class LineageCacheTest : public ::testing::Test { public: - LineageCacheTest() - : max_lineage_size_(10), - num_notifications_(0), - mock_gcs_(), - lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) { - mock_gcs_.Subscribe([this](ray::gcs::RedisGcsClient *client, const TaskID &task_id, - const TaskTableData &data) { - lineage_cache_.HandleEntryCommitted(task_id); - num_notifications_++; - }); + LineageCacheTest() : max_lineage_size_(10), num_notifications_(0) { + gcs::GcsClientOptions options("10.10.10.10", 12100, ""); + mock_gcs_ = std::make_shared(options); + + lineage_cache_.reset(new LineageCache(mock_gcs_, max_lineage_size_)); + + mock_gcs_->MockTasks().RegisterSubscribeCallback( + [this](const TaskID &task_id, const TaskTableData &data) { + lineage_cache_->HandleEntryCommitted(task_id); + num_notifications_++; + }); } protected: uint64_t max_lineage_size_; uint64_t num_notifications_; - MockGcs mock_gcs_; - LineageCache lineage_cache_; + std::shared_ptr mock_gcs_; + std::unique_ptr lineage_cache_; }; static inline Task ExampleTask(const std::vector &arguments, @@ -179,7 +214,7 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineage) { // Insert two independent chains of tasks. std::vector tasks1; auto return_values1 = - InsertTaskChain(lineage_cache_, tasks1, 3, std::vector(), 1); + InsertTaskChain(*lineage_cache_, tasks1, 3, std::vector(), 1); std::vector task_ids1; for (const auto &task : tasks1) { task_ids1.push_back(task.GetTaskSpecification().TaskId()); @@ -187,7 +222,7 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineage) { std::vector tasks2; auto return_values2 = - InsertTaskChain(lineage_cache_, tasks2, 2, std::vector(), 2); + InsertTaskChain(*lineage_cache_, tasks2, 2, std::vector(), 2); std::vector task_ids2; for (const auto &task : tasks2) { task_ids2.push_back(task.GetTaskSpecification().TaskId()); @@ -195,7 +230,7 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineage) { // Get the uncommitted lineage for the last task (the leaf) of one of the chains. auto uncommitted_lineage = - lineage_cache_.GetUncommittedLineage(task_ids1.back(), ClientID::Nil()); + lineage_cache_->GetUncommittedLineage(task_ids1.back(), ClientID::Nil()); // Check that the uncommitted lineage is exactly equal to the first chain of tasks. ASSERT_EQ(task_ids1.size(), uncommitted_lineage.GetEntries().size()); for (auto &task_id : task_ids1) { @@ -208,7 +243,7 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineage) { std::vector combined_arguments = return_values1; combined_arguments.insert(combined_arguments.end(), return_values2.begin(), return_values2.end()); - InsertTaskChain(lineage_cache_, combined_tasks, 1, combined_arguments, 1); + InsertTaskChain(*lineage_cache_, combined_tasks, 1, combined_arguments, 1); std::vector combined_task_ids; for (const auto &task : combined_tasks) { combined_task_ids.push_back(task.GetTaskSpecification().TaskId()); @@ -216,7 +251,7 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineage) { // Get the uncommitted lineage for the inserted task. uncommitted_lineage = - lineage_cache_.GetUncommittedLineage(combined_task_ids.back(), ClientID::Nil()); + lineage_cache_->GetUncommittedLineage(combined_task_ids.back(), ClientID::Nil()); // Check that the uncommitted lineage is exactly equal to the entire set of // tasks inserted so far. ASSERT_EQ(combined_task_ids.size(), uncommitted_lineage.GetEntries().size()); @@ -229,13 +264,13 @@ TEST_F(LineageCacheTest, TestDuplicateUncommittedLineage) { // Insert a chain of tasks. std::vector tasks; auto return_values = - InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + InsertTaskChain(*lineage_cache_, tasks, 3, std::vector(), 1); std::vector task_ids; for (const auto &task : tasks) { task_ids.push_back(task.GetTaskSpecification().TaskId()); } // Check that we subscribed to each of the uncommitted tasks. - ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size()); + ASSERT_EQ(mock_gcs_->MockTasks().NumRequestedNotifications(), task_ids.size()); // Check that if we add the same tasks as UNCOMMITTED again, we do not issue // duplicate subscribe requests. @@ -243,21 +278,21 @@ TEST_F(LineageCacheTest, TestDuplicateUncommittedLineage) { for (const auto &task : tasks) { duplicate_lineage.SetEntry(task, GcsStatus::UNCOMMITTED); } - lineage_cache_.AddUncommittedLineage(task_ids.back(), duplicate_lineage); - ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size()); + lineage_cache_->AddUncommittedLineage(task_ids.back(), duplicate_lineage); + ASSERT_EQ(mock_gcs_->MockTasks().NumRequestedNotifications(), task_ids.size()); // Check that if we commit one of the tasks, we still do not issue any // duplicate subscribe requests. - lineage_cache_.CommitTask(tasks.front()); - lineage_cache_.AddUncommittedLineage(task_ids.back(), duplicate_lineage); - ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size()); + lineage_cache_->CommitTask(tasks.front()); + lineage_cache_->AddUncommittedLineage(task_ids.back(), duplicate_lineage); + ASSERT_EQ(mock_gcs_->MockTasks().NumRequestedNotifications(), task_ids.size()); } TEST_F(LineageCacheTest, TestMarkTaskAsForwarded) { // Insert chain of tasks. std::vector tasks; auto return_values = - InsertTaskChain(lineage_cache_, tasks, 4, std::vector(), 1); + InsertTaskChain(*lineage_cache_, tasks, 4, std::vector(), 1); std::vector task_ids; for (const auto &task : tasks) { task_ids.push_back(task.GetTaskSpecification().TaskId()); @@ -267,12 +302,12 @@ TEST_F(LineageCacheTest, TestMarkTaskAsForwarded) { auto node_id2 = ClientID::FromRandom(); auto forwarded_task_id = task_ids[task_ids.size() - 2]; auto remaining_task_id = task_ids[task_ids.size() - 1]; - lineage_cache_.MarkTaskAsForwarded(forwarded_task_id, node_id); + lineage_cache_->MarkTaskAsForwarded(forwarded_task_id, node_id); auto uncommitted_lineage = - lineage_cache_.GetUncommittedLineage(remaining_task_id, node_id); + lineage_cache_->GetUncommittedLineage(remaining_task_id, node_id); auto uncommitted_lineage_all = - lineage_cache_.GetUncommittedLineage(remaining_task_id, node_id2); + lineage_cache_->GetUncommittedLineage(remaining_task_id, node_id2); ASSERT_EQ(1, uncommitted_lineage.GetEntries().size()); ASSERT_EQ(4, uncommitted_lineage_all.GetEntries().size()); @@ -281,7 +316,7 @@ TEST_F(LineageCacheTest, TestMarkTaskAsForwarded) { // Check that lineage of requested task includes itself, regardless of whether // it has been forwarded before. auto uncommitted_lineage_forwarded = - lineage_cache_.GetUncommittedLineage(forwarded_task_id, node_id); + lineage_cache_->GetUncommittedLineage(forwarded_task_id, node_id); ASSERT_EQ(1, uncommitted_lineage_forwarded.GetEntries().size()); } @@ -289,30 +324,30 @@ TEST_F(LineageCacheTest, TestWritebackReady) { // Insert a chain of dependent tasks. size_t num_tasks_flushed = 0; std::vector tasks; - InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + InsertTaskChain(*lineage_cache_, tasks, 3, std::vector(), 1); // Check that when no tasks have been marked as ready, we do not flush any // entries. - ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); + ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed); // Check that after marking the first task as ready, we flush only that task. - ASSERT_TRUE(lineage_cache_.CommitTask(tasks.front())); + ASSERT_TRUE(lineage_cache_->CommitTask(tasks.front())); num_tasks_flushed++; - ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); + ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed); } TEST_F(LineageCacheTest, TestWritebackOrder) { // Insert a chain of dependent tasks. std::vector tasks; - InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + InsertTaskChain(*lineage_cache_, tasks, 3, std::vector(), 1); size_t num_tasks_flushed = tasks.size(); // Mark all tasks as ready. All tasks should be flushed. for (const auto &task : tasks) { - ASSERT_TRUE(lineage_cache_.CommitTask(task)); + ASSERT_TRUE(lineage_cache_->CommitTask(task)); } - ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); + ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed); } TEST_F(LineageCacheTest, TestEvictChain) { @@ -331,42 +366,45 @@ TEST_F(LineageCacheTest, TestEvictChain) { uncommitted_lineage.SetEntry(task, GcsStatus::UNCOMMITTED); } // Mark the last task as ready to flush. - lineage_cache_.AddUncommittedLineage(tasks.back().GetTaskSpecification().TaskId(), - uncommitted_lineage); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), tasks.size()); - ASSERT_TRUE(lineage_cache_.CommitTask(tasks.back())); + lineage_cache_->AddUncommittedLineage(tasks.back().GetTaskSpecification().TaskId(), + uncommitted_lineage); + ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), tasks.size()); + ASSERT_TRUE(lineage_cache_->CommitTask(tasks.back())); num_tasks_flushed++; - ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); + ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed); // Flush acknowledgements. The lineage cache should receive the commit for // the flushed task, but its lineage should not be evicted yet. - mock_gcs_.Flush(); + mock_gcs_->MockTasks().Flush(); ASSERT_EQ(lineage_cache_ - .GetUncommittedLineage(tasks.back().GetTaskSpecification().TaskId(), - ClientID::Nil()) + ->GetUncommittedLineage(tasks.back().GetTaskSpecification().TaskId(), + ClientID::Nil()) .GetEntries() .size(), tasks.size()); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), tasks.size()); + ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), tasks.size()); // Simulate executing the task on a remote node and adding it to the GCS. + auto task_id = tasks.at(1).GetTaskSpecification().TaskId(); auto task_data = std::make_shared(); - RAY_CHECK_OK( - mock_gcs_.RemoteAdd(tasks.at(1).GetTaskSpecification().TaskId(), task_data)); - mock_gcs_.Flush(); + task_data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary()); + RAY_CHECK_OK(mock_gcs_->MockTasks().RemoteAdd(task_data)); + mock_gcs_->MockTasks().Flush(); ASSERT_EQ(lineage_cache_ - .GetUncommittedLineage(tasks.back().GetTaskSpecification().TaskId(), - ClientID::Nil()) + ->GetUncommittedLineage(tasks.back().GetTaskSpecification().TaskId(), + ClientID::Nil()) .GetEntries() .size(), tasks.size()); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), tasks.size()); + ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), tasks.size()); // Simulate executing the task on a remote node and adding it to the GCS. - RAY_CHECK_OK( - mock_gcs_.RemoteAdd(tasks.at(0).GetTaskSpecification().TaskId(), task_data)); - mock_gcs_.Flush(); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 0); - ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0); + task_id = tasks.at(0).GetTaskSpecification().TaskId(); + auto task_data_2 = std::make_shared(); + task_data_2->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary()); + RAY_CHECK_OK(mock_gcs_->MockTasks().RemoteAdd(task_data_2)); + mock_gcs_->MockTasks().Flush(); + ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), 0); + ASSERT_EQ(lineage_cache_->GetLineage().GetChildrenSize(), 0); } TEST_F(LineageCacheTest, TestEvictManyParents) { @@ -378,50 +416,50 @@ TEST_F(LineageCacheTest, TestEvictManyParents) { parent_tasks.push_back(task); arguments.push_back(task.GetTaskSpecification().ReturnIdForPlasma(0)); auto lineage = CreateSingletonLineage(task); - lineage_cache_.AddUncommittedLineage(task.GetTaskSpecification().TaskId(), lineage); + lineage_cache_->AddUncommittedLineage(task.GetTaskSpecification().TaskId(), lineage); } // Create a child task that is dependent on all of the previous tasks. auto child_task = ExampleTask(arguments, 1); auto lineage = CreateSingletonLineage(child_task); - lineage_cache_.AddUncommittedLineage(child_task.GetTaskSpecification().TaskId(), - lineage); + lineage_cache_->AddUncommittedLineage(child_task.GetTaskSpecification().TaskId(), + lineage); // Flush the child task. Make sure that it remains in the cache, since none // of its parents have been committed yet, and that the uncommitted lineage // still includes all of the parent tasks. size_t total_tasks = parent_tasks.size() + 1; - lineage_cache_.CommitTask(child_task); - mock_gcs_.Flush(); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), total_tasks); + lineage_cache_->CommitTask(child_task); + mock_gcs_->MockTasks().Flush(); + ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), total_tasks); ASSERT_EQ(lineage_cache_ - .GetUncommittedLineage(child_task.GetTaskSpecification().TaskId(), - ClientID::Nil()) + ->GetUncommittedLineage(child_task.GetTaskSpecification().TaskId(), + ClientID::Nil()) .GetEntries() .size(), total_tasks); // Flush each parent task and check for eviction safety. for (const auto &parent_task : parent_tasks) { - lineage_cache_.CommitTask(parent_task); - mock_gcs_.Flush(); + lineage_cache_->CommitTask(parent_task); + mock_gcs_->MockTasks().Flush(); total_tasks--; if (total_tasks > 1) { // Each task should be evicted as soon as its commit is acknowledged, // since the parent tasks have no dependencies. - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), total_tasks); + ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), total_tasks); ASSERT_EQ(lineage_cache_ - .GetUncommittedLineage(child_task.GetTaskSpecification().TaskId(), - ClientID::Nil()) + ->GetUncommittedLineage(child_task.GetTaskSpecification().TaskId(), + ClientID::Nil()) .GetEntries() .size(), total_tasks); } else { // After the last task has been committed, then the child task should // also be evicted. The lineage cache should now be empty. - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 0); + ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), 0); } } - ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0); + ASSERT_EQ(lineage_cache_->GetLineage().GetChildrenSize(), 0); } TEST_F(LineageCacheTest, TestEviction) { @@ -429,51 +467,56 @@ TEST_F(LineageCacheTest, TestEviction) { uint64_t lineage_size = max_lineage_size_ + 1; size_t num_tasks_flushed = 0; std::vector tasks; - InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector(), 1); + InsertTaskChain(*lineage_cache_, tasks, lineage_size, std::vector(), 1); // Check that the last task in the chain still has all tasks in its // uncommitted lineage. const auto last_task_id = tasks.back().GetTaskSpecification().TaskId(); auto uncommitted_lineage = - lineage_cache_.GetUncommittedLineage(last_task_id, ClientID::Nil()); + lineage_cache_->GetUncommittedLineage(last_task_id, ClientID::Nil()); ASSERT_EQ(uncommitted_lineage.GetEntries().size(), lineage_size); // Simulate executing the first task on a remote node and adding it to the // GCS. - auto task_data = std::make_shared(); auto it = tasks.begin(); - RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); + auto task_id = it->GetTaskSpecification().TaskId(); + auto task_data = std::make_shared(); + task_data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary()); + RAY_CHECK_OK(mock_gcs_->MockTasks().RemoteAdd(task_data)); it++; // Check that the remote task is flushed. num_tasks_flushed++; - mock_gcs_.Flush(); - ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); + mock_gcs_->MockTasks().Flush(); + ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed); // Check that the last task in the chain still has all tasks in its // uncommitted lineage. ASSERT_EQ(uncommitted_lineage.GetEntries().size(), lineage_size); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), + ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), lineage_size - num_tasks_flushed); // Simulate executing all the rest of the tasks except the last one on a // remote node and adding them to the GCS. tasks.pop_back(); for (; it != tasks.end(); it++) { - RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); + auto task_id = it->GetTaskSpecification().TaskId(); + auto task_data = std::make_shared(); + task_data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary()); + RAY_CHECK_OK(mock_gcs_->MockTasks().RemoteAdd(task_data)); // Check that the remote task is flushed. num_tasks_flushed++; - mock_gcs_.Flush(); - ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), + mock_gcs_->MockTasks().Flush(); + ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed); + ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), lineage_size - num_tasks_flushed); } // All tasks have now been flushed. Check that enough lineage has been // evicted that the uncommitted lineage is now less than the maximum size. uncommitted_lineage = - lineage_cache_.GetUncommittedLineage(last_task_id, ClientID::Nil()); + lineage_cache_->GetUncommittedLineage(last_task_id, ClientID::Nil()); ASSERT_TRUE(uncommitted_lineage.GetEntries().size() < max_lineage_size_); // The remaining task should have no uncommitted lineage. ASSERT_EQ(uncommitted_lineage.GetEntries().size(), 1); - ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 1); + ASSERT_EQ(lineage_cache_->GetLineage().GetChildrenSize(), 1); } TEST_F(LineageCacheTest, TestOutOfOrderEviction) { @@ -483,38 +526,42 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { uint64_t lineage_size = (2 * max_lineage_size_) + 2; size_t num_tasks_flushed = 0; std::vector tasks; - InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector(), 1); + InsertTaskChain(*lineage_cache_, tasks, lineage_size, std::vector(), 1); // Check that the last task in the chain still has all tasks in its // uncommitted lineage. const auto last_task_id = tasks.back().GetTaskSpecification().TaskId(); auto uncommitted_lineage = - lineage_cache_.GetUncommittedLineage(last_task_id, ClientID::Nil()); + lineage_cache_->GetUncommittedLineage(last_task_id, ClientID::Nil()); ASSERT_EQ(uncommitted_lineage.GetEntries().size(), lineage_size); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), lineage_size); + ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), lineage_size); // Simulate executing the rest of the tasks on a remote node and receiving // the notifications from the GCS in reverse order of execution. auto last_task = tasks.front(); tasks.erase(tasks.begin()); for (auto it = tasks.rbegin(); it != tasks.rend(); it++) { + auto task_id = it->GetTaskSpecification().TaskId(); auto task_data = std::make_shared(); - RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); + task_data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary()); + RAY_CHECK_OK(mock_gcs_->MockTasks().RemoteAdd(task_data)); // Check that the remote task is flushed. num_tasks_flushed++; - mock_gcs_.Flush(); - ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), lineage_size); + mock_gcs_->MockTasks().Flush(); + ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed); + ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), lineage_size); } // Flush the last task. The lineage should not get evicted until this task's // commit is received. + auto task_id = last_task.GetTaskSpecification().TaskId(); auto task_data = std::make_shared(); - RAY_CHECK_OK(mock_gcs_.RemoteAdd(last_task.GetTaskSpecification().TaskId(), task_data)); + task_data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary()); + RAY_CHECK_OK(mock_gcs_->MockTasks().RemoteAdd(task_data)); num_tasks_flushed++; - mock_gcs_.Flush(); - ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 0); - ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0); + mock_gcs_->MockTasks().Flush(); + ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed); + ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), 0); + ASSERT_EQ(lineage_cache_->GetLineage().GetChildrenSize(), 0); } TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { @@ -522,7 +569,7 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { size_t num_tasks_flushed = 0; uint64_t lineage_size = max_lineage_size_ + 1; std::vector tasks; - InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector(), 1); + InsertTaskChain(*lineage_cache_, tasks, lineage_size, std::vector(), 1); // Add more tasks to the lineage cache that will remain local. Each of these // tasks is dependent one of the tasks that was forwarded above. @@ -530,9 +577,9 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { auto return_id = task.GetTaskSpecification().ReturnIdForPlasma(0); auto dependent_task = ExampleTask({return_id}, 1); auto lineage = CreateSingletonLineage(dependent_task); - lineage_cache_.AddUncommittedLineage(dependent_task.GetTaskSpecification().TaskId(), - lineage); - ASSERT_TRUE(lineage_cache_.CommitTask(dependent_task)); + lineage_cache_->AddUncommittedLineage(dependent_task.GetTaskSpecification().TaskId(), + lineage); + ASSERT_TRUE(lineage_cache_->CommitTask(dependent_task)); // Once the forwarded tasks are evicted from the lineage cache, we expect // each of these dependent tasks to be flushed, since all of their // dependencies have been committed. @@ -544,50 +591,52 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { // until after the final remote task is executed, since a task can only be // evicted once all of its ancestors have been committed. for (auto it = tasks.rbegin(); it != tasks.rend(); it++) { + auto task_id = it->GetTaskSpecification().TaskId(); auto task_data = std::make_shared(); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), lineage_size * 2); - RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); + task_data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary()); + ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), lineage_size * 2); + RAY_CHECK_OK(mock_gcs_->MockTasks().RemoteAdd(task_data)); num_tasks_flushed++; - mock_gcs_.Flush(); - ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); + mock_gcs_->MockTasks().Flush(); + ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed); } // Check that after the final remote task is executed, all local lineage is // now evicted. - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 0); - ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0); + ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), 0); + ASSERT_EQ(lineage_cache_->GetLineage().GetChildrenSize(), 0); } TEST_F(LineageCacheTest, TestFlushAllUncommittedTasks) { // Insert a chain of tasks. std::vector tasks; auto return_values = - InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + InsertTaskChain(*lineage_cache_, tasks, 3, std::vector(), 1); std::vector task_ids; for (const auto &task : tasks) { task_ids.push_back(task.GetTaskSpecification().TaskId()); } // Check that we subscribed to each of the uncommitted tasks. - ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size()); + ASSERT_EQ(mock_gcs_->MockTasks().NumRequestedNotifications(), task_ids.size()); // Flush all uncommitted tasks and make sure we add all tasks to // the task table. - lineage_cache_.FlushAllUncommittedTasks(); - ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size()); + lineage_cache_->FlushAllUncommittedTasks(); + ASSERT_EQ(mock_gcs_->MockTasks().NumTaskAdds(), tasks.size()); // Flush again and make sure there are no new tasks added to the // task table. - lineage_cache_.FlushAllUncommittedTasks(); - ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size()); + lineage_cache_->FlushAllUncommittedTasks(); + ASSERT_EQ(mock_gcs_->MockTasks().NumTaskAdds(), tasks.size()); // Flush all GCS notifications. - mock_gcs_.Flush(); + mock_gcs_->MockTasks().Flush(); // Make sure that we unsubscribed to the uncommitted tasks before // we flushed them. ASSERT_EQ(num_notifications_, 0); // Flush again and make sure there are no new tasks added to the // task table. - lineage_cache_.FlushAllUncommittedTasks(); - ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size()); + lineage_cache_->FlushAllUncommittedTasks(); + ASSERT_EQ(mock_gcs_->MockTasks().NumTaskAdds(), tasks.size()); } } // namespace raylet diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 0a9d492ff..c2f427252 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -102,9 +102,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, gcs_client_->client_table().GetLocalClientId(), RayConfig::instance().initial_reconstruction_timeout_milliseconds(), gcs_client_->task_lease_table()), - lineage_cache_(gcs_client_->client_table().GetLocalClientId(), - gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(), - config.max_lineage_size), + lineage_cache_(gcs_client_, config.max_lineage_size), actor_registry_(), node_manager_server_("NodeManager", config.node_manager_port), node_manager_service_(io_service, *this), @@ -140,18 +138,6 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, ray::Status NodeManager::RegisterGcs() { object_manager_.RegisterGcs(); - // Subscribe to task entry commits in the GCS. These notifications are - // forwarded to the lineage cache, which requests notifications about tasks - // that were executed remotely. - const auto task_committed_callback = [this](gcs::RedisGcsClient *client, - const TaskID &task_id, - const TaskTableData &task_data) { - lineage_cache_.HandleEntryCommitted(task_id); - }; - RAY_RETURN_NOT_OK(gcs_client_->raylet_task_table().Subscribe( - JobID::Nil(), gcs_client_->client_table().GetLocalClientId(), - task_committed_callback, nullptr, nullptr)); - const auto task_lease_notification_callback = [this](gcs::RedisGcsClient *client, const TaskID &task_id, const TaskLeaseData &task_lease) { @@ -984,7 +970,7 @@ void NodeManager::ProcessClientMessage( for (const auto &object_id : object_ids) { creating_task_ids.push_back(object_id.TaskId()); } - gcs_client_->raylet_task_table().Delete(JobID::Nil(), creating_task_ids); + RAY_CHECK_OK(gcs_client_->Tasks().AsyncDelete(creating_task_ids, nullptr)); } } break; case protocol::MessageType::PrepareActorCheckpointRequest: { @@ -2437,27 +2423,25 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { auto parent_task_id = task_spec.ParentTaskId(); int port = worker.Port(); RAY_CHECK_OK( - gcs_client_->raylet_task_table().Lookup( - JobID::Nil(), parent_task_id, - /*success_callback=*/ - [this, task_spec, resumed_from_checkpoint, port]( - ray::gcs::RedisGcsClient *client, const TaskID &parent_task_id, - const TaskTableData &parent_task_data) { - // The task was in the GCS task table. Use the stored task spec to - // get the parent actor id. - Task parent_task(parent_task_data.task()); - ActorID parent_actor_id = ActorID::Nil(); - if (parent_task.GetTaskSpecification().IsActorCreationTask()) { - parent_actor_id = parent_task.GetTaskSpecification().ActorCreationId(); - } else if (parent_task.GetTaskSpecification().IsActorTask()) { - parent_actor_id = parent_task.GetTaskSpecification().ActorId(); + gcs_client_->Tasks().AsyncGet( + parent_task_id, + /*callback=*/ + [this, task_spec, resumed_from_checkpoint, port, parent_task_id]( + Status status, const boost::optional &parent_task_data) { + if (parent_task_data) { + // The task was in the GCS task table. Use the stored task spec to + // get the parent actor id. + Task parent_task(parent_task_data->task()); + ActorID parent_actor_id = ActorID::Nil(); + if (parent_task.GetTaskSpecification().IsActorCreationTask()) { + parent_actor_id = parent_task.GetTaskSpecification().ActorCreationId(); + } else if (parent_task.GetTaskSpecification().IsActorTask()) { + parent_actor_id = parent_task.GetTaskSpecification().ActorId(); + } + FinishAssignedActorCreationTask(parent_actor_id, task_spec, + resumed_from_checkpoint, port); + return; } - FinishAssignedActorCreationTask(parent_actor_id, task_spec, - resumed_from_checkpoint, port); - }, - /*failure_callback=*/ - [this, task_spec, resumed_from_checkpoint, port]( - ray::gcs::RedisGcsClient *client, const TaskID &parent_task_id) { // The parent task was not in the GCS task table. It should most likely be // in the lineage cache. ActorID parent_actor_id = ActorID::Nil(); @@ -2574,18 +2558,17 @@ void NodeManager::FinishAssignedActorCreationTask(const ActorID &parent_actor_id void NodeManager::HandleTaskReconstruction(const TaskID &task_id, const ObjectID &required_object_id) { // Retrieve the task spec in order to re-execute the task. - RAY_CHECK_OK(gcs_client_->raylet_task_table().Lookup( - JobID::Nil(), task_id, - /*success_callback=*/ - [this, required_object_id](ray::gcs::RedisGcsClient *client, const TaskID &task_id, - const TaskTableData &task_data) { - // The task was in the GCS task table. Use the stored task spec to - // re-execute the task. - ResubmitTask(Task(task_data.task()), required_object_id); - }, - /*failure_callback=*/ - [this, required_object_id](ray::gcs::RedisGcsClient *client, - const TaskID &task_id) { + RAY_CHECK_OK(gcs_client_->Tasks().AsyncGet( + task_id, + /*callback=*/ + [this, required_object_id, task_id]( + Status status, const boost::optional &task_data) { + if (task_data) { + // The task was in the GCS task table. Use the stored task spec to + // re-execute the task. + ResubmitTask(Task(task_data->task()), required_object_id); + return; + } // The task was not in the GCS task table. It must therefore be in the // lineage cache. if (lineage_cache_.ContainsTask(task_id)) {