remove id.h dependence for c++ worker headers (#16055)

This commit is contained in:
SongGuyang 2021-05-26 11:56:24 +08:00 committed by GitHub
parent 08de5a36e1
commit 7c3874b38e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 134 additions and 82 deletions

View file

@ -78,7 +78,7 @@ class Ray {
static std::once_flag is_inited_; static std::once_flag is_inited_;
template <typename T> template <typename T>
static std::vector<std::shared_ptr<T>> Get(const std::vector<ObjectID> &ids); static std::vector<std::shared_ptr<T>> Get(const std::vector<std::string> &ids);
template <typename FuncType> template <typename FuncType>
static TaskCaller<FuncType> TaskInternal(FuncType &func); static TaskCaller<FuncType> TaskInternal(FuncType &func);
@ -96,9 +96,9 @@ namespace ray {
namespace api { namespace api {
template <typename T> template <typename T>
inline static std::vector<ObjectID> ObjectRefsToObjectIDs( inline static std::vector<std::string> ObjectRefsToObjectIDs(
const std::vector<ObjectRef<T>> &object_refs) { const std::vector<ObjectRef<T>> &object_refs) {
std::vector<ObjectID> object_ids; std::vector<std::string> object_ids;
for (auto it = object_refs.begin(); it != object_refs.end(); it++) { for (auto it = object_refs.begin(); it != object_refs.end(); it++) {
object_ids.push_back(it->ID()); object_ids.push_back(it->ID());
} }
@ -118,7 +118,7 @@ inline std::shared_ptr<T> Ray::Get(const ObjectRef<T> &object) {
} }
template <typename T> template <typename T>
inline std::vector<std::shared_ptr<T>> Ray::Get(const std::vector<ObjectID> &ids) { inline std::vector<std::shared_ptr<T>> Ray::Get(const std::vector<std::string> &ids) {
auto result = ray::internal::RayRuntime()->Get(ids); auto result = ray::internal::RayRuntime()->Get(ids);
std::vector<std::shared_ptr<T>> return_objects; std::vector<std::shared_ptr<T>> return_objects;
return_objects.reserve(result.size()); return_objects.reserve(result.size());

View file

@ -21,10 +21,10 @@ class ActorHandle {
public: public:
ActorHandle(); ActorHandle();
ActorHandle(const ActorID &id); ActorHandle(const std::string &id);
/// Get a untyped ID of the actor /// Get a untyped ID of the actor
const ActorID &ID() const; const std::string &ID() const;
/// Include the `Call` methods for calling remote functions. /// Include the `Call` methods for calling remote functions.
template <typename F> template <typename F>
@ -34,7 +34,7 @@ class ActorHandle {
MSGPACK_DEFINE(id_); MSGPACK_DEFINE(id_);
private: private:
ActorID id_; std::string id_;
}; };
// ---------- implementation ---------- // ---------- implementation ----------
@ -42,12 +42,12 @@ template <typename ActorType>
ActorHandle<ActorType>::ActorHandle() {} ActorHandle<ActorType>::ActorHandle() {}
template <typename ActorType> template <typename ActorType>
ActorHandle<ActorType>::ActorHandle(const ActorID &id) { ActorHandle<ActorType>::ActorHandle(const std::string &id) {
id_ = id; id_ = id;
} }
template <typename ActorType> template <typename ActorType>
const ActorID &ActorHandle<ActorType>::ID() const { const std::string &ActorHandle<ActorType>::ID() const {
return id_; return id_;
} }

View file

@ -14,7 +14,7 @@ class ActorTaskCaller {
public: public:
ActorTaskCaller() = default; ActorTaskCaller() = default;
ActorTaskCaller(RayRuntime *runtime, ActorID id, ActorTaskCaller(RayRuntime *runtime, std::string id,
RemoteFunctionHolder remote_function_holder) RemoteFunctionHolder remote_function_holder)
: runtime_(runtime), id_(id), remote_function_holder_(remote_function_holder) {} : runtime_(runtime), id_(id), remote_function_holder_(remote_function_holder) {}
@ -23,7 +23,7 @@ class ActorTaskCaller {
private: private:
RayRuntime *runtime_; RayRuntime *runtime_;
ActorID id_; std::string id_;
RemoteFunctionHolder remote_function_holder_; RemoteFunctionHolder remote_function_holder_;
std::vector<ray::api::TaskArg> args_; std::vector<ray::api::TaskArg> args_;
}; };

View file

@ -41,7 +41,7 @@ struct TaskArg {
/// If the buf is initialized shows it is a value argument. /// If the buf is initialized shows it is a value argument.
boost::optional<msgpack::sbuffer> buf; boost::optional<msgpack::sbuffer> buf;
/// If the id is initialized shows it is a reference argument. /// If the id is initialized shows it is a reference argument.
boost::optional<ObjectID> id; boost::optional<std::string> id;
}; };
} // namespace api } // namespace api

View file

@ -28,12 +28,14 @@ inline void CheckResult(const std::shared_ptr<msgpack::sbuffer> &packed_object)
} }
} }
inline void CopyAndAddRefrence(ObjectID &dest_id, const ObjectID &id) { inline void CopyAndAddReference(std::string &dest_id, const std::string &id) {
dest_id = id; dest_id = id;
AddLocalReference(id); ray::internal::RayRuntime()->AddLocalReference(id);
} }
inline void SubRefrence(const ObjectID &id) { RemoveLocalReference(id); } inline void SubReference(const std::string &id) {
ray::internal::RayRuntime()->RemoveLocalReference(id);
}
/// Represents an object in the object store.. /// Represents an object in the object store..
/// \param T The type of object. /// \param T The type of object.
@ -43,19 +45,19 @@ class ObjectRef {
ObjectRef(); ObjectRef();
~ObjectRef(); ~ObjectRef();
ObjectRef(const ObjectRef &rhs) { CopyAndAddRefrence(id_, rhs.id_); } ObjectRef(const ObjectRef &rhs) { CopyAndAddReference(id_, rhs.id_); }
ObjectRef &operator=(const ObjectRef &rhs) { ObjectRef &operator=(const ObjectRef &rhs) {
CopyAndAddRefrence(id_, rhs.id_); CopyAndAddReference(id_, rhs.id_);
return *this; return *this;
} }
ObjectRef(const ObjectID &id); ObjectRef(const std::string &id);
bool operator==(const ObjectRef<T> &object) const; bool operator==(const ObjectRef<T> &object) const;
/// Get a untyped ID of the object /// Get a untyped ID of the object
const ObjectID &ID() const; const std::string &ID() const;
/// Get the object from the object store. /// Get the object from the object store.
/// This method will be blocked until the object is ready. /// This method will be blocked until the object is ready.
@ -67,7 +69,7 @@ class ObjectRef {
MSGPACK_DEFINE(id_); MSGPACK_DEFINE(id_);
private: private:
ObjectID id_; std::string id_;
}; };
// ---------- implementation ---------- // ---------- implementation ----------
@ -84,13 +86,13 @@ template <typename T>
ObjectRef<T>::ObjectRef() {} ObjectRef<T>::ObjectRef() {}
template <typename T> template <typename T>
ObjectRef<T>::ObjectRef(const ObjectID &id) { ObjectRef<T>::ObjectRef(const std::string &id) {
CopyAndAddRefrence(id_, id); CopyAndAddReference(id_, id);
} }
template <typename T> template <typename T>
ObjectRef<T>::~ObjectRef() { ObjectRef<T>::~ObjectRef() {
SubRefrence(id_); SubReference(id_);
} }
template <typename T> template <typename T>
@ -99,7 +101,7 @@ inline bool ObjectRef<T>::operator==(const ObjectRef<T> &object) const {
} }
template <typename T> template <typename T>
const ObjectID &ObjectRef<T>::ID() const { const std::string &ObjectRef<T>::ID() const {
return id_; return id_;
} }
@ -112,21 +114,21 @@ template <>
class ObjectRef<void> { class ObjectRef<void> {
public: public:
ObjectRef() = default; ObjectRef() = default;
~ObjectRef() { SubRefrence(id_); } ~ObjectRef() { SubReference(id_); }
ObjectRef(const ObjectRef &rhs) { CopyAndAddRefrence(id_, rhs.id_); } ObjectRef(const ObjectRef &rhs) { CopyAndAddReference(id_, rhs.id_); }
ObjectRef &operator=(const ObjectRef &rhs) { ObjectRef &operator=(const ObjectRef &rhs) {
CopyAndAddRefrence(id_, rhs.id_); CopyAndAddReference(id_, rhs.id_);
return *this; return *this;
} }
ObjectRef(const ObjectID &id) { CopyAndAddRefrence(id_, id); } ObjectRef(const std::string &id) { CopyAndAddReference(id_, id); }
bool operator==(const ObjectRef<void> &object) const { return id_ == object.id_; } bool operator==(const ObjectRef<void> &object) const { return id_ == object.id_; }
/// Get a untyped ID of the object /// Get a untyped ID of the object
const ObjectID &ID() const { return id_; } const std::string &ID() const { return id_; }
/// Get the object from the object store. /// Get the object from the object store.
/// This method will be blocked until the object is ready. /// This method will be blocked until the object is ready.
@ -141,7 +143,7 @@ class ObjectRef<void> {
MSGPACK_DEFINE(id_); MSGPACK_DEFINE(id_);
private: private:
ObjectID id_; std::string id_;
}; };
} // namespace api } // namespace api
} // namespace ray } // namespace ray

View file

@ -32,27 +32,24 @@ struct RemoteFunctionHolder {
class RayRuntime { class RayRuntime {
public: public:
virtual ObjectID Put(std::shared_ptr<msgpack::sbuffer> data) = 0; virtual std::string Put(std::shared_ptr<msgpack::sbuffer> data) = 0;
virtual std::shared_ptr<msgpack::sbuffer> Get(const ObjectID &id) = 0; virtual std::shared_ptr<msgpack::sbuffer> Get(const std::string &id) = 0;
virtual std::vector<std::shared_ptr<msgpack::sbuffer>> Get( virtual std::vector<std::shared_ptr<msgpack::sbuffer>> Get(
const std::vector<ObjectID> &ids) = 0; const std::vector<std::string> &ids) = 0;
virtual std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects, virtual std::vector<bool> Wait(const std::vector<std::string> &ids, int num_objects,
int timeout_ms) = 0; int timeout_ms) = 0;
virtual ObjectID Call(const RemoteFunctionHolder &remote_function_holder, virtual std::string Call(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args) = 0; std::vector<ray::api::TaskArg> &args) = 0;
virtual ActorID CreateActor(const RemoteFunctionHolder &remote_function_holder, virtual std::string CreateActor(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args) = 0; std::vector<ray::api::TaskArg> &args) = 0;
virtual ObjectID CallActor(const RemoteFunctionHolder &remote_function_holder, virtual std::string CallActor(const RemoteFunctionHolder &remote_function_holder,
const ActorID &actor, const std::string &actor,
std::vector<ray::api::TaskArg> &args) = 0; std::vector<ray::api::TaskArg> &args) = 0;
virtual void AddLocalReference(const std::string &id) = 0;
virtual void RemoveLocalReference(const std::string &id) = 0;
}; };
void AddLocalReference(const ObjectID &id);
void RemoveLocalReference(const ObjectID &id);
} // namespace api } // namespace api
} // namespace ray } // namespace ray

View file

@ -3,7 +3,6 @@
#include "ray/common/buffer.h" #include "ray/common/buffer.h"
#include "ray/common/function_descriptor.h" #include "ray/common/function_descriptor.h"
#include "ray/common/id.h"
#include "ray/common/status.h" #include "ray/common/status.h"
#include "ray/common/task/task_common.h" #include "ray/common/task/task_common.h"
#include "ray/common/task/task_spec.h" #include "ray/common/task/task_spec.h"

View file

@ -62,25 +62,34 @@ void AbstractRayRuntime::Put(std::shared_ptr<msgpack::sbuffer> data,
object_store_->Put(data, object_id); object_store_->Put(data, object_id);
} }
ObjectID AbstractRayRuntime::Put(std::shared_ptr<msgpack::sbuffer> data) { std::string AbstractRayRuntime::Put(std::shared_ptr<msgpack::sbuffer> data) {
ObjectID object_id = ObjectID object_id =
ObjectID::FromIndex(worker_->GetCurrentTaskID(), worker_->GetNextPutIndex()); ObjectID::FromIndex(worker_->GetCurrentTaskID(), worker_->GetNextPutIndex());
Put(data, &object_id); Put(data, &object_id);
return object_id; return object_id.Binary();
} }
std::shared_ptr<msgpack::sbuffer> AbstractRayRuntime::Get(const ObjectID &object_id) { std::shared_ptr<msgpack::sbuffer> AbstractRayRuntime::Get(const std::string &object_id) {
return object_store_->Get(object_id, -1); return object_store_->Get(ObjectID::FromBinary(object_id), -1);
}
inline static std::vector<ObjectID> StringIDsToObjectIDs(
const std::vector<std::string> &ids) {
std::vector<ObjectID> object_ids;
for (std::string id : ids) {
object_ids.push_back(ObjectID::FromBinary(id));
}
return object_ids;
} }
std::vector<std::shared_ptr<msgpack::sbuffer>> AbstractRayRuntime::Get( std::vector<std::shared_ptr<msgpack::sbuffer>> AbstractRayRuntime::Get(
const std::vector<ObjectID> &ids) { const std::vector<std::string> &ids) {
return object_store_->Get(ids, -1); return object_store_->Get(StringIDsToObjectIDs(ids), -1);
} }
std::vector<bool> AbstractRayRuntime::Wait(const std::vector<ObjectID> &ids, std::vector<bool> AbstractRayRuntime::Wait(const std::vector<std::string> &ids,
int num_objects, int timeout_ms) { int num_objects, int timeout_ms) {
return object_store_->Wait(ids, num_objects, timeout_ms); return object_store_->Wait(StringIDsToObjectIDs(ids), num_objects, timeout_ms);
} }
std::vector<std::unique_ptr<::ray::TaskArg>> TransformArgs( std::vector<std::unique_ptr<::ray::TaskArg>> TransformArgs(
@ -96,7 +105,8 @@ std::vector<std::unique_ptr<::ray::TaskArg>> TransformArgs(
memory_buffer, nullptr, std::vector<ObjectID>())); memory_buffer, nullptr, std::vector<ObjectID>()));
} else { } else {
RAY_CHECK(arg.id); RAY_CHECK(arg.id);
ray_arg = absl::make_unique<ray::TaskArgByReference>(*arg.id, ray::rpc::Address{}); ray_arg = absl::make_unique<ray::TaskArgByReference>(ObjectID::FromBinary(*arg.id),
ray::rpc::Address{});
} }
ray_args.push_back(std::move(ray_arg)); ray_args.push_back(std::move(ray_arg));
} }
@ -119,29 +129,30 @@ InvocationSpec BuildInvocationSpec1(TaskType task_type, std::string lib_name,
return invocation_spec; return invocation_spec;
} }
ObjectID AbstractRayRuntime::Call(const RemoteFunctionHolder &remote_function_holder, std::string AbstractRayRuntime::Call(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args) { std::vector<ray::api::TaskArg> &args) {
auto invocation_spec = auto invocation_spec =
BuildInvocationSpec1(TaskType::NORMAL_TASK, this->config_->lib_name, BuildInvocationSpec1(TaskType::NORMAL_TASK, this->config_->lib_name,
remote_function_holder, args, ActorID::Nil()); remote_function_holder, args, ActorID::Nil());
return task_submitter_->SubmitTask(invocation_spec); return task_submitter_->SubmitTask(invocation_spec).Binary();
} }
ActorID AbstractRayRuntime::CreateActor( std::string AbstractRayRuntime::CreateActor(
const RemoteFunctionHolder &remote_function_holder, const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args) { std::vector<ray::api::TaskArg> &args) {
auto invocation_spec = auto invocation_spec =
BuildInvocationSpec1(TaskType::ACTOR_CREATION_TASK, this->config_->lib_name, BuildInvocationSpec1(TaskType::ACTOR_CREATION_TASK, this->config_->lib_name,
remote_function_holder, args, ActorID::Nil()); remote_function_holder, args, ActorID::Nil());
return task_submitter_->CreateActor(invocation_spec); return task_submitter_->CreateActor(invocation_spec).Binary();
} }
ObjectID AbstractRayRuntime::CallActor(const RemoteFunctionHolder &remote_function_holder, std::string AbstractRayRuntime::CallActor(
const ActorID &actor, const RemoteFunctionHolder &remote_function_holder, const std::string &actor,
std::vector<ray::api::TaskArg> &args) { std::vector<ray::api::TaskArg> &args) {
auto invocation_spec = BuildInvocationSpec1( auto invocation_spec =
TaskType::ACTOR_TASK, this->config_->lib_name, remote_function_holder, args, actor); BuildInvocationSpec1(TaskType::ACTOR_TASK, this->config_->lib_name,
return task_submitter_->SubmitActorTask(invocation_spec); remote_function_holder, args, ActorID::FromBinary(actor));
return task_submitter_->SubmitActorTask(invocation_spec).Binary();
} }
const TaskID &AbstractRayRuntime::GetCurrentTaskId() { const TaskID &AbstractRayRuntime::GetCurrentTaskId() {
@ -154,17 +165,17 @@ const std::unique_ptr<WorkerContext> &AbstractRayRuntime::GetWorkerContext() {
return worker_; return worker_;
} }
void AddLocalReference(const ObjectID &id) { void AbstractRayRuntime::AddLocalReference(const std::string &id) {
if (CoreWorkerProcess::IsInitialized()) { if (CoreWorkerProcess::IsInitialized()) {
auto &core_worker = CoreWorkerProcess::GetCoreWorker(); auto &core_worker = CoreWorkerProcess::GetCoreWorker();
core_worker.AddLocalReference(id); core_worker.AddLocalReference(ObjectID::FromBinary(id));
} }
} }
void RemoveLocalReference(const ObjectID &id) { void AbstractRayRuntime::RemoveLocalReference(const std::string &id) {
if (CoreWorkerProcess::IsInitialized()) { if (CoreWorkerProcess::IsInitialized()) {
auto &core_worker = CoreWorkerProcess::GetCoreWorker(); auto &core_worker = CoreWorkerProcess::GetCoreWorker();
core_worker.RemoveLocalReference(id); core_worker.RemoveLocalReference(ObjectID::FromBinary(id));
} }
} }

View file

@ -23,23 +23,27 @@ class AbstractRayRuntime : public RayRuntime {
void Put(std::shared_ptr<msgpack::sbuffer> data, const ObjectID &object_id); void Put(std::shared_ptr<msgpack::sbuffer> data, const ObjectID &object_id);
ObjectID Put(std::shared_ptr<msgpack::sbuffer> data); std::string Put(std::shared_ptr<msgpack::sbuffer> data);
std::shared_ptr<msgpack::sbuffer> Get(const ObjectID &id); std::shared_ptr<msgpack::sbuffer> Get(const std::string &id);
std::vector<std::shared_ptr<msgpack::sbuffer>> Get(const std::vector<ObjectID> &ids); std::vector<std::shared_ptr<msgpack::sbuffer>> Get(const std::vector<std::string> &ids);
std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects, std::vector<bool> Wait(const std::vector<std::string> &ids, int num_objects,
int timeout_ms); int timeout_ms);
ObjectID Call(const RemoteFunctionHolder &remote_function_holder, std::string Call(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args); std::vector<ray::api::TaskArg> &args);
ActorID CreateActor(const RemoteFunctionHolder &remote_function_holder, std::string CreateActor(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args); std::vector<ray::api::TaskArg> &args);
ObjectID CallActor(const RemoteFunctionHolder &remote_function_holder, std::string CallActor(const RemoteFunctionHolder &remote_function_holder,
const ActorID &actor, std::vector<ray::api::TaskArg> &args); const std::string &actor, std::vector<ray::api::TaskArg> &args);
void AddLocalReference(const std::string &id);
void RemoveLocalReference(const std::string &id);
const TaskID &GetCurrentTaskId(); const TaskID &GetCurrentTaskId();

View file

@ -88,5 +88,9 @@ std::vector<bool> LocalModeObjectStore::Wait(const std::vector<ObjectID> &ids,
} }
return result; return result;
} }
void LocalModeObjectStore::AddLocalReference(const std::string &id) { return; }
void LocalModeObjectStore::RemoveLocalReference(const std::string &id) { return; }
} // namespace api } // namespace api
} // namespace ray } // namespace ray

View file

@ -17,6 +17,10 @@ class LocalModeObjectStore : public ObjectStore {
std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects, std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects,
int timeout_ms); int timeout_ms);
void AddLocalReference(const std::string &id);
void RemoveLocalReference(const std::string &id);
private: private:
void PutRaw(std::shared_ptr<msgpack::sbuffer> data, ObjectID *object_id); void PutRaw(std::shared_ptr<msgpack::sbuffer> data, ObjectID *object_id);

View file

@ -83,5 +83,19 @@ std::vector<bool> NativeObjectStore::Wait(const std::vector<ObjectID> &ids,
} }
return results; return results;
} }
void NativeObjectStore::AddLocalReference(const std::string &id) {
if (CoreWorkerProcess::IsInitialized()) {
auto &core_worker = CoreWorkerProcess::GetCoreWorker();
core_worker.AddLocalReference(ObjectID::FromBinary(id));
}
}
void NativeObjectStore::RemoveLocalReference(const std::string &id) {
if (CoreWorkerProcess::IsInitialized()) {
auto &core_worker = CoreWorkerProcess::GetCoreWorker();
core_worker.RemoveLocalReference(ObjectID::FromBinary(id));
}
}
} // namespace api } // namespace api
} // namespace ray } // namespace ray

View file

@ -17,6 +17,10 @@ class NativeObjectStore : public ObjectStore {
std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects, std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects,
int timeout_ms); int timeout_ms);
void AddLocalReference(const std::string &id);
void RemoveLocalReference(const std::string &id);
private: private:
void PutRaw(std::shared_ptr<msgpack::sbuffer> data, ObjectID *object_id); void PutRaw(std::shared_ptr<msgpack::sbuffer> data, ObjectID *object_id);

View file

@ -56,6 +56,19 @@ class ObjectStore {
virtual std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects, virtual std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects,
int timeout_ms) = 0; int timeout_ms) = 0;
/// Increase the reference count for this object ID.
/// Increase the local reference count for this object ID. Should be called
/// by the language frontend when a new reference is created.
///
/// \param[in] id The binary string ID to increase the reference count for.
virtual void AddLocalReference(const std::string &id) = 0;
/// Decrease the reference count for this object ID. Should be called
/// by the language frontend when a reference is destroyed.
///
/// \param[in] id The binary string ID to decrease the reference count for.
virtual void RemoveLocalReference(const std::string &id) = 0;
private: private:
virtual void PutRaw(std::shared_ptr<msgpack::sbuffer> data, ObjectID *object_id) = 0; virtual void PutRaw(std::shared_ptr<msgpack::sbuffer> data, ObjectID *object_id) = 0;

View file

@ -148,7 +148,7 @@ void TaskExecutor::Invoke(
std::vector<msgpack::sbuffer> args_buffer; std::vector<msgpack::sbuffer> args_buffer;
for (size_t i = 0; i < task_spec.NumArgs(); i++) { for (size_t i = 0; i < task_spec.NumArgs(); i++) {
if (task_spec.ArgByRef(i)) { if (task_spec.ArgByRef(i)) {
auto arg = runtime->Get(task_spec.ArgId(i)); auto arg = runtime->Get(task_spec.ArgId(i).Binary());
args_buffer.push_back(std::move(*arg)); args_buffer.push_back(std::move(*arg));
} else { } else {
msgpack::sbuffer sbuf; msgpack::sbuffer sbuf;