remove id.h dependence for c++ worker headers (#16055)

This commit is contained in:
SongGuyang 2021-05-26 11:56:24 +08:00 committed by GitHub
parent 08de5a36e1
commit 7c3874b38e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 134 additions and 82 deletions

View file

@ -78,7 +78,7 @@ class Ray {
static std::once_flag is_inited_;
template <typename T>
static std::vector<std::shared_ptr<T>> Get(const std::vector<ObjectID> &ids);
static std::vector<std::shared_ptr<T>> Get(const std::vector<std::string> &ids);
template <typename FuncType>
static TaskCaller<FuncType> TaskInternal(FuncType &func);
@ -96,9 +96,9 @@ namespace ray {
namespace api {
template <typename T>
inline static std::vector<ObjectID> ObjectRefsToObjectIDs(
inline static std::vector<std::string> ObjectRefsToObjectIDs(
const std::vector<ObjectRef<T>> &object_refs) {
std::vector<ObjectID> object_ids;
std::vector<std::string> object_ids;
for (auto it = object_refs.begin(); it != object_refs.end(); it++) {
object_ids.push_back(it->ID());
}
@ -118,7 +118,7 @@ inline std::shared_ptr<T> Ray::Get(const ObjectRef<T> &object) {
}
template <typename T>
inline std::vector<std::shared_ptr<T>> Ray::Get(const std::vector<ObjectID> &ids) {
inline std::vector<std::shared_ptr<T>> Ray::Get(const std::vector<std::string> &ids) {
auto result = ray::internal::RayRuntime()->Get(ids);
std::vector<std::shared_ptr<T>> return_objects;
return_objects.reserve(result.size());

View file

@ -21,10 +21,10 @@ class ActorHandle {
public:
ActorHandle();
ActorHandle(const ActorID &id);
ActorHandle(const std::string &id);
/// Get a untyped ID of the actor
const ActorID &ID() const;
const std::string &ID() const;
/// Include the `Call` methods for calling remote functions.
template <typename F>
@ -34,7 +34,7 @@ class ActorHandle {
MSGPACK_DEFINE(id_);
private:
ActorID id_;
std::string id_;
};
// ---------- implementation ----------
@ -42,12 +42,12 @@ template <typename ActorType>
ActorHandle<ActorType>::ActorHandle() {}
template <typename ActorType>
ActorHandle<ActorType>::ActorHandle(const ActorID &id) {
ActorHandle<ActorType>::ActorHandle(const std::string &id) {
id_ = id;
}
template <typename ActorType>
const ActorID &ActorHandle<ActorType>::ID() const {
const std::string &ActorHandle<ActorType>::ID() const {
return id_;
}

View file

@ -14,7 +14,7 @@ class ActorTaskCaller {
public:
ActorTaskCaller() = default;
ActorTaskCaller(RayRuntime *runtime, ActorID id,
ActorTaskCaller(RayRuntime *runtime, std::string id,
RemoteFunctionHolder remote_function_holder)
: runtime_(runtime), id_(id), remote_function_holder_(remote_function_holder) {}
@ -23,7 +23,7 @@ class ActorTaskCaller {
private:
RayRuntime *runtime_;
ActorID id_;
std::string id_;
RemoteFunctionHolder remote_function_holder_;
std::vector<ray::api::TaskArg> args_;
};

View file

@ -41,7 +41,7 @@ struct TaskArg {
/// If the buf is initialized shows it is a value argument.
boost::optional<msgpack::sbuffer> buf;
/// If the id is initialized shows it is a reference argument.
boost::optional<ObjectID> id;
boost::optional<std::string> id;
};
} // namespace api

View file

@ -28,12 +28,14 @@ inline void CheckResult(const std::shared_ptr<msgpack::sbuffer> &packed_object)
}
}
inline void CopyAndAddRefrence(ObjectID &dest_id, const ObjectID &id) {
inline void CopyAndAddReference(std::string &dest_id, const std::string &id) {
dest_id = id;
AddLocalReference(id);
ray::internal::RayRuntime()->AddLocalReference(id);
}
inline void SubRefrence(const ObjectID &id) { RemoveLocalReference(id); }
inline void SubReference(const std::string &id) {
ray::internal::RayRuntime()->RemoveLocalReference(id);
}
/// Represents an object in the object store..
/// \param T The type of object.
@ -43,19 +45,19 @@ class ObjectRef {
ObjectRef();
~ObjectRef();
ObjectRef(const ObjectRef &rhs) { CopyAndAddRefrence(id_, rhs.id_); }
ObjectRef(const ObjectRef &rhs) { CopyAndAddReference(id_, rhs.id_); }
ObjectRef &operator=(const ObjectRef &rhs) {
CopyAndAddRefrence(id_, rhs.id_);
CopyAndAddReference(id_, rhs.id_);
return *this;
}
ObjectRef(const ObjectID &id);
ObjectRef(const std::string &id);
bool operator==(const ObjectRef<T> &object) const;
/// Get a untyped ID of the object
const ObjectID &ID() const;
const std::string &ID() const;
/// Get the object from the object store.
/// This method will be blocked until the object is ready.
@ -67,7 +69,7 @@ class ObjectRef {
MSGPACK_DEFINE(id_);
private:
ObjectID id_;
std::string id_;
};
// ---------- implementation ----------
@ -84,13 +86,13 @@ template <typename T>
ObjectRef<T>::ObjectRef() {}
template <typename T>
ObjectRef<T>::ObjectRef(const ObjectID &id) {
CopyAndAddRefrence(id_, id);
ObjectRef<T>::ObjectRef(const std::string &id) {
CopyAndAddReference(id_, id);
}
template <typename T>
ObjectRef<T>::~ObjectRef() {
SubRefrence(id_);
SubReference(id_);
}
template <typename T>
@ -99,7 +101,7 @@ inline bool ObjectRef<T>::operator==(const ObjectRef<T> &object) const {
}
template <typename T>
const ObjectID &ObjectRef<T>::ID() const {
const std::string &ObjectRef<T>::ID() const {
return id_;
}
@ -112,21 +114,21 @@ template <>
class ObjectRef<void> {
public:
ObjectRef() = default;
~ObjectRef() { SubRefrence(id_); }
~ObjectRef() { SubReference(id_); }
ObjectRef(const ObjectRef &rhs) { CopyAndAddRefrence(id_, rhs.id_); }
ObjectRef(const ObjectRef &rhs) { CopyAndAddReference(id_, rhs.id_); }
ObjectRef &operator=(const ObjectRef &rhs) {
CopyAndAddRefrence(id_, rhs.id_);
CopyAndAddReference(id_, rhs.id_);
return *this;
}
ObjectRef(const ObjectID &id) { CopyAndAddRefrence(id_, id); }
ObjectRef(const std::string &id) { CopyAndAddReference(id_, id); }
bool operator==(const ObjectRef<void> &object) const { return id_ == object.id_; }
/// Get a untyped ID of the object
const ObjectID &ID() const { return id_; }
const std::string &ID() const { return id_; }
/// Get the object from the object store.
/// This method will be blocked until the object is ready.
@ -141,7 +143,7 @@ class ObjectRef<void> {
MSGPACK_DEFINE(id_);
private:
ObjectID id_;
std::string id_;
};
} // namespace api
} // namespace ray

View file

@ -32,27 +32,24 @@ struct RemoteFunctionHolder {
class RayRuntime {
public:
virtual ObjectID Put(std::shared_ptr<msgpack::sbuffer> data) = 0;
virtual std::shared_ptr<msgpack::sbuffer> Get(const ObjectID &id) = 0;
virtual std::string Put(std::shared_ptr<msgpack::sbuffer> data) = 0;
virtual std::shared_ptr<msgpack::sbuffer> Get(const std::string &id) = 0;
virtual std::vector<std::shared_ptr<msgpack::sbuffer>> Get(
const std::vector<ObjectID> &ids) = 0;
const std::vector<std::string> &ids) = 0;
virtual std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects,
virtual std::vector<bool> Wait(const std::vector<std::string> &ids, int num_objects,
int timeout_ms) = 0;
virtual ObjectID Call(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args) = 0;
virtual ActorID CreateActor(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args) = 0;
virtual ObjectID CallActor(const RemoteFunctionHolder &remote_function_holder,
const ActorID &actor,
std::vector<ray::api::TaskArg> &args) = 0;
virtual std::string Call(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args) = 0;
virtual std::string CreateActor(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args) = 0;
virtual std::string CallActor(const RemoteFunctionHolder &remote_function_holder,
const std::string &actor,
std::vector<ray::api::TaskArg> &args) = 0;
virtual void AddLocalReference(const std::string &id) = 0;
virtual void RemoveLocalReference(const std::string &id) = 0;
};
void AddLocalReference(const ObjectID &id);
void RemoveLocalReference(const ObjectID &id);
} // namespace api
} // namespace ray

View file

@ -3,7 +3,6 @@
#include "ray/common/buffer.h"
#include "ray/common/function_descriptor.h"
#include "ray/common/id.h"
#include "ray/common/status.h"
#include "ray/common/task/task_common.h"
#include "ray/common/task/task_spec.h"

View file

@ -62,25 +62,34 @@ void AbstractRayRuntime::Put(std::shared_ptr<msgpack::sbuffer> data,
object_store_->Put(data, object_id);
}
ObjectID AbstractRayRuntime::Put(std::shared_ptr<msgpack::sbuffer> data) {
std::string AbstractRayRuntime::Put(std::shared_ptr<msgpack::sbuffer> data) {
ObjectID object_id =
ObjectID::FromIndex(worker_->GetCurrentTaskID(), worker_->GetNextPutIndex());
Put(data, &object_id);
return object_id;
return object_id.Binary();
}
std::shared_ptr<msgpack::sbuffer> AbstractRayRuntime::Get(const ObjectID &object_id) {
return object_store_->Get(object_id, -1);
std::shared_ptr<msgpack::sbuffer> AbstractRayRuntime::Get(const std::string &object_id) {
return object_store_->Get(ObjectID::FromBinary(object_id), -1);
}
inline static std::vector<ObjectID> StringIDsToObjectIDs(
const std::vector<std::string> &ids) {
std::vector<ObjectID> object_ids;
for (std::string id : ids) {
object_ids.push_back(ObjectID::FromBinary(id));
}
return object_ids;
}
std::vector<std::shared_ptr<msgpack::sbuffer>> AbstractRayRuntime::Get(
const std::vector<ObjectID> &ids) {
return object_store_->Get(ids, -1);
const std::vector<std::string> &ids) {
return object_store_->Get(StringIDsToObjectIDs(ids), -1);
}
std::vector<bool> AbstractRayRuntime::Wait(const std::vector<ObjectID> &ids,
std::vector<bool> AbstractRayRuntime::Wait(const std::vector<std::string> &ids,
int num_objects, int timeout_ms) {
return object_store_->Wait(ids, num_objects, timeout_ms);
return object_store_->Wait(StringIDsToObjectIDs(ids), num_objects, timeout_ms);
}
std::vector<std::unique_ptr<::ray::TaskArg>> TransformArgs(
@ -96,7 +105,8 @@ std::vector<std::unique_ptr<::ray::TaskArg>> TransformArgs(
memory_buffer, nullptr, std::vector<ObjectID>()));
} else {
RAY_CHECK(arg.id);
ray_arg = absl::make_unique<ray::TaskArgByReference>(*arg.id, ray::rpc::Address{});
ray_arg = absl::make_unique<ray::TaskArgByReference>(ObjectID::FromBinary(*arg.id),
ray::rpc::Address{});
}
ray_args.push_back(std::move(ray_arg));
}
@ -119,29 +129,30 @@ InvocationSpec BuildInvocationSpec1(TaskType task_type, std::string lib_name,
return invocation_spec;
}
ObjectID AbstractRayRuntime::Call(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args) {
std::string AbstractRayRuntime::Call(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args) {
auto invocation_spec =
BuildInvocationSpec1(TaskType::NORMAL_TASK, this->config_->lib_name,
remote_function_holder, args, ActorID::Nil());
return task_submitter_->SubmitTask(invocation_spec);
return task_submitter_->SubmitTask(invocation_spec).Binary();
}
ActorID AbstractRayRuntime::CreateActor(
std::string AbstractRayRuntime::CreateActor(
const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args) {
auto invocation_spec =
BuildInvocationSpec1(TaskType::ACTOR_CREATION_TASK, this->config_->lib_name,
remote_function_holder, args, ActorID::Nil());
return task_submitter_->CreateActor(invocation_spec);
return task_submitter_->CreateActor(invocation_spec).Binary();
}
ObjectID AbstractRayRuntime::CallActor(const RemoteFunctionHolder &remote_function_holder,
const ActorID &actor,
std::vector<ray::api::TaskArg> &args) {
auto invocation_spec = BuildInvocationSpec1(
TaskType::ACTOR_TASK, this->config_->lib_name, remote_function_holder, args, actor);
return task_submitter_->SubmitActorTask(invocation_spec);
std::string AbstractRayRuntime::CallActor(
const RemoteFunctionHolder &remote_function_holder, const std::string &actor,
std::vector<ray::api::TaskArg> &args) {
auto invocation_spec =
BuildInvocationSpec1(TaskType::ACTOR_TASK, this->config_->lib_name,
remote_function_holder, args, ActorID::FromBinary(actor));
return task_submitter_->SubmitActorTask(invocation_spec).Binary();
}
const TaskID &AbstractRayRuntime::GetCurrentTaskId() {
@ -154,17 +165,17 @@ const std::unique_ptr<WorkerContext> &AbstractRayRuntime::GetWorkerContext() {
return worker_;
}
void AddLocalReference(const ObjectID &id) {
void AbstractRayRuntime::AddLocalReference(const std::string &id) {
if (CoreWorkerProcess::IsInitialized()) {
auto &core_worker = CoreWorkerProcess::GetCoreWorker();
core_worker.AddLocalReference(id);
core_worker.AddLocalReference(ObjectID::FromBinary(id));
}
}
void RemoveLocalReference(const ObjectID &id) {
void AbstractRayRuntime::RemoveLocalReference(const std::string &id) {
if (CoreWorkerProcess::IsInitialized()) {
auto &core_worker = CoreWorkerProcess::GetCoreWorker();
core_worker.RemoveLocalReference(id);
core_worker.RemoveLocalReference(ObjectID::FromBinary(id));
}
}

View file

@ -23,23 +23,27 @@ class AbstractRayRuntime : public RayRuntime {
void Put(std::shared_ptr<msgpack::sbuffer> data, const ObjectID &object_id);
ObjectID Put(std::shared_ptr<msgpack::sbuffer> data);
std::string Put(std::shared_ptr<msgpack::sbuffer> data);
std::shared_ptr<msgpack::sbuffer> Get(const ObjectID &id);
std::shared_ptr<msgpack::sbuffer> Get(const std::string &id);
std::vector<std::shared_ptr<msgpack::sbuffer>> Get(const std::vector<ObjectID> &ids);
std::vector<std::shared_ptr<msgpack::sbuffer>> Get(const std::vector<std::string> &ids);
std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects,
std::vector<bool> Wait(const std::vector<std::string> &ids, int num_objects,
int timeout_ms);
ObjectID Call(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args);
std::string Call(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args);
ActorID CreateActor(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args);
std::string CreateActor(const RemoteFunctionHolder &remote_function_holder,
std::vector<ray::api::TaskArg> &args);
ObjectID CallActor(const RemoteFunctionHolder &remote_function_holder,
const ActorID &actor, std::vector<ray::api::TaskArg> &args);
std::string CallActor(const RemoteFunctionHolder &remote_function_holder,
const std::string &actor, std::vector<ray::api::TaskArg> &args);
void AddLocalReference(const std::string &id);
void RemoveLocalReference(const std::string &id);
const TaskID &GetCurrentTaskId();

View file

@ -88,5 +88,9 @@ std::vector<bool> LocalModeObjectStore::Wait(const std::vector<ObjectID> &ids,
}
return result;
}
void LocalModeObjectStore::AddLocalReference(const std::string &id) { return; }
void LocalModeObjectStore::RemoveLocalReference(const std::string &id) { return; }
} // namespace api
} // namespace ray

View file

@ -17,6 +17,10 @@ class LocalModeObjectStore : public ObjectStore {
std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects,
int timeout_ms);
void AddLocalReference(const std::string &id);
void RemoveLocalReference(const std::string &id);
private:
void PutRaw(std::shared_ptr<msgpack::sbuffer> data, ObjectID *object_id);

View file

@ -83,5 +83,19 @@ std::vector<bool> NativeObjectStore::Wait(const std::vector<ObjectID> &ids,
}
return results;
}
void NativeObjectStore::AddLocalReference(const std::string &id) {
if (CoreWorkerProcess::IsInitialized()) {
auto &core_worker = CoreWorkerProcess::GetCoreWorker();
core_worker.AddLocalReference(ObjectID::FromBinary(id));
}
}
void NativeObjectStore::RemoveLocalReference(const std::string &id) {
if (CoreWorkerProcess::IsInitialized()) {
auto &core_worker = CoreWorkerProcess::GetCoreWorker();
core_worker.RemoveLocalReference(ObjectID::FromBinary(id));
}
}
} // namespace api
} // namespace ray

View file

@ -17,6 +17,10 @@ class NativeObjectStore : public ObjectStore {
std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects,
int timeout_ms);
void AddLocalReference(const std::string &id);
void RemoveLocalReference(const std::string &id);
private:
void PutRaw(std::shared_ptr<msgpack::sbuffer> data, ObjectID *object_id);

View file

@ -56,6 +56,19 @@ class ObjectStore {
virtual std::vector<bool> Wait(const std::vector<ObjectID> &ids, int num_objects,
int timeout_ms) = 0;
/// Increase the reference count for this object ID.
/// Increase the local reference count for this object ID. Should be called
/// by the language frontend when a new reference is created.
///
/// \param[in] id The binary string ID to increase the reference count for.
virtual void AddLocalReference(const std::string &id) = 0;
/// Decrease the reference count for this object ID. Should be called
/// by the language frontend when a reference is destroyed.
///
/// \param[in] id The binary string ID to decrease the reference count for.
virtual void RemoveLocalReference(const std::string &id) = 0;
private:
virtual void PutRaw(std::shared_ptr<msgpack::sbuffer> data, ObjectID *object_id) = 0;

View file

@ -148,7 +148,7 @@ void TaskExecutor::Invoke(
std::vector<msgpack::sbuffer> args_buffer;
for (size_t i = 0; i < task_spec.NumArgs(); i++) {
if (task_spec.ArgByRef(i)) {
auto arg = runtime->Get(task_spec.ArgId(i));
auto arg = runtime->Get(task_spec.ArgId(i).Binary());
args_buffer.push_back(std::move(*arg));
} else {
msgpack::sbuffer sbuf;