[GCS]refactor the GCS Client Task Interface (#6556)

This commit is contained in:
micafan 2019-12-23 17:54:21 +08:00 committed by Hao Chen
parent bac6f3b61e
commit 84d3d4b67b
23 changed files with 991 additions and 836 deletions

View file

@ -1,6 +1,5 @@
#include "ray/core_worker/actor_manager.h"
#include "ray/gcs/redis_actor_info_accessor.h"
#include "ray/gcs/redis_accessor.h"
namespace ray {

View file

@ -220,7 +220,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
std::shared_ptr<gcs::TaskTableData> data = std::make_shared<gcs::TaskTableData>();
data->mutable_task()->mutable_task_spec()->CopyFrom(builder.Build().GetMessage());
RAY_CHECK_OK(gcs_client_->raylet_task_table().Add(job_id, task_id, data, nullptr));
RAY_CHECK_OK(gcs_client_->Tasks().AsyncAdd(data, nullptr));
SetCurrentTaskId(task_id);
}

183
src/ray/gcs/accessor.h Normal file
View file

@ -0,0 +1,183 @@
#ifndef RAY_GCS_ACCESSOR_H
#define RAY_GCS_ACCESSOR_H
#include "ray/common/id.h"
#include "ray/gcs/callback.h"
#include "ray/protobuf/gcs.pb.h"
namespace ray {
namespace gcs {
/// \class ActorInfoAccessor
/// `ActorInfoAccessor` is a sub-interface of `GcsClient`.
/// This class includes all the methods that are related to accessing
/// actor information in the GCS.
class ActorInfoAccessor {
public:
virtual ~ActorInfoAccessor() = default;
/// Get actor specification from GCS asynchronously.
///
/// \param actor_id The ID of actor to look up in the GCS.
/// \param callback Callback that will be called after lookup finishes.
/// \return Status
virtual Status AsyncGet(const ActorID &actor_id,
const OptionalItemCallback<rpc::ActorTableData> &callback) = 0;
/// Register an actor to GCS asynchronously.
///
/// \param data_ptr The actor that will be registered to the GCS.
/// \param callback Callback that will be called after actor has been registered
/// to the GCS.
/// \return Status
virtual Status AsyncRegister(const std::shared_ptr<rpc::ActorTableData> &data_ptr,
const StatusCallback &callback) = 0;
/// Update dynamic states of actor in GCS asynchronously.
///
/// \param actor_id ID of the actor to update.
/// \param data_ptr Data of the actor to update.
/// \param callback Callback that will be called after update finishes.
/// \return Status
/// TODO(micafan) Don't expose the whole `ActorTableData` and only allow
/// updating dynamic states.
virtual Status AsyncUpdate(const ActorID &actor_id,
const std::shared_ptr<rpc::ActorTableData> &data_ptr,
const StatusCallback &callback) = 0;
/// Subscribe to any register or update operations of actors.
///
/// \param subscribe Callback that will be called each time when an actor is registered
/// or updated.
/// \param done Callback that will be called when subscription is complete and we
/// are ready to receive notification.
/// \return Status
virtual Status AsyncSubscribeAll(
const SubscribeCallback<ActorID, rpc::ActorTableData> &subscribe,
const StatusCallback &done) = 0;
/// Subscribe to any update operations of an actor.
///
/// \param actor_id The ID of actor to be subscribed to.
/// \param subscribe Callback that will be called each time when the actor is updated.
/// \param done Callback that will be called when subscription is complete.
/// \return Status
virtual Status AsyncSubscribe(
const ActorID &actor_id,
const SubscribeCallback<ActorID, rpc::ActorTableData> &subscribe,
const StatusCallback &done) = 0;
/// Cancel subscription to an actor.
///
/// \param actor_id The ID of the actor to be unsubscribed to.
/// \param done Callback that will be called when unsubscribe is complete.
/// \return Status
virtual Status AsyncUnsubscribe(const ActorID &actor_id,
const StatusCallback &done) = 0;
protected:
ActorInfoAccessor() = default;
};
/// \class JobInfoAccessor
/// `JobInfoAccessor` is a sub-interface of `GcsClient`.
/// This class includes all the methods that are related to accessing
/// job information in the GCS.
class JobInfoAccessor {
public:
virtual ~JobInfoAccessor() = default;
/// Add a job to GCS asynchronously.
///
/// \param data_ptr The job that will be add to GCS.
/// \param callback Callback that will be called after job has been added
/// to GCS.
/// \return Status
virtual Status AsyncAdd(const std::shared_ptr<rpc::JobTableData> &data_ptr,
const StatusCallback &callback) = 0;
/// Mark job as finished in GCS asynchronously.
///
/// \param job_id ID of the job that will be make finished to GCS.
/// \param callback Callback that will be called after update finished.
/// \return Status
virtual Status AsyncMarkFinished(const JobID &job_id,
const StatusCallback &callback) = 0;
/// Subscribe to finished jobs.
///
/// \param subscribe Callback that will be called each time when a job finishes.
/// \param done Callback that will be called when subscription is complete.
/// \return Status
virtual Status AsyncSubscribeToFinishedJobs(
const SubscribeCallback<JobID, rpc::JobTableData> &subscribe,
const StatusCallback &done) = 0;
protected:
JobInfoAccessor() = default;
};
/// \class TaskInfoAccessor
/// `TaskInfoAccessor` is a sub-interface of `GcsClient`.
/// This class includes all the methods that are related to accessing
/// task information in the GCS.
class TaskInfoAccessor {
public:
virtual ~TaskInfoAccessor() {}
/// Add a task to GCS asynchronously.
///
/// \param data_ptr The task that will be added to GCS.
/// \param callback Callback that will be called after task has been added
/// to GCS.
/// \return Status
virtual Status AsyncAdd(const std::shared_ptr<rpc::TaskTableData> &data_ptr,
const StatusCallback &callback) = 0;
/// Get task information from GCS asynchronously.
///
/// \param task_id The ID of the task to look up in GCS.
/// \param callback Callback that is called after lookup finished.
/// \return Status
virtual Status AsyncGet(const TaskID &task_id,
const OptionalItemCallback<rpc::TaskTableData> &callback) = 0;
/// Delete tasks from GCS asynchronously.
///
/// \param task_ids The vector of IDs to delete from GCS.
/// \param callback Callback that is called after delete finished.
/// \return Status
// TODO(micafan) Will support callback of batch deletion in the future.
// Currently this callback will never be called.
virtual Status AsyncDelete(const std::vector<TaskID> &task_ids,
const StatusCallback &callback) = 0;
/// Subscribe asynchronously to the event that the given task is added in GCS.
///
/// \param task_id The ID of the task to be subscribed to.
/// \param subscribe Callback that will be called each time when the task is updated.
/// \param done Callback that will be called when subscription is complete.
/// \return Status
virtual Status AsyncSubscribe(
const TaskID &task_id,
const SubscribeCallback<TaskID, rpc::TaskTableData> &subscribe,
const StatusCallback &done) = 0;
/// Cancel subscription to a task asynchronously.
/// This method is for node only (core worker shouldn't use this method).
///
/// \param task_id The ID of the task to be unsubscribed to.
/// \param done Callback that will be called when unsubscribe is complete.
/// \return Status
virtual Status AsyncUnsubscribe(const TaskID &task_id, const StatusCallback &done) = 0;
protected:
TaskInfoAccessor() = default;
};
} // namespace gcs
} // namespace ray
#endif // RAY_GCS_ACCESSOR_H

View file

@ -1,87 +0,0 @@
#ifndef RAY_GCS_ACTOR_INFO_ACCESSOR_H
#define RAY_GCS_ACTOR_INFO_ACCESSOR_H
#include "ray/common/id.h"
#include "ray/gcs/callback.h"
#include "ray/protobuf/gcs.pb.h"
namespace ray {
namespace gcs {
/// \class ActorInfoAccessor
/// `ActorInfoAccessor` is a sub-interface of `GcsClient`.
/// This class includes all the methods that are related to accessing
/// actor information in the GCS.
class ActorInfoAccessor {
public:
virtual ~ActorInfoAccessor() = default;
/// Get actor specification from GCS asynchronously.
///
/// \param actor_id The ID of actor to look up in the GCS.
/// \param callback Callback that will be called after lookup finishes.
/// \return Status
virtual Status AsyncGet(const ActorID &actor_id,
const OptionalItemCallback<rpc::ActorTableData> &callback) = 0;
/// Register an actor to GCS asynchronously.
///
/// \param data_ptr The actor that will be registered to the GCS.
/// \param callback Callback that will be called after actor has been registered
/// to the GCS.
/// \return Status
virtual Status AsyncRegister(const std::shared_ptr<rpc::ActorTableData> &data_ptr,
const StatusCallback &callback) = 0;
/// Update dynamic states of actor in GCS asynchronously.
///
/// \param actor_id ID of the actor to update.
/// \param data_ptr Data of the actor to update.
/// \param callback Callback that will be called after update finishes.
/// \return Status
/// TODO(micafan) Don't expose the whole `ActorTableData` and only allow
/// updating dynamic states.
virtual Status AsyncUpdate(const ActorID &actor_id,
const std::shared_ptr<rpc::ActorTableData> &data_ptr,
const StatusCallback &callback) = 0;
/// Subscribe to any register or update operations of actors.
///
/// \param subscribe Callback that will be called each time when an actor is registered
/// or updated.
/// \param done Callback that will be called when subscription is complete and we
/// are ready to receive notification.
/// \return Status
virtual Status AsyncSubscribeAll(
const SubscribeCallback<ActorID, rpc::ActorTableData> &subscribe,
const StatusCallback &done) = 0;
/// Subscribe to any update operations of an actor.
///
/// \param actor_id The ID of actor to be subscribed to.
/// \param subscribe Callback that will be called each time when the actor is updated.
/// \param done Callback that will be called when subscription is complete.
/// \return Status
virtual Status AsyncSubscribe(
const ActorID &actor_id,
const SubscribeCallback<ActorID, rpc::ActorTableData> &subscribe,
const StatusCallback &done) = 0;
/// Cancel subscription to an actor.
///
/// \param actor_id The ID of the actor to be unsubscribed to.
/// \param done Callback that will be called when unsubscribe is complete.
/// \return Status
virtual Status AsyncUnsubscribe(const ActorID &actor_id,
const StatusCallback &done) = 0;
protected:
ActorInfoAccessor() = default;
};
} // namespace gcs
} // namespace ray
#endif // RAY_GCS_ACTOR_INFO_ACCESSOR_H

View file

@ -6,8 +6,7 @@
#include <string>
#include <vector>
#include "ray/common/status.h"
#include "ray/gcs/actor_info_accessor.h"
#include "ray/gcs/job_info_accessor.h"
#include "ray/gcs/accessor.h"
#include "ray/util/logging.h"
namespace ray {
@ -75,6 +74,13 @@ class GcsClient : public std::enable_shared_from_this<GcsClient> {
return *job_accessor_;
}
/// Get the sub-interface for accessing task information in GCS.
/// This function is thread safe.
TaskInfoAccessor &Tasks() {
RAY_CHECK(task_accessor_ != nullptr);
return *task_accessor_;
}
protected:
/// Constructor of GcsClient.
///
@ -88,6 +94,7 @@ class GcsClient : public std::enable_shared_from_this<GcsClient> {
std::unique_ptr<ActorInfoAccessor> actor_accessor_;
std::unique_ptr<JobInfoAccessor> job_accessor_;
std::unique_ptr<TaskInfoAccessor> task_accessor_;
};
} // namespace gcs

View file

@ -1,54 +0,0 @@
#ifndef RAY_GCS_JOB_INFO_ACCESSOR_H
#define RAY_GCS_JOB_INFO_ACCESSOR_H
#include "ray/common/id.h"
#include "ray/gcs/callback.h"
#include "ray/protobuf/gcs.pb.h"
namespace ray {
namespace gcs {
/// \class JobInfoAccessor
/// `JobInfoAccessor` is a sub-interface of `GcsClient`.
/// This class includes all the methods that are related to accessing
/// job information in the GCS.
class JobInfoAccessor {
public:
virtual ~JobInfoAccessor() = default;
/// Add a job to GCS asynchronously.
///
/// \param data_ptr The job that will be add to GCS.
/// \param callback Callback that will be called after job has been added
/// to GCS.
/// \return Status
virtual Status AsyncAdd(const std::shared_ptr<rpc::JobTableData> &data_ptr,
const StatusCallback &callback) = 0;
/// Mark job as finished in GCS asynchronously.
///
/// \param job_id ID of the job that will be make finished to GCS.
/// \param callback Callback that will be called after update finished.
/// \return Status
virtual Status AsyncMarkFinished(const JobID &job_id,
const StatusCallback &callback) = 0;
/// Subscribe to finished jobs.
///
/// \param subscribe Callback that will be called each time when a job finishes.
/// \param done Callback that will be called when subscription is complete.
/// \return Status
virtual Status AsyncSubscribeToFinishedJobs(
const SubscribeCallback<JobID, rpc::JobTableData> &subscribe,
const StatusCallback &done) = 0;
protected:
JobInfoAccessor() = default;
};
} // namespace gcs
} // namespace ray
#endif // RAY_GCS_JOB_INFO_ACCESSOR_H

View file

@ -1,5 +1,6 @@
#include "ray/gcs/redis_actor_info_accessor.h"
#include "ray/gcs/redis_accessor.h"
#include <boost/none.hpp>
#include "ray/gcs/pb_util.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/util/logging.h"
@ -132,6 +133,100 @@ Status RedisActorInfoAccessor::AsyncUnsubscribe(const ActorID &actor_id,
return actor_sub_executor_.AsyncUnsubscribe(node_id_, actor_id, done);
}
RedisJobInfoAccessor::RedisJobInfoAccessor(RedisGcsClient *client_impl)
: client_impl_(client_impl), job_sub_executor_(client_impl->job_table()) {}
Status RedisJobInfoAccessor::AsyncAdd(const std::shared_ptr<JobTableData> &data_ptr,
const StatusCallback &callback) {
return DoAsyncAppend(data_ptr, callback);
}
Status RedisJobInfoAccessor::AsyncMarkFinished(const JobID &job_id,
const StatusCallback &callback) {
std::shared_ptr<JobTableData> data_ptr =
CreateJobTableData(job_id, /*is_dead*/ true, /*time_stamp*/ std::time(nullptr),
/*node_manager_address*/ "", /*driver_pid*/ -1);
return DoAsyncAppend(data_ptr, callback);
}
Status RedisJobInfoAccessor::DoAsyncAppend(const std::shared_ptr<JobTableData> &data_ptr,
const StatusCallback &callback) {
JobTable::WriteCallback on_done = nullptr;
if (callback != nullptr) {
on_done = [callback](RedisGcsClient *client, const JobID &job_id,
const JobTableData &data) { callback(Status::OK()); };
}
JobID job_id = JobID::FromBinary(data_ptr->job_id());
return client_impl_->job_table().Append(job_id, job_id, data_ptr, on_done);
}
Status RedisJobInfoAccessor::AsyncSubscribeToFinishedJobs(
const SubscribeCallback<JobID, JobTableData> &subscribe, const StatusCallback &done) {
RAY_CHECK(subscribe != nullptr);
auto on_subscribe = [subscribe](const JobID &job_id, const JobTableData &job_data) {
if (job_data.is_dead()) {
subscribe(job_id, job_data);
}
};
return job_sub_executor_.AsyncSubscribeAll(ClientID::Nil(), on_subscribe, done);
}
RedisTaskInfoAccessor::RedisTaskInfoAccessor(RedisGcsClient *client_impl)
: client_impl_(client_impl), task_sub_executor_(client_impl->raylet_task_table()) {}
Status RedisTaskInfoAccessor::AsyncAdd(const std::shared_ptr<TaskTableData> &data_ptr,
const StatusCallback &callback) {
raylet::TaskTable::WriteCallback on_done = nullptr;
if (callback != nullptr) {
on_done = [callback](RedisGcsClient *client, const TaskID &task_id,
const TaskTableData &data) { callback(Status::OK()); };
}
TaskID task_id = TaskID::FromBinary(data_ptr->task().task_spec().task_id());
raylet::TaskTable &task_table = client_impl_->raylet_task_table();
return task_table.Add(JobID::Nil(), task_id, data_ptr, on_done);
}
Status RedisTaskInfoAccessor::AsyncGet(
const TaskID &task_id, const OptionalItemCallback<TaskTableData> &callback) {
RAY_CHECK(callback != nullptr);
auto on_success = [callback](RedisGcsClient *client, const TaskID &task_id,
const TaskTableData &data) {
boost::optional<TaskTableData> result(data);
callback(Status::OK(), result);
};
auto on_failure = [callback](RedisGcsClient *client, const TaskID &task_id) {
boost::optional<TaskTableData> result;
callback(Status::Invalid("Task not exist."), result);
};
raylet::TaskTable &task_table = client_impl_->raylet_task_table();
return task_table.Lookup(JobID::Nil(), task_id, on_success, on_failure);
}
Status RedisTaskInfoAccessor::AsyncDelete(const std::vector<TaskID> &task_ids,
const StatusCallback &callback) {
raylet::TaskTable &task_table = client_impl_->raylet_task_table();
task_table.Delete(JobID::Nil(), task_ids);
// TODO(micafan) Always return OK here.
// Confirm if we need to handle the deletion failure and how to handle it.
return Status::OK();
}
Status RedisTaskInfoAccessor::AsyncSubscribe(
const TaskID &task_id, const SubscribeCallback<TaskID, TaskTableData> &subscribe,
const StatusCallback &done) {
RAY_CHECK(subscribe != nullptr);
return task_sub_executor_.AsyncSubscribe(subscribe_id_, task_id, subscribe, done);
}
Status RedisTaskInfoAccessor::AsyncUnsubscribe(const TaskID &task_id,
const StatusCallback &done) {
return task_sub_executor_.AsyncUnsubscribe(subscribe_id_, task_id, done);
}
} // namespace gcs
} // namespace ray

View file

@ -0,0 +1,139 @@
#ifndef RAY_GCS_REDIS_ACCESSOR_H
#define RAY_GCS_REDIS_ACCESSOR_H
#include "ray/common/id.h"
#include "ray/common/task/task_spec.h"
#include "ray/gcs/accessor.h"
#include "ray/gcs/callback.h"
#include "ray/gcs/subscription_executor.h"
#include "ray/gcs/tables.h"
namespace ray {
namespace gcs {
class RedisGcsClient;
std::shared_ptr<gcs::ActorTableData> CreateActorTableData(
const TaskSpecification &task_spec, const rpc::Address &address,
gcs::ActorTableData::ActorState state, uint64_t remaining_reconstructions);
/// \class RedisActorInfoAccessor
/// `RedisActorInfoAccessor` is an implementation of `ActorInfoAccessor`
/// that uses Redis as the backend storage.
class RedisActorInfoAccessor : public ActorInfoAccessor {
public:
explicit RedisActorInfoAccessor(RedisGcsClient *client_impl);
virtual ~RedisActorInfoAccessor() {}
Status AsyncGet(const ActorID &actor_id,
const OptionalItemCallback<ActorTableData> &callback) override;
Status AsyncRegister(const std::shared_ptr<ActorTableData> &data_ptr,
const StatusCallback &callback) override;
Status AsyncUpdate(const ActorID &actor_id,
const std::shared_ptr<ActorTableData> &data_ptr,
const StatusCallback &callback) override;
Status AsyncSubscribeAll(const SubscribeCallback<ActorID, ActorTableData> &subscribe,
const StatusCallback &done) override;
Status AsyncSubscribe(const ActorID &actor_id,
const SubscribeCallback<ActorID, ActorTableData> &subscribe,
const StatusCallback &done) override;
Status AsyncUnsubscribe(const ActorID &actor_id, const StatusCallback &done) override;
private:
RedisGcsClient *client_impl_{nullptr};
// Use a random ClientID for actor subscription. Because:
// If we use ClientID::Nil, GCS will still send all actors' updates to this GCS Client.
// Even we can filter out irrelevant updates, but there will be extra overhead.
// And because the new GCS Client will no longer hold the local ClientID, so we use
// random ClientID instead.
// TODO(micafan): Remove this random id, once GCS becomes a service.
ClientID node_id_{ClientID::FromRandom()};
typedef SubscriptionExecutor<ActorID, ActorTableData, ActorTable>
ActorSubscriptionExecutor;
ActorSubscriptionExecutor actor_sub_executor_;
};
/// \class RedisJobInfoAccessor
/// RedisJobInfoAccessor is an implementation of `JobInfoAccessor`
/// that uses Redis as the backend storage.
class RedisJobInfoAccessor : public JobInfoAccessor {
public:
explicit RedisJobInfoAccessor(RedisGcsClient *client_impl);
virtual ~RedisJobInfoAccessor() {}
Status AsyncAdd(const std::shared_ptr<JobTableData> &data_ptr,
const StatusCallback &callback) override;
Status AsyncMarkFinished(const JobID &job_id, const StatusCallback &callback) override;
Status AsyncSubscribeToFinishedJobs(
const SubscribeCallback<JobID, JobTableData> &subscribe,
const StatusCallback &done) override;
private:
/// Append job information to GCS asynchronously.
///
/// \param data_ptr The job information that will be appended to GCS.
/// \param callback Callback that will be called after append done.
/// \return Status
Status DoAsyncAppend(const std::shared_ptr<JobTableData> &data_ptr,
const StatusCallback &callback);
RedisGcsClient *client_impl_{nullptr};
typedef SubscriptionExecutor<JobID, JobTableData, JobTable> JobSubscriptionExecutor;
JobSubscriptionExecutor job_sub_executor_;
};
/// \class RedisTaskInfoAccessor
/// `RedisTaskInfoAccessor` is an implementation of `TaskInfoAccessor`
/// that uses Redis as the backend storage.
class RedisTaskInfoAccessor : public TaskInfoAccessor {
public:
explicit RedisTaskInfoAccessor(RedisGcsClient *client_impl);
~RedisTaskInfoAccessor() {}
Status AsyncAdd(const std::shared_ptr<TaskTableData> &data_ptr,
const StatusCallback &callback);
Status AsyncGet(const TaskID &task_id,
const OptionalItemCallback<TaskTableData> &callback);
Status AsyncDelete(const std::vector<TaskID> &task_ids, const StatusCallback &callback);
Status AsyncSubscribe(const TaskID &task_id,
const SubscribeCallback<TaskID, TaskTableData> &subscribe,
const StatusCallback &done);
Status AsyncUnsubscribe(const TaskID &task_id, const StatusCallback &done);
private:
RedisGcsClient *client_impl_{nullptr};
// Use a random ClientID for task subscription. Because:
// If we use ClientID::Nil, GCS will still send all tasks' updates to this GCS Client.
// Even we can filter out irrelevant updates, but there will be extra overhead.
// And because the new GCS Client will no longer hold the local ClientID, so we use
// random ClientID instead.
// TODO(micafan): Remove this random id, once GCS becomes a service.
ClientID subscribe_id_{ClientID::FromRandom()};
typedef SubscriptionExecutor<TaskID, TaskTableData, raylet::TaskTable>
TaskSubscriptionExecutor;
TaskSubscriptionExecutor task_sub_executor_;
};
} // namespace gcs
} // namespace ray
#endif // RAY_GCS_REDIS_ACCESSOR_H

View file

@ -1,68 +0,0 @@
#ifndef RAY_GCS_REDIS_ACTOR_INFO_ACCESSOR_H
#define RAY_GCS_REDIS_ACTOR_INFO_ACCESSOR_H
#include "ray/common/id.h"
#include "ray/common/task/task_spec.h"
#include "ray/gcs/actor_info_accessor.h"
#include "ray/gcs/callback.h"
#include "ray/gcs/subscription_executor.h"
#include "ray/gcs/tables.h"
namespace ray {
namespace gcs {
std::shared_ptr<gcs::ActorTableData> CreateActorTableData(
const TaskSpecification &task_spec, const rpc::Address &address,
gcs::ActorTableData::ActorState state, uint64_t remaining_reconstructions);
class RedisGcsClient;
/// \class RedisActorInfoAccessor
/// `RedisActorInfoAccessor` is an implementation of `ActorInfoAccessor`
/// that uses Redis as the backend storage.
class RedisActorInfoAccessor : public ActorInfoAccessor {
public:
explicit RedisActorInfoAccessor(RedisGcsClient *client_impl);
virtual ~RedisActorInfoAccessor() {}
Status AsyncGet(const ActorID &actor_id,
const OptionalItemCallback<ActorTableData> &callback) override;
Status AsyncRegister(const std::shared_ptr<ActorTableData> &data_ptr,
const StatusCallback &callback) override;
Status AsyncUpdate(const ActorID &actor_id,
const std::shared_ptr<ActorTableData> &data_ptr,
const StatusCallback &callback) override;
Status AsyncSubscribeAll(const SubscribeCallback<ActorID, ActorTableData> &subscribe,
const StatusCallback &done) override;
Status AsyncSubscribe(const ActorID &actor_id,
const SubscribeCallback<ActorID, ActorTableData> &subscribe,
const StatusCallback &done) override;
Status AsyncUnsubscribe(const ActorID &actor_id, const StatusCallback &done) override;
private:
RedisGcsClient *client_impl_{nullptr};
// Use a random ClientID for actor subscription. Because:
// If we use ClientID::Nil, GCS will still send all actors' updates to this GCS Client.
// Even we can filter out irrelevant updates, but there will be extra overhead.
// And because the new GCS Client will no longer hold the local ClientID, so we use
// random ClientID instead.
// TODO(micafan): Remove this random id, once GCS becomes a service.
ClientID node_id_{ClientID::FromRandom()};
typedef SubscriptionExecutor<ActorID, ActorTableData, ActorTable>
ActorSubscriptionExecutor;
ActorSubscriptionExecutor actor_sub_executor_;
};
} // namespace gcs
} // namespace ray
#endif // RAY_GCS_REDIS_ACTOR_INFO_ACCESSOR_H

View file

@ -2,9 +2,8 @@
#include <unistd.h>
#include "ray/common/ray_config.h"
#include "ray/gcs/redis_actor_info_accessor.h"
#include "ray/gcs/redis_accessor.h"
#include "ray/gcs/redis_context.h"
#include "ray/gcs/redis_job_info_accessor.h"
static void GetRedisShards(redisContext *context, std::vector<std::string> &addresses,
std::vector<int> &ports) {
@ -73,13 +72,8 @@ namespace ray {
namespace gcs {
RedisGcsClient::RedisGcsClient(const GcsClientOptions &options) : GcsClient(options) {
#if RAY_USE_NEW_GCS
command_type_ = CommandType::kChain;
#else
command_type_ = CommandType::kRegular;
#endif
}
RedisGcsClient::RedisGcsClient(const GcsClientOptions &options)
: GcsClient(options), command_type_(CommandType::kRegular) {}
RedisGcsClient::RedisGcsClient(const GcsClientOptions &options, CommandType command_type)
: GcsClient(options), command_type_(command_type) {}
@ -150,6 +144,7 @@ Status RedisGcsClient::Connect(boost::asio::io_service &io_service) {
actor_accessor_.reset(new RedisActorInfoAccessor(this));
job_accessor_.reset(new RedisJobInfoAccessor(this));
task_accessor_.reset(new RedisTaskInfoAccessor(this));
is_connected_ = true;

View file

@ -18,12 +18,14 @@ namespace gcs {
class RedisContext;
class RAY_EXPORT RedisGcsClient : public GcsClient {
// TODO(micafan) Will remove those friend class after we replace RedisGcsClient
// TODO(micafan) Will remove those friend class / method after we replace RedisGcsClient
// with interface class GcsClient in raylet.
friend class RedisActorInfoAccessor;
friend class RedisJobInfoAccessor;
friend class RedisTaskInfoAccessor;
friend class SubscriptionExecutorTest;
friend class LogSubscribeTestHelper;
friend class TaskTableTestHelper;
public:
/// Constructor of RedisGcsClient.
@ -55,7 +57,6 @@ class RAY_EXPORT RedisGcsClient : public GcsClient {
// TODO: Some API for getting the error on the driver
ObjectTable &object_table();
raylet::TaskTable &raylet_task_table();
TaskReconstructionLog &task_reconstruction_log();
TaskLeaseTable &task_lease_table();
ClientTable &client_table();
@ -96,6 +97,8 @@ class RAY_EXPORT RedisGcsClient : public GcsClient {
ActorTable &actor_table();
/// This method will be deprecated, use method Jobs() instead.
JobTable &job_table();
/// This method will be deprecated, use method Tasks() instead.
raylet::TaskTable &raylet_task_table();
// GCS command type. If CommandType::kChain, chain-replicated versions of the tables
// might be used, if available.

View file

@ -1,50 +0,0 @@
#include "ray/gcs/redis_job_info_accessor.h"
#include "ray/gcs/pb_util.h"
#include "ray/gcs/redis_gcs_client.h"
namespace ray {
namespace gcs {
RedisJobInfoAccessor::RedisJobInfoAccessor(RedisGcsClient *client_impl)
: client_impl_(client_impl), job_sub_executor_(client_impl->job_table()) {}
Status RedisJobInfoAccessor::AsyncAdd(const std::shared_ptr<JobTableData> &data_ptr,
const StatusCallback &callback) {
return DoAsyncAppend(data_ptr, callback);
}
Status RedisJobInfoAccessor::AsyncMarkFinished(const JobID &job_id,
const StatusCallback &callback) {
std::shared_ptr<JobTableData> data_ptr =
CreateJobTableData(job_id, /*is_dead*/ true, /*time_stamp*/ std::time(nullptr),
/*node_manager_address*/ "", /*driver_pid*/ -1);
return DoAsyncAppend(data_ptr, callback);
}
Status RedisJobInfoAccessor::DoAsyncAppend(const std::shared_ptr<JobTableData> &data_ptr,
const StatusCallback &callback) {
JobTable::WriteCallback on_done = nullptr;
if (callback != nullptr) {
on_done = [callback](RedisGcsClient *client, const JobID &job_id,
const JobTableData &data) { callback(Status::OK()); };
}
JobID job_id = JobID::FromBinary(data_ptr->job_id());
return client_impl_->job_table().Append(job_id, job_id, data_ptr, on_done);
}
Status RedisJobInfoAccessor::AsyncSubscribeToFinishedJobs(
const SubscribeCallback<JobID, JobTableData> &subscribe, const StatusCallback &done) {
RAY_CHECK(subscribe != nullptr);
auto on_subscribe = [subscribe](const JobID &job_id, const JobTableData &job_data) {
if (job_data.is_dead()) {
subscribe(job_id, job_data);
}
};
return job_sub_executor_.AsyncSubscribeAll(ClientID::Nil(), on_subscribe, done);
}
} // namespace gcs
} // namespace ray

View file

@ -1,53 +0,0 @@
#ifndef RAY_GCS_REDIS_JOB_INFO_ACCESSOR_H
#define RAY_GCS_REDIS_JOB_INFO_ACCESSOR_H
#include "ray/common/id.h"
#include "ray/gcs/callback.h"
#include "ray/gcs/job_info_accessor.h"
#include "ray/gcs/subscription_executor.h"
#include "ray/gcs/tables.h"
namespace ray {
namespace gcs {
class RedisGcsClient;
/// \class RedisJobInfoAccessor
/// RedisJobInfoAccessor is an implementation of `JobInfoAccessor`
/// that uses Redis as the backend storage.
class RedisJobInfoAccessor : public JobInfoAccessor {
public:
explicit RedisJobInfoAccessor(RedisGcsClient *client_impl);
virtual ~RedisJobInfoAccessor() {}
Status AsyncAdd(const std::shared_ptr<JobTableData> &data_ptr,
const StatusCallback &callback) override;
Status AsyncMarkFinished(const JobID &job_id, const StatusCallback &callback) override;
Status AsyncSubscribeToFinishedJobs(
const SubscribeCallback<JobID, JobTableData> &subscribe,
const StatusCallback &done) override;
private:
/// Append job information to GCS asynchronously.
///
/// \param data_ptr The job information that will be appended to GCS.
/// \param callback Callback that will be called after append done.
/// \return Status
Status DoAsyncAppend(const std::shared_ptr<JobTableData> &data_ptr,
const StatusCallback &callback);
RedisGcsClient *client_impl_{nullptr};
typedef SubscriptionExecutor<JobID, JobTableData, JobTable> JobSubscriptionExecutor;
JobSubscriptionExecutor job_sub_executor_;
};
} // namespace gcs
} // namespace ray
#endif // RAY_GCS_REDIS_JOB_INFO_ACCESSOR_H

View file

@ -189,6 +189,7 @@ Status SubscriptionExecutor<ID, Data, Table>::AsyncUnsubscribe(
template class SubscriptionExecutor<ActorID, ActorTableData, ActorTable>;
template class SubscriptionExecutor<ActorID, ActorTableData, DirectActorTable>;
template class SubscriptionExecutor<JobID, JobTableData, JobTable>;
template class SubscriptionExecutor<TaskID, TaskTableData, raylet::TaskTable>;
} // namespace gcs

View file

@ -306,6 +306,13 @@ Status Table<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id
done);
}
template <typename ID, typename Data>
Status Table<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
const Callback &subscribe,
const SubscriptionCallback &done) {
return Subscribe(job_id, client_id, subscribe, /*failure*/ nullptr, done);
}
template <typename ID, typename Data>
std::string Table<ID, Data>::DebugString() const {
std::stringstream result;

View file

@ -311,6 +311,9 @@ class Table : private Log<ID, Data>,
using Log<ID, Data>::RequestNotifications;
using Log<ID, Data>::CancelNotifications;
/// Expose this interface for use by subscription tools class SubscriptionExecutor.
/// In this way TaskTable() can also reuse class SubscriptionExecutor.
using Log<ID, Data>::Subscribe;
/// Add an entry to the table. This overwrites any existing data at the key.
///
@ -356,6 +359,24 @@ class Table : private Log<ID, Data>,
const Callback &subscribe, const FailureCallback &failure,
const SubscriptionCallback &done);
/// Subscribe to any Add operations to this table. The caller may choose to
/// subscribe to all Adds, or to subscribe only to keys that it requests
/// notifications for. This may only be called once per Table instance.
///
/// \param job_id The ID of the job.
/// \param client_id The type of update to listen to. If this is nil, then a
/// message for each Add to the table will be received. Else, only
/// messages for the given client will be received. In the latter
/// case, the client may request notifications on specific keys in the
/// table via `RequestNotifications`.
/// \param subscribe Callback that is called on each received message. If the
/// callback is called with an empty vector, then there was no data at the key.
/// \param done Callback that is called when subscription is complete and we
/// are ready to receive messages.
/// \return Status
Status Subscribe(const JobID &job_id, const ClientID &client_id,
const Callback &subscribe, const SubscriptionCallback &done);
void Delete(const JobID &job_id, const ID &id) { Log<ID, Data>::Delete(job_id, id); }
void Delete(const JobID &job_id, const std::vector<ID> &ids) {

View file

@ -7,6 +7,7 @@
#include <thread>
#include <vector>
#include "gtest/gtest.h"
#include "ray/gcs/redis_accessor.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/util/test_util.h"

View file

@ -87,69 +87,278 @@ class TestGcsWithChainAsio : public TestGcsWithAsio {
TestGcsWithChainAsio() : TestGcsWithAsio(gcs::CommandType::kChain){};
};
/// A helper function that creates a GCS `TaskTableData` object.
std::shared_ptr<TaskTableData> CreateTaskTableData(const TaskID &task_id,
uint64_t num_returns = 0) {
auto data = std::make_shared<TaskTableData>();
data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary());
data->mutable_task()->mutable_task_spec()->set_num_returns(num_returns);
return data;
}
class TaskTableTestHelper {
public:
/// A helper function that creates a GCS `TaskTableData` object.
static std::shared_ptr<TaskTableData> CreateTaskTableData(const TaskID &task_id,
uint64_t num_returns = 0) {
auto data = std::make_shared<TaskTableData>();
data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary());
data->mutable_task()->mutable_task_spec()->set_num_returns(num_returns);
return data;
}
/// A helper function that compare whether 2 `TaskTableData` objects are equal.
/// Note, this function only compares fields set by `CreateTaskTableData`.
bool TaskTableDataEqual(const TaskTableData &data1, const TaskTableData &data2) {
const auto &spec1 = data1.task().task_spec();
const auto &spec2 = data2.task().task_spec();
return (spec1.task_id() == spec2.task_id() &&
spec1.num_returns() == spec2.num_returns());
}
/// A helper function that compare whether 2 `TaskTableData` objects are equal.
/// Note, this function only compares fields set by `CreateTaskTableData`.
static bool TaskTableDataEqual(const TaskTableData &data1, const TaskTableData &data2) {
const auto &spec1 = data1.task().task_spec();
const auto &spec2 = data2.task().task_spec();
return (spec1.task_id() == spec2.task_id() &&
spec1.num_returns() == spec2.num_returns());
}
void TestTableLookup(const JobID &job_id, std::shared_ptr<gcs::RedisGcsClient> client) {
const auto task_id = RandomTaskId();
const auto data = CreateTaskTableData(task_id);
static void TestTableLookup(const JobID &job_id,
std::shared_ptr<gcs::RedisGcsClient> client) {
const auto task_id = RandomTaskId();
const auto data = CreateTaskTableData(task_id);
// Check that we added the correct task.
auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &d) {
ASSERT_EQ(id, task_id);
ASSERT_TRUE(TaskTableDataEqual(*data, d));
};
// Check that we added the correct task.
auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &d) {
ASSERT_EQ(id, task_id);
ASSERT_TRUE(TaskTableDataEqual(*data, d));
};
// Check that the lookup returns the added task.
auto lookup_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &d) {
ASSERT_EQ(id, task_id);
ASSERT_TRUE(TaskTableDataEqual(*data, d));
test->Stop();
};
// Check that the lookup returns the added task.
auto lookup_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &d) {
ASSERT_EQ(id, task_id);
ASSERT_TRUE(TaskTableDataEqual(*data, d));
test->Stop();
};
// Check that the lookup does not return an empty entry.
auto failure_callback = [](gcs::RedisGcsClient *client, const TaskID &id) {
RAY_CHECK(false);
};
// Check that the lookup does not return an empty entry.
auto failure_callback = [](gcs::RedisGcsClient *client, const TaskID &id) {
RAY_CHECK(false);
};
// Add the task, then do a lookup.
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback));
RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, lookup_callback,
failure_callback));
// Run the event loop. The loop will only stop if the Lookup callback is
// called (or an assertion failure).
test->Start();
}
// Add the task, then do a lookup.
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback));
RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, lookup_callback,
failure_callback));
// Run the event loop. The loop will only stop if the Lookup callback is
// called (or an assertion failure).
test->Start();
}
static void TestTableLookupFailure(const JobID &job_id,
std::shared_ptr<gcs::RedisGcsClient> client) {
TaskID task_id = RandomTaskId();
// Check that the lookup does not return data.
auto lookup_callback = [](gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &d) { RAY_CHECK(false); };
// Check that the lookup returns an empty entry.
auto failure_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id) {
ASSERT_EQ(id, task_id);
test->Stop();
};
// Lookup the task. We have not done any writes, so the key should be empty.
RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, lookup_callback,
failure_callback));
// Run the event loop. The loop will only stop if the failure callback is
// called (or an assertion failure).
test->Start();
}
static void TestDeleteKeysFromTable(
const JobID &job_id, std::shared_ptr<gcs::RedisGcsClient> client,
std::vector<std::shared_ptr<TaskTableData>> &data_vector, bool stop_at_end) {
std::vector<TaskID> ids;
TaskID task_id;
for (auto &data : data_vector) {
task_id = RandomTaskId();
ids.push_back(task_id);
// Check that we added the correct object entries.
auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &d) {
ASSERT_EQ(id, task_id);
ASSERT_TRUE(TaskTableDataEqual(*data, d));
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback));
}
for (const auto &task_id : ids) {
auto task_lookup_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &data) {
ASSERT_EQ(id, task_id);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id,
task_lookup_callback, nullptr));
}
if (ids.size() == 1) {
client->raylet_task_table().Delete(job_id, ids[0]);
} else {
client->raylet_task_table().Delete(job_id, ids);
}
auto expected_failure_callback = [](RedisGcsClient *client, const TaskID &id) {
ASSERT_TRUE(true);
test->IncrementNumCallbacks();
};
auto undesired_callback = [](gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &data) { ASSERT_TRUE(false); };
for (size_t i = 0; i < ids.size(); ++i) {
RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, undesired_callback,
expected_failure_callback));
}
if (stop_at_end) {
auto stop_callback = [](RedisGcsClient *client, const TaskID &id) { test->Stop(); };
RAY_CHECK_OK(
client->raylet_task_table().Lookup(job_id, ids[0], nullptr, stop_callback));
}
}
static void TestTableSubscribeId(const JobID &job_id,
std::shared_ptr<gcs::RedisGcsClient> client) {
size_t num_modifications = 3;
// Add a table entry.
TaskID task_id1 = RandomTaskId();
// Add a table entry at a second key.
TaskID task_id2 = RandomTaskId();
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto notification_callback = [task_id2, num_modifications](
gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &data) {
// Check that we only get notifications for the requested key.
ASSERT_EQ(id, task_id2);
// Check that we get notifications in the same order as the writes.
ASSERT_TRUE(
TaskTableDataEqual(data, *CreateTaskTableData(task_id2, test->NumCallbacks())));
test->IncrementNumCallbacks();
if (test->NumCallbacks() == num_modifications) {
test->Stop();
}
};
// The failure callback should be called once since both keys start as empty.
bool failure_notification_received = false;
auto failure_callback = [task_id2, &failure_notification_received](
gcs::RedisGcsClient *client, const TaskID &id) {
ASSERT_EQ(id, task_id2);
// The failure notification should be the first notification received.
ASSERT_EQ(test->NumCallbacks(), 0);
failure_notification_received = true;
};
// The callback for subscription success. Once we've subscribed, request
// notifications for only one of the keys, then write to both keys.
auto subscribe_callback = [job_id, task_id1, task_id2,
num_modifications](gcs::RedisGcsClient *client) {
// Request notifications for one of the keys.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
job_id, task_id2, client->client_table().GetLocalClientId(), nullptr));
// Write both keys. We should only receive notifications for the key that
// we requested them for.
for (uint64_t i = 0; i < num_modifications; i++) {
auto data = CreateTaskTableData(task_id1, i);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id1, data, nullptr));
}
for (uint64_t i = 0; i < num_modifications; i++) {
auto data = CreateTaskTableData(task_id2, i);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id2, data, nullptr));
}
};
// Subscribe to notifications for this client. This allows us to request and
// receive notifications for specific keys.
RAY_CHECK_OK(client->raylet_task_table().Subscribe(
job_id, client->client_table().GetLocalClientId(), notification_callback,
failure_callback, subscribe_callback));
// Run the event loop. The loop will only stop if the registered subscription
// callback is called for the requested key.
test->Start();
// Check that the failure callback was called since the key was initially
// empty.
ASSERT_TRUE(failure_notification_received);
// Check that we received one notification callback for each write to the
// requested key.
ASSERT_EQ(test->NumCallbacks(), num_modifications);
}
static void TestTableSubscribeCancel(const JobID &job_id,
std::shared_ptr<gcs::RedisGcsClient> client) {
// Add a table entry.
const auto task_id = RandomTaskId();
const int num_modifications = 3;
const auto data = CreateTaskTableData(task_id, 0);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr));
// The failure callback should not be called since all keys are non-empty
// when notifications are requested.
auto failure_callback = [](gcs::RedisGcsClient *client, const TaskID &id) {
RAY_CHECK(false);
};
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto notification_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &data) {
ASSERT_EQ(id, task_id);
// Check that we only get notifications for the first and last writes,
// since notifications are canceled in between.
if (test->NumCallbacks() == 0) {
ASSERT_TRUE(TaskTableDataEqual(data, *CreateTaskTableData(task_id, 0)));
} else {
ASSERT_TRUE(TaskTableDataEqual(
data, *CreateTaskTableData(task_id, num_modifications - 1)));
}
test->IncrementNumCallbacks();
if (test->NumCallbacks() == num_modifications - 1) {
test->Stop();
}
};
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto subscribe_callback = [job_id, task_id](gcs::RedisGcsClient *client) {
// Request notifications, then cancel immediately. We should receive a
// notification for the current value at the key.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
job_id, task_id, client->client_table().GetLocalClientId(), nullptr));
RAY_CHECK_OK(client->raylet_task_table().CancelNotifications(
job_id, task_id, client->client_table().GetLocalClientId(), nullptr));
// Write to the key. Since we canceled notifications, we should not receive
// a notification for these writes.
for (uint64_t i = 1; i < num_modifications; i++) {
auto data = CreateTaskTableData(task_id, i);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr));
}
// Request notifications again. We should receive a notification for the
// current value at the key.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
job_id, task_id, client->client_table().GetLocalClientId(), nullptr));
};
// Subscribe to notifications for this client. This allows us to request and
// receive notifications for specific keys.
RAY_CHECK_OK(client->raylet_task_table().Subscribe(
job_id, client->client_table().GetLocalClientId(), notification_callback,
failure_callback, subscribe_callback));
// Run the event loop. The loop will only stop if the registered subscription
// callback is called for the requested key.
test->Start();
// Check that we received a notification callback for the first and least
// writes to the key, since notifications are canceled in between.
ASSERT_EQ(test->NumCallbacks(), 2);
}
};
// Convenient macro to test across {ae, asio} x {regular, chain} x {the tests}.
// Undefined at the end.
#define TEST_MACRO(FIXTURE, TEST) \
TEST_F(FIXTURE, TEST) { \
test = this; \
TEST(job_id_, client_); \
#define TEST_TASK_TABLE_MACRO(FIXTURE, TEST) \
TEST_F(FIXTURE, TEST) { \
test = this; \
TaskTableTestHelper::TEST(job_id_, client_); \
}
TEST_MACRO(TestGcsWithAsio, TestTableLookup);
#if RAY_USE_NEW_GCS
TEST_MACRO(TestGcsWithChainAsio, TestTableLookup);
#endif
TEST_TASK_TABLE_MACRO(TestGcsWithAsio, TestTableLookup);
void TestLogLookup(const JobID &job_id, std::shared_ptr<gcs::RedisGcsClient> client) {
// Append some entries to the log at an object ID.
@ -196,32 +405,7 @@ TEST_F(TestGcsWithAsio, TestLogLookup) {
TestLogLookup(job_id_, client_);
}
void TestTableLookupFailure(const JobID &job_id,
std::shared_ptr<gcs::RedisGcsClient> client) {
TaskID task_id = RandomTaskId();
// Check that the lookup does not return data.
auto lookup_callback = [](gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &d) { RAY_CHECK(false); };
// Check that the lookup returns an empty entry.
auto failure_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id) {
ASSERT_EQ(id, task_id);
test->Stop();
};
// Lookup the task. We have not done any writes, so the key should be empty.
RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, lookup_callback,
failure_callback));
// Run the event loop. The loop will only stop if the failure callback is
// called (or an assertion failure).
test->Start();
}
TEST_MACRO(TestGcsWithAsio, TestTableLookupFailure);
#if RAY_USE_NEW_GCS
TEST_MACRO(TestGcsWithChainAsio, TestTableLookupFailure);
#endif
TEST_TASK_TABLE_MACRO(TestGcsWithAsio, TestTableLookupFailure);
void TestLogAppendAt(const JobID &job_id, std::shared_ptr<gcs::RedisGcsClient> client) {
TaskID task_id = RandomTaskId();
@ -395,55 +579,6 @@ void TestDeleteKeysFromLog(
}
}
void TestDeleteKeysFromTable(const JobID &job_id,
std::shared_ptr<gcs::RedisGcsClient> client,
std::vector<std::shared_ptr<TaskTableData>> &data_vector,
bool stop_at_end) {
std::vector<TaskID> ids;
TaskID task_id;
for (auto &data : data_vector) {
task_id = RandomTaskId();
ids.push_back(task_id);
// Check that we added the correct object entries.
auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &d) {
ASSERT_EQ(id, task_id);
ASSERT_TRUE(TaskTableDataEqual(*data, d));
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback));
}
for (const auto &task_id : ids) {
auto task_lookup_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &data) {
ASSERT_EQ(id, task_id);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, task_lookup_callback,
nullptr));
}
if (ids.size() == 1) {
client->raylet_task_table().Delete(job_id, ids[0]);
} else {
client->raylet_task_table().Delete(job_id, ids);
}
auto expected_failure_callback = [](RedisGcsClient *client, const TaskID &id) {
ASSERT_TRUE(true);
test->IncrementNumCallbacks();
};
auto undesired_callback = [](gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &data) { ASSERT_TRUE(false); };
for (size_t i = 0; i < ids.size(); ++i) {
RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, undesired_callback,
expected_failure_callback));
}
if (stop_at_end) {
auto stop_callback = [](RedisGcsClient *client, const TaskID &id) { test->Stop(); };
RAY_CHECK_OK(
client->raylet_task_table().Lookup(job_id, ids[0], nullptr, stop_callback));
}
}
void TestDeleteKeysFromSet(const JobID &job_id,
std::shared_ptr<gcs::RedisGcsClient> client,
std::vector<std::shared_ptr<ObjectTableData>> &data_vector) {
@ -523,21 +658,21 @@ void TestDeleteKeys(const JobID &job_id, std::shared_ptr<gcs::RedisGcsClient> cl
std::vector<std::shared_ptr<TaskTableData>> task_vector;
auto AppendTaskData = [&task_vector](size_t add_count) {
for (size_t i = 0; i < add_count; ++i) {
task_vector.push_back(CreateTaskTableData(RandomTaskId()));
task_vector.push_back(TaskTableTestHelper::CreateTaskTableData(RandomTaskId()));
}
};
AppendTaskData(1);
ASSERT_EQ(task_vector.size(), 1);
TestDeleteKeysFromTable(job_id, client, task_vector, false);
TaskTableTestHelper::TestDeleteKeysFromTable(job_id, client, task_vector, false);
AppendTaskData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2);
ASSERT_GT(task_vector.size(), 1);
ASSERT_LT(task_vector.size(), RayConfig::instance().maximum_gcs_deletion_batch_size());
TestDeleteKeysFromTable(job_id, client, task_vector, false);
TaskTableTestHelper::TestDeleteKeysFromTable(job_id, client, task_vector, false);
AppendTaskData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2);
ASSERT_GT(task_vector.size(), RayConfig::instance().maximum_gcs_deletion_batch_size());
TestDeleteKeysFromTable(job_id, client, task_vector, true);
TaskTableTestHelper::TestDeleteKeysFromTable(job_id, client, task_vector, true);
test->Start();
ASSERT_GT(test->NumCallbacks(),
@ -841,81 +976,7 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeAll) {
TestSetSubscribeAll(job_id_, client_);
}
void TestTableSubscribeId(const JobID &job_id,
std::shared_ptr<gcs::RedisGcsClient> client) {
size_t num_modifications = 3;
// Add a table entry.
TaskID task_id1 = RandomTaskId();
// Add a table entry at a second key.
TaskID task_id2 = RandomTaskId();
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto notification_callback = [task_id2, num_modifications](gcs::RedisGcsClient *client,
const TaskID &id,
const TaskTableData &data) {
// Check that we only get notifications for the requested key.
ASSERT_EQ(id, task_id2);
// Check that we get notifications in the same order as the writes.
ASSERT_TRUE(
TaskTableDataEqual(data, *CreateTaskTableData(task_id2, test->NumCallbacks())));
test->IncrementNumCallbacks();
if (test->NumCallbacks() == num_modifications) {
test->Stop();
}
};
// The failure callback should be called once since both keys start as empty.
bool failure_notification_received = false;
auto failure_callback = [task_id2, &failure_notification_received](
gcs::RedisGcsClient *client, const TaskID &id) {
ASSERT_EQ(id, task_id2);
// The failure notification should be the first notification received.
ASSERT_EQ(test->NumCallbacks(), 0);
failure_notification_received = true;
};
// The callback for subscription success. Once we've subscribed, request
// notifications for only one of the keys, then write to both keys.
auto subscribe_callback = [job_id, task_id1, task_id2,
num_modifications](gcs::RedisGcsClient *client) {
// Request notifications for one of the keys.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
job_id, task_id2, client->client_table().GetLocalClientId(), nullptr));
// Write both keys. We should only receive notifications for the key that
// we requested them for.
for (uint64_t i = 0; i < num_modifications; i++) {
auto data = CreateTaskTableData(task_id1, i);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id1, data, nullptr));
}
for (uint64_t i = 0; i < num_modifications; i++) {
auto data = CreateTaskTableData(task_id2, i);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id2, data, nullptr));
}
};
// Subscribe to notifications for this client. This allows us to request and
// receive notifications for specific keys.
RAY_CHECK_OK(client->raylet_task_table().Subscribe(
job_id, client->client_table().GetLocalClientId(), notification_callback,
failure_callback, subscribe_callback));
// Run the event loop. The loop will only stop if the registered subscription
// callback is called for the requested key.
test->Start();
// Check that the failure callback was called since the key was initially
// empty.
ASSERT_TRUE(failure_notification_received);
// Check that we received one notification callback for each write to the
// requested key.
ASSERT_EQ(test->NumCallbacks(), num_modifications);
}
TEST_MACRO(TestGcsWithAsio, TestTableSubscribeId);
#if RAY_USE_NEW_GCS
TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeId);
#endif
TEST_TASK_TABLE_MACRO(TestGcsWithAsio, TestTableSubscribeId);
TEST_F(TestGcsWithAsio, TestLogSubscribeId) {
test = this;
@ -998,77 +1059,7 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeId) {
TestSetSubscribeId(job_id_, client_);
}
void TestTableSubscribeCancel(const JobID &job_id,
std::shared_ptr<gcs::RedisGcsClient> client) {
// Add a table entry.
const auto task_id = RandomTaskId();
const int num_modifications = 3;
const auto data = CreateTaskTableData(task_id, 0);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr));
// The failure callback should not be called since all keys are non-empty
// when notifications are requested.
auto failure_callback = [](gcs::RedisGcsClient *client, const TaskID &id) {
RAY_CHECK(false);
};
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto notification_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &data) {
ASSERT_EQ(id, task_id);
// Check that we only get notifications for the first and last writes,
// since notifications are canceled in between.
if (test->NumCallbacks() == 0) {
ASSERT_TRUE(TaskTableDataEqual(data, *CreateTaskTableData(task_id, 0)));
} else {
ASSERT_TRUE(
TaskTableDataEqual(data, *CreateTaskTableData(task_id, num_modifications - 1)));
}
test->IncrementNumCallbacks();
if (test->NumCallbacks() == num_modifications - 1) {
test->Stop();
}
};
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto subscribe_callback = [job_id, task_id](gcs::RedisGcsClient *client) {
// Request notifications, then cancel immediately. We should receive a
// notification for the current value at the key.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
job_id, task_id, client->client_table().GetLocalClientId(), nullptr));
RAY_CHECK_OK(client->raylet_task_table().CancelNotifications(
job_id, task_id, client->client_table().GetLocalClientId(), nullptr));
// Write to the key. Since we canceled notifications, we should not receive
// a notification for these writes.
for (uint64_t i = 1; i < num_modifications; i++) {
auto data = CreateTaskTableData(task_id, i);
RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr));
}
// Request notifications again. We should receive a notification for the
// current value at the key.
RAY_CHECK_OK(client->raylet_task_table().RequestNotifications(
job_id, task_id, client->client_table().GetLocalClientId(), nullptr));
};
// Subscribe to notifications for this client. This allows us to request and
// receive notifications for specific keys.
RAY_CHECK_OK(client->raylet_task_table().Subscribe(
job_id, client->client_table().GetLocalClientId(), notification_callback,
failure_callback, subscribe_callback));
// Run the event loop. The loop will only stop if the registered subscription
// callback is called for the requested key.
test->Start();
// Check that we received a notification callback for the first and least
// writes to the key, since notifications are canceled in between.
ASSERT_EQ(test->NumCallbacks(), 2);
}
TEST_MACRO(TestGcsWithAsio, TestTableSubscribeCancel);
#if RAY_USE_NEW_GCS
TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeCancel);
#endif
TEST_TASK_TABLE_MACRO(TestGcsWithAsio, TestTableSubscribeCancel);
TEST_F(TestGcsWithAsio, TestLogSubscribeCancel) {
test = this;
@ -1431,7 +1422,7 @@ TEST_F(TestGcsWithAsio, TestHashTable) {
TestHashTable(job_id_, client_);
}
#undef TEST_MACRO
#undef TEST_TASK_TABLE_MACRO
} // namespace gcs
} // namespace ray

View file

@ -1,4 +1,3 @@
#include "ray/gcs/redis_job_info_accessor.h"
#include <memory>
#include "gtest/gtest.h"
#include "ray/gcs/pb_util.h"

View file

@ -1,7 +1,7 @@
#include "lineage_cache.h"
#include "ray/stats/stats.h"
#include <sstream>
#include "ray/gcs/redis_gcs_client.h"
#include "ray/stats/stats.h"
namespace ray {
@ -152,16 +152,15 @@ const std::unordered_set<TaskID> &Lineage::GetChildren(const TaskID &task_id) co
}
}
LineageCache::LineageCache(const ClientID &client_id,
gcs::TableInterface<TaskID, TaskTableData> &task_storage,
gcs::PubsubInterface<TaskID> &task_pubsub,
LineageCache::LineageCache(std::shared_ptr<gcs::RedisGcsClient> gcs_client,
uint64_t max_lineage_size)
: client_id_(client_id), task_storage_(task_storage), task_pubsub_(task_pubsub) {}
: gcs_client_(gcs_client) {}
/// A helper function to add some uncommitted lineage to the local cache.
void LineageCache::AddUncommittedLineage(const TaskID &task_id,
const Lineage &uncommitted_lineage) {
RAY_LOG(DEBUG) << "Adding uncommitted task " << task_id << " on " << client_id_;
RAY_LOG(DEBUG) << "Adding uncommitted task " << task_id << " on "
<< gcs_client_->client_table().GetLocalClientId();
// If the entry is not found in the lineage to merge, then we stop since
// there is nothing to copy into the merged lineage.
auto entry = uncommitted_lineage.GetEntry(task_id);
@ -192,7 +191,8 @@ bool LineageCache::CommitTask(const Task &task) {
return true;
}
const TaskID task_id = task.GetTaskSpecification().TaskId();
RAY_LOG(DEBUG) << "Committing task " << task_id << " on " << client_id_;
RAY_LOG(DEBUG) << "Committing task " << task_id << " on "
<< gcs_client_->client_table().GetLocalClientId();
if (lineage_.SetEntry(task, GcsStatus::UNCOMMITTED) ||
lineage_.GetEntry(task_id)->GetStatus() == GcsStatus::UNCOMMITTED) {
@ -275,17 +275,17 @@ void LineageCache::FlushTask(const TaskID &task_id) {
RAY_CHECK(entry);
RAY_CHECK(entry->GetStatus() < GcsStatus::COMMITTING);
gcs::raylet::TaskTable::WriteCallback task_callback =
[this](ray::gcs::RedisGcsClient *client, const TaskID &id,
const TaskTableData &data) { HandleEntryCommitted(id); };
auto task_callback = [this, task_id](Status status) {
RAY_CHECK(status.ok());
HandleEntryCommitted(task_id);
};
auto task = lineage_.GetEntry(task_id);
auto task_data = std::make_shared<TaskTableData>();
task_data->mutable_task()->mutable_task_spec()->CopyFrom(
task->TaskData().GetTaskSpecification().GetMessage());
task_data->mutable_task()->mutable_task_execution_spec()->CopyFrom(
task->TaskData().GetTaskExecutionSpec().GetMessage());
RAY_CHECK_OK(task_storage_.Add(JobID(task->TaskData().GetTaskSpecification().JobId()),
task_id, task_data, task_callback));
RAY_CHECK_OK(gcs_client_->Tasks().AsyncAdd(task_data, task_callback));
// We successfully wrote the task, so mark it as committing.
// TODO(swang): Use a batched interface and write with all object entries.
@ -296,10 +296,12 @@ bool LineageCache::SubscribeTask(const TaskID &task_id) {
auto inserted = subscribed_tasks_.insert(task_id);
bool unsubscribed = inserted.second;
if (unsubscribed) {
// Request notifications for the task if we haven't already requested
// notifications for it.
RAY_CHECK_OK(task_pubsub_.RequestNotifications(JobID::Nil(), task_id, client_id_,
/*done*/ nullptr));
auto subscribe = [this](const TaskID &task_id, const TaskTableData) {
HandleEntryCommitted(task_id);
};
// Subscribe to the task.
RAY_CHECK_OK(gcs_client_->Tasks().AsyncSubscribe(task_id, subscribe,
/*done*/ nullptr));
}
// Return whether we were previously unsubscribed to this task and are now
// subscribed.
@ -310,10 +312,8 @@ bool LineageCache::UnsubscribeTask(const TaskID &task_id) {
auto it = subscribed_tasks_.find(task_id);
bool subscribed = (it != subscribed_tasks_.end());
if (subscribed) {
// Cancel notifications for the task if we previously requested
// notifications for it.
RAY_CHECK_OK(task_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_,
/*done*/ nullptr));
// Cancel subscribe to the task.
RAY_CHECK_OK(gcs_client_->Tasks().AsyncUnsubscribe(task_id, /*done*/ nullptr));
subscribed_tasks_.erase(it);
}
// Return whether we were previously subscribed to this task and are now
@ -339,7 +339,8 @@ void LineageCache::EvictTask(const TaskID &task_id) {
}
// Evict the task.
RAY_LOG(DEBUG) << "Evicting task " << task_id << " on " << client_id_;
RAY_LOG(DEBUG) << "Evicting task " << task_id << " on "
<< gcs_client_->client_table().GetLocalClientId();
lineage_.PopEntry(task_id);
// Try to evict the children of the evict task. These are the tasks that have
// a dependency on the evicted task.

View file

@ -12,7 +12,7 @@
#include "ray/common/id.h"
#include "ray/common/status.h"
#include "ray/common/task/task.h"
#include "ray/gcs/tables.h"
#include "ray/gcs/redis_gcs_client.h"
namespace ray {
@ -209,9 +209,8 @@ class LineageCache {
public:
/// Create a lineage cache for the given task storage system.
/// TODO(swang): Pass in the policy (interface?).
LineageCache(const ClientID &client_id,
gcs::TableInterface<TaskID, TaskTableData> &task_storage,
gcs::PubsubInterface<TaskID> &task_pubsub, uint64_t max_lineage_size);
LineageCache(std::shared_ptr<gcs::RedisGcsClient> gcs_client,
uint64_t max_lineage_size);
/// Asynchronously commit a task to the GCS.
///
@ -303,19 +302,13 @@ class LineageCache {
/// was successful (whether we were subscribed).
bool UnsubscribeTask(const TaskID &task_id);
/// The client ID, used to request notifications for specific tasks.
/// TODO(swang): Move the ClientID into the generic Table implementation.
ClientID client_id_;
/// The durable storage system for task information.
gcs::TableInterface<TaskID, TaskTableData> &task_storage_;
/// The pubsub storage system for task information. This can be used to
/// request notifications for the commit of a task entry.
gcs::PubsubInterface<TaskID> &task_pubsub_;
/// A client connection to the GCS.
std::shared_ptr<gcs::RedisGcsClient> gcs_client_;
/// All tasks and objects that we are responsible for writing back to the
/// GCS, and the tasks and objects in their lineage.
Lineage lineage_;
/// The tasks that we've subscribed to notifications for from the pubsub
/// storage system. We will receive a notification for these tasks on commit.
/// The tasks that we've subscribed to.
/// We will receive a notification for these tasks on commit.
std::unordered_set<TaskID> subscribed_tasks_;
};

View file

@ -1,4 +1,5 @@
#include <list>
#include <memory>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@ -9,6 +10,8 @@
#include "ray/common/task/task_util.h"
#include "ray/gcs/callback.h"
#include "ray/gcs/redis_accessor.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/raylet/lineage_cache.h"
@ -23,66 +26,68 @@ const static JobID kDefaultJobId = JobID::FromInt(1);
const static TaskID kDefaultDriverTaskId = TaskID::ForDriverTask(kDefaultJobId);
class MockGcs : public gcs::TableInterface<TaskID, TaskTableData>,
public gcs::PubsubInterface<TaskID> {
public:
MockGcs() {}
class MockGcsClient;
void Subscribe(const gcs::raylet::TaskTable::WriteCallback &notification_callback) {
class MockTaskInfoAccessor : public gcs::RedisTaskInfoAccessor {
public:
MockTaskInfoAccessor(gcs::RedisGcsClient *gcs_client)
: RedisTaskInfoAccessor(gcs_client) {}
virtual ~MockTaskInfoAccessor() {}
void RegisterSubscribeCallback(
const gcs::SubscribeCallback<TaskID, rpc::TaskTableData> &notification_callback) {
notification_callback_ = notification_callback;
}
Status Add(const JobID &job_id, const TaskID &task_id,
const std::shared_ptr<TaskTableData> &task_data,
const gcs::TableInterface<TaskID, TaskTableData>::WriteCallback &done) {
Status AsyncAdd(const std::shared_ptr<TaskTableData> &task_data,
const gcs::StatusCallback &done) {
TaskID task_id = TaskID::FromBinary(task_data->task().task_spec().task_id());
task_table_[task_id] = task_data;
auto callback = done;
// If we requested notifications for this task ID, send the notification as
// part of the callback.
if (subscribed_tasks_.count(task_id) == 1) {
callback = [this, done](ray::gcs::RedisGcsClient *client, const TaskID &task_id,
const TaskTableData &data) {
done(client, task_id, data);
callback = [this, done, task_id, task_data](Status status) {
done(status);
// If we're subscribed to the task to be added, also send a
// subscription notification.
notification_callback_(client, task_id, data);
notification_callback_(task_id, *task_data);
};
}
callbacks_.push_back(
std::pair<gcs::raylet::TaskTable::WriteCallback, TaskID>(callback, task_id));
callbacks_.push_back({callback, task_id});
num_task_adds_++;
return ray::Status::OK();
}
Status RemoteAdd(const TaskID &task_id, std::shared_ptr<TaskTableData> task_data) {
Status RemoteAdd(std::shared_ptr<TaskTableData> task_data) {
TaskID task_id = TaskID::FromBinary(task_data->task().task_spec().task_id());
task_table_[task_id] = task_data;
// Send a notification after the add if the lineage cache requested
// notifications for this key.
bool send_notification = (subscribed_tasks_.count(task_id) == 1);
auto callback = [this, send_notification](ray::gcs::RedisGcsClient *client,
const TaskID &task_id,
const TaskTableData &data) {
auto callback = [this, send_notification, task_id, task_data](Status status) {
if (send_notification) {
notification_callback_(client, task_id, data);
notification_callback_(task_id, *task_data);
}
};
return Add(JobID::Nil(), task_id, task_data, callback);
return AsyncAdd(task_data, callback);
}
Status RequestNotifications(const JobID &job_id, const TaskID &task_id,
const ClientID &client_id,
const gcs::StatusCallback &done) {
Status AsyncSubscribe(
const TaskID &task_id,
const gcs::SubscribeCallback<TaskID, rpc::TaskTableData> &notification_callback,
const gcs::StatusCallback &done) {
subscribed_tasks_.insert(task_id);
if (task_table_.count(task_id) == 1) {
callbacks_.push_back({notification_callback_, task_id});
notification_callbacks_.push_back({notification_callback_, task_id});
}
num_requested_notifications_ += 1;
return ray::Status::OK();
}
Status CancelNotifications(const JobID &job_id, const TaskID &task_id,
const ClientID &client_id, const gcs::StatusCallback &done) {
Status AsyncUnsubscribe(const TaskID &task_id, const gcs::StatusCallback &done) {
subscribed_tasks_.erase(task_id);
return ray::Status::OK();
}
@ -91,7 +96,10 @@ class MockGcs : public gcs::TableInterface<TaskID, TaskTableData>,
auto callbacks = std::move(callbacks_);
callbacks_.clear();
for (const auto &callback : callbacks) {
callback.first(NULL, callback.second, *task_table_[callback.second]);
callback.first(Status::OK());
}
for (const auto &callback : notification_callbacks_) {
callback.first(callback.second, *task_table_[callback.second]);
}
}
@ -107,32 +115,59 @@ class MockGcs : public gcs::TableInterface<TaskID, TaskTableData>,
private:
std::unordered_map<TaskID, std::shared_ptr<TaskTableData>> task_table_;
std::vector<std::pair<gcs::raylet::TaskTable::WriteCallback, TaskID>> callbacks_;
gcs::raylet::TaskTable::WriteCallback notification_callback_;
std::vector<std::pair<gcs::StatusCallback, TaskID>> callbacks_;
typedef gcs::SubscribeCallback<TaskID, rpc::TaskTableData> TaskSubscribeCallback;
TaskSubscribeCallback notification_callback_;
std::vector<std::pair<TaskSubscribeCallback, TaskID>> notification_callbacks_;
std::unordered_set<TaskID> subscribed_tasks_;
int num_requested_notifications_ = 0;
int num_task_adds_ = 0;
};
class MockGcsClient : public gcs::RedisGcsClient {
public:
MockGcsClient(const gcs::GcsClientOptions &options) : RedisGcsClient(options) {
client_table_fake_.reset(
new gcs::ClientTable({nullptr}, this, ClientID::FromRandom()));
task_table_fake_.reset(new gcs::raylet::TaskTable({nullptr}, this));
task_accessor_.reset(new MockTaskInfoAccessor(this));
}
gcs::ClientTable &client_table() { return *client_table_fake_; }
gcs::raylet::TaskTable &raylet_task_table() { return *task_table_fake_; }
MockTaskInfoAccessor &MockTasks() {
return *dynamic_cast<MockTaskInfoAccessor *>(task_accessor_.get());
}
private:
std::unique_ptr<gcs::ClientTable> client_table_fake_;
std::unique_ptr<gcs::raylet::TaskTable> task_table_fake_;
};
class LineageCacheTest : public ::testing::Test {
public:
LineageCacheTest()
: max_lineage_size_(10),
num_notifications_(0),
mock_gcs_(),
lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) {
mock_gcs_.Subscribe([this](ray::gcs::RedisGcsClient *client, const TaskID &task_id,
const TaskTableData &data) {
lineage_cache_.HandleEntryCommitted(task_id);
num_notifications_++;
});
LineageCacheTest() : max_lineage_size_(10), num_notifications_(0) {
gcs::GcsClientOptions options("10.10.10.10", 12100, "");
mock_gcs_ = std::make_shared<MockGcsClient>(options);
lineage_cache_.reset(new LineageCache(mock_gcs_, max_lineage_size_));
mock_gcs_->MockTasks().RegisterSubscribeCallback(
[this](const TaskID &task_id, const TaskTableData &data) {
lineage_cache_->HandleEntryCommitted(task_id);
num_notifications_++;
});
}
protected:
uint64_t max_lineage_size_;
uint64_t num_notifications_;
MockGcs mock_gcs_;
LineageCache lineage_cache_;
std::shared_ptr<MockGcsClient> mock_gcs_;
std::unique_ptr<LineageCache> lineage_cache_;
};
static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
@ -179,7 +214,7 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineage) {
// Insert two independent chains of tasks.
std::vector<Task> tasks1;
auto return_values1 =
InsertTaskChain(lineage_cache_, tasks1, 3, std::vector<ObjectID>(), 1);
InsertTaskChain(*lineage_cache_, tasks1, 3, std::vector<ObjectID>(), 1);
std::vector<TaskID> task_ids1;
for (const auto &task : tasks1) {
task_ids1.push_back(task.GetTaskSpecification().TaskId());
@ -187,7 +222,7 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineage) {
std::vector<Task> tasks2;
auto return_values2 =
InsertTaskChain(lineage_cache_, tasks2, 2, std::vector<ObjectID>(), 2);
InsertTaskChain(*lineage_cache_, tasks2, 2, std::vector<ObjectID>(), 2);
std::vector<TaskID> task_ids2;
for (const auto &task : tasks2) {
task_ids2.push_back(task.GetTaskSpecification().TaskId());
@ -195,7 +230,7 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineage) {
// Get the uncommitted lineage for the last task (the leaf) of one of the chains.
auto uncommitted_lineage =
lineage_cache_.GetUncommittedLineage(task_ids1.back(), ClientID::Nil());
lineage_cache_->GetUncommittedLineage(task_ids1.back(), ClientID::Nil());
// Check that the uncommitted lineage is exactly equal to the first chain of tasks.
ASSERT_EQ(task_ids1.size(), uncommitted_lineage.GetEntries().size());
for (auto &task_id : task_ids1) {
@ -208,7 +243,7 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineage) {
std::vector<ObjectID> combined_arguments = return_values1;
combined_arguments.insert(combined_arguments.end(), return_values2.begin(),
return_values2.end());
InsertTaskChain(lineage_cache_, combined_tasks, 1, combined_arguments, 1);
InsertTaskChain(*lineage_cache_, combined_tasks, 1, combined_arguments, 1);
std::vector<TaskID> combined_task_ids;
for (const auto &task : combined_tasks) {
combined_task_ids.push_back(task.GetTaskSpecification().TaskId());
@ -216,7 +251,7 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineage) {
// Get the uncommitted lineage for the inserted task.
uncommitted_lineage =
lineage_cache_.GetUncommittedLineage(combined_task_ids.back(), ClientID::Nil());
lineage_cache_->GetUncommittedLineage(combined_task_ids.back(), ClientID::Nil());
// Check that the uncommitted lineage is exactly equal to the entire set of
// tasks inserted so far.
ASSERT_EQ(combined_task_ids.size(), uncommitted_lineage.GetEntries().size());
@ -229,13 +264,13 @@ TEST_F(LineageCacheTest, TestDuplicateUncommittedLineage) {
// Insert a chain of tasks.
std::vector<Task> tasks;
auto return_values =
InsertTaskChain(lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
InsertTaskChain(*lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
std::vector<TaskID> task_ids;
for (const auto &task : tasks) {
task_ids.push_back(task.GetTaskSpecification().TaskId());
}
// Check that we subscribed to each of the uncommitted tasks.
ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size());
ASSERT_EQ(mock_gcs_->MockTasks().NumRequestedNotifications(), task_ids.size());
// Check that if we add the same tasks as UNCOMMITTED again, we do not issue
// duplicate subscribe requests.
@ -243,21 +278,21 @@ TEST_F(LineageCacheTest, TestDuplicateUncommittedLineage) {
for (const auto &task : tasks) {
duplicate_lineage.SetEntry(task, GcsStatus::UNCOMMITTED);
}
lineage_cache_.AddUncommittedLineage(task_ids.back(), duplicate_lineage);
ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size());
lineage_cache_->AddUncommittedLineage(task_ids.back(), duplicate_lineage);
ASSERT_EQ(mock_gcs_->MockTasks().NumRequestedNotifications(), task_ids.size());
// Check that if we commit one of the tasks, we still do not issue any
// duplicate subscribe requests.
lineage_cache_.CommitTask(tasks.front());
lineage_cache_.AddUncommittedLineage(task_ids.back(), duplicate_lineage);
ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size());
lineage_cache_->CommitTask(tasks.front());
lineage_cache_->AddUncommittedLineage(task_ids.back(), duplicate_lineage);
ASSERT_EQ(mock_gcs_->MockTasks().NumRequestedNotifications(), task_ids.size());
}
TEST_F(LineageCacheTest, TestMarkTaskAsForwarded) {
// Insert chain of tasks.
std::vector<Task> tasks;
auto return_values =
InsertTaskChain(lineage_cache_, tasks, 4, std::vector<ObjectID>(), 1);
InsertTaskChain(*lineage_cache_, tasks, 4, std::vector<ObjectID>(), 1);
std::vector<TaskID> task_ids;
for (const auto &task : tasks) {
task_ids.push_back(task.GetTaskSpecification().TaskId());
@ -267,12 +302,12 @@ TEST_F(LineageCacheTest, TestMarkTaskAsForwarded) {
auto node_id2 = ClientID::FromRandom();
auto forwarded_task_id = task_ids[task_ids.size() - 2];
auto remaining_task_id = task_ids[task_ids.size() - 1];
lineage_cache_.MarkTaskAsForwarded(forwarded_task_id, node_id);
lineage_cache_->MarkTaskAsForwarded(forwarded_task_id, node_id);
auto uncommitted_lineage =
lineage_cache_.GetUncommittedLineage(remaining_task_id, node_id);
lineage_cache_->GetUncommittedLineage(remaining_task_id, node_id);
auto uncommitted_lineage_all =
lineage_cache_.GetUncommittedLineage(remaining_task_id, node_id2);
lineage_cache_->GetUncommittedLineage(remaining_task_id, node_id2);
ASSERT_EQ(1, uncommitted_lineage.GetEntries().size());
ASSERT_EQ(4, uncommitted_lineage_all.GetEntries().size());
@ -281,7 +316,7 @@ TEST_F(LineageCacheTest, TestMarkTaskAsForwarded) {
// Check that lineage of requested task includes itself, regardless of whether
// it has been forwarded before.
auto uncommitted_lineage_forwarded =
lineage_cache_.GetUncommittedLineage(forwarded_task_id, node_id);
lineage_cache_->GetUncommittedLineage(forwarded_task_id, node_id);
ASSERT_EQ(1, uncommitted_lineage_forwarded.GetEntries().size());
}
@ -289,30 +324,30 @@ TEST_F(LineageCacheTest, TestWritebackReady) {
// Insert a chain of dependent tasks.
size_t num_tasks_flushed = 0;
std::vector<Task> tasks;
InsertTaskChain(lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
InsertTaskChain(*lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
// Check that when no tasks have been marked as ready, we do not flush any
// entries.
ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed);
ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed);
// Check that after marking the first task as ready, we flush only that task.
ASSERT_TRUE(lineage_cache_.CommitTask(tasks.front()));
ASSERT_TRUE(lineage_cache_->CommitTask(tasks.front()));
num_tasks_flushed++;
ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed);
ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed);
}
TEST_F(LineageCacheTest, TestWritebackOrder) {
// Insert a chain of dependent tasks.
std::vector<Task> tasks;
InsertTaskChain(lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
InsertTaskChain(*lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
size_t num_tasks_flushed = tasks.size();
// Mark all tasks as ready. All tasks should be flushed.
for (const auto &task : tasks) {
ASSERT_TRUE(lineage_cache_.CommitTask(task));
ASSERT_TRUE(lineage_cache_->CommitTask(task));
}
ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed);
ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed);
}
TEST_F(LineageCacheTest, TestEvictChain) {
@ -331,42 +366,45 @@ TEST_F(LineageCacheTest, TestEvictChain) {
uncommitted_lineage.SetEntry(task, GcsStatus::UNCOMMITTED);
}
// Mark the last task as ready to flush.
lineage_cache_.AddUncommittedLineage(tasks.back().GetTaskSpecification().TaskId(),
uncommitted_lineage);
ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), tasks.size());
ASSERT_TRUE(lineage_cache_.CommitTask(tasks.back()));
lineage_cache_->AddUncommittedLineage(tasks.back().GetTaskSpecification().TaskId(),
uncommitted_lineage);
ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), tasks.size());
ASSERT_TRUE(lineage_cache_->CommitTask(tasks.back()));
num_tasks_flushed++;
ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed);
ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed);
// Flush acknowledgements. The lineage cache should receive the commit for
// the flushed task, but its lineage should not be evicted yet.
mock_gcs_.Flush();
mock_gcs_->MockTasks().Flush();
ASSERT_EQ(lineage_cache_
.GetUncommittedLineage(tasks.back().GetTaskSpecification().TaskId(),
ClientID::Nil())
->GetUncommittedLineage(tasks.back().GetTaskSpecification().TaskId(),
ClientID::Nil())
.GetEntries()
.size(),
tasks.size());
ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), tasks.size());
ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), tasks.size());
// Simulate executing the task on a remote node and adding it to the GCS.
auto task_id = tasks.at(1).GetTaskSpecification().TaskId();
auto task_data = std::make_shared<TaskTableData>();
RAY_CHECK_OK(
mock_gcs_.RemoteAdd(tasks.at(1).GetTaskSpecification().TaskId(), task_data));
mock_gcs_.Flush();
task_data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary());
RAY_CHECK_OK(mock_gcs_->MockTasks().RemoteAdd(task_data));
mock_gcs_->MockTasks().Flush();
ASSERT_EQ(lineage_cache_
.GetUncommittedLineage(tasks.back().GetTaskSpecification().TaskId(),
ClientID::Nil())
->GetUncommittedLineage(tasks.back().GetTaskSpecification().TaskId(),
ClientID::Nil())
.GetEntries()
.size(),
tasks.size());
ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), tasks.size());
ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), tasks.size());
// Simulate executing the task on a remote node and adding it to the GCS.
RAY_CHECK_OK(
mock_gcs_.RemoteAdd(tasks.at(0).GetTaskSpecification().TaskId(), task_data));
mock_gcs_.Flush();
ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 0);
ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0);
task_id = tasks.at(0).GetTaskSpecification().TaskId();
auto task_data_2 = std::make_shared<TaskTableData>();
task_data_2->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary());
RAY_CHECK_OK(mock_gcs_->MockTasks().RemoteAdd(task_data_2));
mock_gcs_->MockTasks().Flush();
ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), 0);
ASSERT_EQ(lineage_cache_->GetLineage().GetChildrenSize(), 0);
}
TEST_F(LineageCacheTest, TestEvictManyParents) {
@ -378,50 +416,50 @@ TEST_F(LineageCacheTest, TestEvictManyParents) {
parent_tasks.push_back(task);
arguments.push_back(task.GetTaskSpecification().ReturnIdForPlasma(0));
auto lineage = CreateSingletonLineage(task);
lineage_cache_.AddUncommittedLineage(task.GetTaskSpecification().TaskId(), lineage);
lineage_cache_->AddUncommittedLineage(task.GetTaskSpecification().TaskId(), lineage);
}
// Create a child task that is dependent on all of the previous tasks.
auto child_task = ExampleTask(arguments, 1);
auto lineage = CreateSingletonLineage(child_task);
lineage_cache_.AddUncommittedLineage(child_task.GetTaskSpecification().TaskId(),
lineage);
lineage_cache_->AddUncommittedLineage(child_task.GetTaskSpecification().TaskId(),
lineage);
// Flush the child task. Make sure that it remains in the cache, since none
// of its parents have been committed yet, and that the uncommitted lineage
// still includes all of the parent tasks.
size_t total_tasks = parent_tasks.size() + 1;
lineage_cache_.CommitTask(child_task);
mock_gcs_.Flush();
ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), total_tasks);
lineage_cache_->CommitTask(child_task);
mock_gcs_->MockTasks().Flush();
ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), total_tasks);
ASSERT_EQ(lineage_cache_
.GetUncommittedLineage(child_task.GetTaskSpecification().TaskId(),
ClientID::Nil())
->GetUncommittedLineage(child_task.GetTaskSpecification().TaskId(),
ClientID::Nil())
.GetEntries()
.size(),
total_tasks);
// Flush each parent task and check for eviction safety.
for (const auto &parent_task : parent_tasks) {
lineage_cache_.CommitTask(parent_task);
mock_gcs_.Flush();
lineage_cache_->CommitTask(parent_task);
mock_gcs_->MockTasks().Flush();
total_tasks--;
if (total_tasks > 1) {
// Each task should be evicted as soon as its commit is acknowledged,
// since the parent tasks have no dependencies.
ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), total_tasks);
ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), total_tasks);
ASSERT_EQ(lineage_cache_
.GetUncommittedLineage(child_task.GetTaskSpecification().TaskId(),
ClientID::Nil())
->GetUncommittedLineage(child_task.GetTaskSpecification().TaskId(),
ClientID::Nil())
.GetEntries()
.size(),
total_tasks);
} else {
// After the last task has been committed, then the child task should
// also be evicted. The lineage cache should now be empty.
ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 0);
ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), 0);
}
}
ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0);
ASSERT_EQ(lineage_cache_->GetLineage().GetChildrenSize(), 0);
}
TEST_F(LineageCacheTest, TestEviction) {
@ -429,51 +467,56 @@ TEST_F(LineageCacheTest, TestEviction) {
uint64_t lineage_size = max_lineage_size_ + 1;
size_t num_tasks_flushed = 0;
std::vector<Task> tasks;
InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector<ObjectID>(), 1);
InsertTaskChain(*lineage_cache_, tasks, lineage_size, std::vector<ObjectID>(), 1);
// Check that the last task in the chain still has all tasks in its
// uncommitted lineage.
const auto last_task_id = tasks.back().GetTaskSpecification().TaskId();
auto uncommitted_lineage =
lineage_cache_.GetUncommittedLineage(last_task_id, ClientID::Nil());
lineage_cache_->GetUncommittedLineage(last_task_id, ClientID::Nil());
ASSERT_EQ(uncommitted_lineage.GetEntries().size(), lineage_size);
// Simulate executing the first task on a remote node and adding it to the
// GCS.
auto task_data = std::make_shared<TaskTableData>();
auto it = tasks.begin();
RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data));
auto task_id = it->GetTaskSpecification().TaskId();
auto task_data = std::make_shared<TaskTableData>();
task_data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary());
RAY_CHECK_OK(mock_gcs_->MockTasks().RemoteAdd(task_data));
it++;
// Check that the remote task is flushed.
num_tasks_flushed++;
mock_gcs_.Flush();
ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed);
mock_gcs_->MockTasks().Flush();
ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed);
// Check that the last task in the chain still has all tasks in its
// uncommitted lineage.
ASSERT_EQ(uncommitted_lineage.GetEntries().size(), lineage_size);
ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(),
ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(),
lineage_size - num_tasks_flushed);
// Simulate executing all the rest of the tasks except the last one on a
// remote node and adding them to the GCS.
tasks.pop_back();
for (; it != tasks.end(); it++) {
RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data));
auto task_id = it->GetTaskSpecification().TaskId();
auto task_data = std::make_shared<TaskTableData>();
task_data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary());
RAY_CHECK_OK(mock_gcs_->MockTasks().RemoteAdd(task_data));
// Check that the remote task is flushed.
num_tasks_flushed++;
mock_gcs_.Flush();
ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed);
ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(),
mock_gcs_->MockTasks().Flush();
ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed);
ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(),
lineage_size - num_tasks_flushed);
}
// All tasks have now been flushed. Check that enough lineage has been
// evicted that the uncommitted lineage is now less than the maximum size.
uncommitted_lineage =
lineage_cache_.GetUncommittedLineage(last_task_id, ClientID::Nil());
lineage_cache_->GetUncommittedLineage(last_task_id, ClientID::Nil());
ASSERT_TRUE(uncommitted_lineage.GetEntries().size() < max_lineage_size_);
// The remaining task should have no uncommitted lineage.
ASSERT_EQ(uncommitted_lineage.GetEntries().size(), 1);
ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 1);
ASSERT_EQ(lineage_cache_->GetLineage().GetChildrenSize(), 1);
}
TEST_F(LineageCacheTest, TestOutOfOrderEviction) {
@ -483,38 +526,42 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) {
uint64_t lineage_size = (2 * max_lineage_size_) + 2;
size_t num_tasks_flushed = 0;
std::vector<Task> tasks;
InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector<ObjectID>(), 1);
InsertTaskChain(*lineage_cache_, tasks, lineage_size, std::vector<ObjectID>(), 1);
// Check that the last task in the chain still has all tasks in its
// uncommitted lineage.
const auto last_task_id = tasks.back().GetTaskSpecification().TaskId();
auto uncommitted_lineage =
lineage_cache_.GetUncommittedLineage(last_task_id, ClientID::Nil());
lineage_cache_->GetUncommittedLineage(last_task_id, ClientID::Nil());
ASSERT_EQ(uncommitted_lineage.GetEntries().size(), lineage_size);
ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), lineage_size);
ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), lineage_size);
// Simulate executing the rest of the tasks on a remote node and receiving
// the notifications from the GCS in reverse order of execution.
auto last_task = tasks.front();
tasks.erase(tasks.begin());
for (auto it = tasks.rbegin(); it != tasks.rend(); it++) {
auto task_id = it->GetTaskSpecification().TaskId();
auto task_data = std::make_shared<TaskTableData>();
RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data));
task_data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary());
RAY_CHECK_OK(mock_gcs_->MockTasks().RemoteAdd(task_data));
// Check that the remote task is flushed.
num_tasks_flushed++;
mock_gcs_.Flush();
ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed);
ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), lineage_size);
mock_gcs_->MockTasks().Flush();
ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed);
ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), lineage_size);
}
// Flush the last task. The lineage should not get evicted until this task's
// commit is received.
auto task_id = last_task.GetTaskSpecification().TaskId();
auto task_data = std::make_shared<TaskTableData>();
RAY_CHECK_OK(mock_gcs_.RemoteAdd(last_task.GetTaskSpecification().TaskId(), task_data));
task_data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary());
RAY_CHECK_OK(mock_gcs_->MockTasks().RemoteAdd(task_data));
num_tasks_flushed++;
mock_gcs_.Flush();
ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed);
ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 0);
ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0);
mock_gcs_->MockTasks().Flush();
ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed);
ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), 0);
ASSERT_EQ(lineage_cache_->GetLineage().GetChildrenSize(), 0);
}
TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) {
@ -522,7 +569,7 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) {
size_t num_tasks_flushed = 0;
uint64_t lineage_size = max_lineage_size_ + 1;
std::vector<Task> tasks;
InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector<ObjectID>(), 1);
InsertTaskChain(*lineage_cache_, tasks, lineage_size, std::vector<ObjectID>(), 1);
// Add more tasks to the lineage cache that will remain local. Each of these
// tasks is dependent one of the tasks that was forwarded above.
@ -530,9 +577,9 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) {
auto return_id = task.GetTaskSpecification().ReturnIdForPlasma(0);
auto dependent_task = ExampleTask({return_id}, 1);
auto lineage = CreateSingletonLineage(dependent_task);
lineage_cache_.AddUncommittedLineage(dependent_task.GetTaskSpecification().TaskId(),
lineage);
ASSERT_TRUE(lineage_cache_.CommitTask(dependent_task));
lineage_cache_->AddUncommittedLineage(dependent_task.GetTaskSpecification().TaskId(),
lineage);
ASSERT_TRUE(lineage_cache_->CommitTask(dependent_task));
// Once the forwarded tasks are evicted from the lineage cache, we expect
// each of these dependent tasks to be flushed, since all of their
// dependencies have been committed.
@ -544,50 +591,52 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) {
// until after the final remote task is executed, since a task can only be
// evicted once all of its ancestors have been committed.
for (auto it = tasks.rbegin(); it != tasks.rend(); it++) {
auto task_id = it->GetTaskSpecification().TaskId();
auto task_data = std::make_shared<TaskTableData>();
ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), lineage_size * 2);
RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data));
task_data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary());
ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), lineage_size * 2);
RAY_CHECK_OK(mock_gcs_->MockTasks().RemoteAdd(task_data));
num_tasks_flushed++;
mock_gcs_.Flush();
ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed);
mock_gcs_->MockTasks().Flush();
ASSERT_EQ(mock_gcs_->MockTasks().TaskTable().size(), num_tasks_flushed);
}
// Check that after the final remote task is executed, all local lineage is
// now evicted.
ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 0);
ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0);
ASSERT_EQ(lineage_cache_->GetLineage().GetEntries().size(), 0);
ASSERT_EQ(lineage_cache_->GetLineage().GetChildrenSize(), 0);
}
TEST_F(LineageCacheTest, TestFlushAllUncommittedTasks) {
// Insert a chain of tasks.
std::vector<Task> tasks;
auto return_values =
InsertTaskChain(lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
InsertTaskChain(*lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
std::vector<TaskID> task_ids;
for (const auto &task : tasks) {
task_ids.push_back(task.GetTaskSpecification().TaskId());
}
// Check that we subscribed to each of the uncommitted tasks.
ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size());
ASSERT_EQ(mock_gcs_->MockTasks().NumRequestedNotifications(), task_ids.size());
// Flush all uncommitted tasks and make sure we add all tasks to
// the task table.
lineage_cache_.FlushAllUncommittedTasks();
ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size());
lineage_cache_->FlushAllUncommittedTasks();
ASSERT_EQ(mock_gcs_->MockTasks().NumTaskAdds(), tasks.size());
// Flush again and make sure there are no new tasks added to the
// task table.
lineage_cache_.FlushAllUncommittedTasks();
ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size());
lineage_cache_->FlushAllUncommittedTasks();
ASSERT_EQ(mock_gcs_->MockTasks().NumTaskAdds(), tasks.size());
// Flush all GCS notifications.
mock_gcs_.Flush();
mock_gcs_->MockTasks().Flush();
// Make sure that we unsubscribed to the uncommitted tasks before
// we flushed them.
ASSERT_EQ(num_notifications_, 0);
// Flush again and make sure there are no new tasks added to the
// task table.
lineage_cache_.FlushAllUncommittedTasks();
ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size());
lineage_cache_->FlushAllUncommittedTasks();
ASSERT_EQ(mock_gcs_->MockTasks().NumTaskAdds(), tasks.size());
}
} // namespace raylet

View file

@ -102,9 +102,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
gcs_client_->client_table().GetLocalClientId(),
RayConfig::instance().initial_reconstruction_timeout_milliseconds(),
gcs_client_->task_lease_table()),
lineage_cache_(gcs_client_->client_table().GetLocalClientId(),
gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(),
config.max_lineage_size),
lineage_cache_(gcs_client_, config.max_lineage_size),
actor_registry_(),
node_manager_server_("NodeManager", config.node_manager_port),
node_manager_service_(io_service, *this),
@ -140,18 +138,6 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
ray::Status NodeManager::RegisterGcs() {
object_manager_.RegisterGcs();
// Subscribe to task entry commits in the GCS. These notifications are
// forwarded to the lineage cache, which requests notifications about tasks
// that were executed remotely.
const auto task_committed_callback = [this](gcs::RedisGcsClient *client,
const TaskID &task_id,
const TaskTableData &task_data) {
lineage_cache_.HandleEntryCommitted(task_id);
};
RAY_RETURN_NOT_OK(gcs_client_->raylet_task_table().Subscribe(
JobID::Nil(), gcs_client_->client_table().GetLocalClientId(),
task_committed_callback, nullptr, nullptr));
const auto task_lease_notification_callback = [this](gcs::RedisGcsClient *client,
const TaskID &task_id,
const TaskLeaseData &task_lease) {
@ -984,7 +970,7 @@ void NodeManager::ProcessClientMessage(
for (const auto &object_id : object_ids) {
creating_task_ids.push_back(object_id.TaskId());
}
gcs_client_->raylet_task_table().Delete(JobID::Nil(), creating_task_ids);
RAY_CHECK_OK(gcs_client_->Tasks().AsyncDelete(creating_task_ids, nullptr));
}
} break;
case protocol::MessageType::PrepareActorCheckpointRequest: {
@ -2437,27 +2423,25 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) {
auto parent_task_id = task_spec.ParentTaskId();
int port = worker.Port();
RAY_CHECK_OK(
gcs_client_->raylet_task_table().Lookup(
JobID::Nil(), parent_task_id,
/*success_callback=*/
[this, task_spec, resumed_from_checkpoint, port](
ray::gcs::RedisGcsClient *client, const TaskID &parent_task_id,
const TaskTableData &parent_task_data) {
// The task was in the GCS task table. Use the stored task spec to
// get the parent actor id.
Task parent_task(parent_task_data.task());
ActorID parent_actor_id = ActorID::Nil();
if (parent_task.GetTaskSpecification().IsActorCreationTask()) {
parent_actor_id = parent_task.GetTaskSpecification().ActorCreationId();
} else if (parent_task.GetTaskSpecification().IsActorTask()) {
parent_actor_id = parent_task.GetTaskSpecification().ActorId();
gcs_client_->Tasks().AsyncGet(
parent_task_id,
/*callback=*/
[this, task_spec, resumed_from_checkpoint, port, parent_task_id](
Status status, const boost::optional<TaskTableData> &parent_task_data) {
if (parent_task_data) {
// The task was in the GCS task table. Use the stored task spec to
// get the parent actor id.
Task parent_task(parent_task_data->task());
ActorID parent_actor_id = ActorID::Nil();
if (parent_task.GetTaskSpecification().IsActorCreationTask()) {
parent_actor_id = parent_task.GetTaskSpecification().ActorCreationId();
} else if (parent_task.GetTaskSpecification().IsActorTask()) {
parent_actor_id = parent_task.GetTaskSpecification().ActorId();
}
FinishAssignedActorCreationTask(parent_actor_id, task_spec,
resumed_from_checkpoint, port);
return;
}
FinishAssignedActorCreationTask(parent_actor_id, task_spec,
resumed_from_checkpoint, port);
},
/*failure_callback=*/
[this, task_spec, resumed_from_checkpoint, port](
ray::gcs::RedisGcsClient *client, const TaskID &parent_task_id) {
// The parent task was not in the GCS task table. It should most likely be
// in the lineage cache.
ActorID parent_actor_id = ActorID::Nil();
@ -2574,18 +2558,17 @@ void NodeManager::FinishAssignedActorCreationTask(const ActorID &parent_actor_id
void NodeManager::HandleTaskReconstruction(const TaskID &task_id,
const ObjectID &required_object_id) {
// Retrieve the task spec in order to re-execute the task.
RAY_CHECK_OK(gcs_client_->raylet_task_table().Lookup(
JobID::Nil(), task_id,
/*success_callback=*/
[this, required_object_id](ray::gcs::RedisGcsClient *client, const TaskID &task_id,
const TaskTableData &task_data) {
// The task was in the GCS task table. Use the stored task spec to
// re-execute the task.
ResubmitTask(Task(task_data.task()), required_object_id);
},
/*failure_callback=*/
[this, required_object_id](ray::gcs::RedisGcsClient *client,
const TaskID &task_id) {
RAY_CHECK_OK(gcs_client_->Tasks().AsyncGet(
task_id,
/*callback=*/
[this, required_object_id, task_id](
Status status, const boost::optional<TaskTableData> &task_data) {
if (task_data) {
// The task was in the GCS task table. Use the stored task spec to
// re-execute the task.
ResubmitTask(Task(task_data->task()), required_object_id);
return;
}
// The task was not in the GCS task table. It must therefore be in the
// lineage cache.
if (lineage_cache_.ContainsTask(task_id)) {