Use strongly typed IDs in C++. (#4185)

*  Use strongly typed IDs for C++.

* Avoid heap allocation in cython.

* Fix JNI part

* Fix rebase conflict

* Refine

* Remove type check from __init__

* Remove unused constructor declarations.
This commit is contained in:
Yuhong Guo 2019-03-07 21:43:01 +08:00 committed by GitHub
parent b0332551dd
commit b9ea821d16
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 484 additions and 334 deletions

View file

@ -72,7 +72,7 @@ cdef c_vector[CObjectID] ObjectIDsToVector(object_ids):
ObjectID object_id
c_vector[CObjectID] result
for object_id in object_ids:
result.push_back(object_id.data)
result.push_back(object_id.native())
return result
@ -87,11 +87,11 @@ def compute_put_id(TaskID task_id, int64_t put_index):
if put_index < 1 or put_index > kMaxTaskPuts:
raise ValueError("The range of 'put_index' should be [1, %d]"
% kMaxTaskPuts)
return ObjectID(ComputePutId(task_id.data, put_index).binary())
return ObjectID(ComputePutId(task_id.native(), put_index).binary())
def compute_task_id(ObjectID object_id):
return TaskID(ComputeTaskId(object_id.data).binary())
return TaskID(ComputeTaskId(object_id.native()).binary())
cdef c_bool is_simple_value(value, int *num_elements_contained):
@ -225,8 +225,8 @@ cdef class RayletClient:
# parameter.
# TODO(suquark): Should we allow unicode chars in "raylet_socket"?
self.client.reset(new CRayletClient(
raylet_socket.encode("ascii"), client_id.data, is_worker,
driver_id.data, LANGUAGE_PYTHON))
raylet_socket.encode("ascii"), client_id.native(), is_worker,
driver_id.native(), LANGUAGE_PYTHON))
def disconnect(self):
check_status(self.client.get().Disconnect())
@ -252,22 +252,23 @@ cdef class RayletClient:
TaskID current_task_id=TaskID.nil()):
cdef c_vector[CObjectID] fetch_ids = ObjectIDsToVector(object_ids)
check_status(self.client.get().FetchOrReconstruct(
fetch_ids, fetch_only, current_task_id.data))
fetch_ids, fetch_only, current_task_id.native()))
def notify_unblocked(self, TaskID current_task_id):
check_status(self.client.get().NotifyUnblocked(current_task_id.data))
check_status(self.client.get().NotifyUnblocked(current_task_id.native()))
def wait(self, object_ids, int num_returns, int64_t timeout_milliseconds,
c_bool wait_local, TaskID current_task_id):
cdef:
WaitResultPair result
c_vector[CObjectID] wait_ids
CTaskID c_task_id = current_task_id.native()
wait_ids = ObjectIDsToVector(object_ids)
with nogil:
check_status(self.client.get().Wait(wait_ids, num_returns,
timeout_milliseconds,
wait_local,
current_task_id.data, &result))
c_task_id, &result))
return (VectorToObjectIDs(result.first),
VectorToObjectIDs(result.second))
@ -291,9 +292,9 @@ cdef class RayletClient:
postincrement(iterator)
return resources_dict
def push_error(self, DriverID job_id, error_type, error_message,
def push_error(self, DriverID driver_id, error_type, error_message,
double timestamp):
check_status(self.client.get().PushError(job_id.data,
check_status(self.client.get().PushError(driver_id.native(),
error_type.encode("ascii"),
error_message.encode("ascii"),
timestamp))
@ -354,7 +355,7 @@ cdef class RayletClient:
def prepare_actor_checkpoint(self, ActorID actor_id):
cdef CActorCheckpointID checkpoint_id
cdef CActorID c_actor_id = actor_id.data
cdef CActorID c_actor_id = actor_id.native()
# PrepareActorCheckpoint will wait for raylet's reply, release
# the GIL so other Python threads can run.
with nogil:
@ -365,7 +366,7 @@ cdef class RayletClient:
def notify_actor_resumed_from_checkpoint(self, ActorID actor_id,
ActorCheckpointID checkpoint_id):
check_status(self.client.get().NotifyActorResumedFromCheckpoint(
actor_id.data, checkpoint_id.data))
actor_id.native(), checkpoint_id.native()))
@property
def language(self):

View file

@ -62,7 +62,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
int num_returns, int64_t timeout_milliseconds,
c_bool wait_local, const CTaskID &current_task_id,
WaitResultPair *result)
CRayStatus PushError(const CDriverID &job_id, const c_string &type,
CRayStatus PushError(const CDriverID &driver_id, const c_string &type,
const c_string &error_message, double timestamp)
CRayStatus PushProfileEvents(
const GCSProfileTableDataT &profile_events)

View file

@ -54,7 +54,7 @@ cdef class Task:
for arg in arguments:
if isinstance(arg, ObjectID):
references = c_vector[CObjectID]()
references.push_back((<ObjectID>arg).data)
references.push_back((<ObjectID>arg).native())
task_args.push_back(
static_pointer_cast[CTaskArgument,
CTaskArgumentByReference](
@ -71,23 +71,21 @@ cdef class Task:
for new_actor_handle in new_actor_handles:
task_new_actor_handles.push_back(
(<ActorHandleID?>new_actor_handle).data)
(<ActorHandleID?>new_actor_handle).native())
self.task_spec.reset(new CTaskSpecification(
CUniqueID(driver_id.data), parent_task_id.data, parent_counter,
actor_creation_id.data, actor_creation_dummy_object_id.data,
max_actor_reconstructions, CUniqueID(actor_id.data),
CUniqueID(actor_handle_id.data), actor_counter,
task_new_actor_handles, task_args, num_returns,
required_resources, required_placement_resources,
LANGUAGE_PYTHON, c_function_descriptor))
driver_id.native(), parent_task_id.native(), parent_counter, actor_creation_id.native(),
actor_creation_dummy_object_id.native(), max_actor_reconstructions, actor_id.native(),
actor_handle_id.native(), actor_counter, task_new_actor_handles, task_args, num_returns,
required_resources, required_placement_resources, LANGUAGE_PYTHON,
c_function_descriptor))
# Set the task's execution dependencies.
self.execution_dependencies.reset(new c_vector[CObjectID]())
if execution_arguments is not None:
for execution_arg in execution_arguments:
self.execution_dependencies.get().push_back(
(<ObjectID?>execution_arg).data)
(<ObjectID?>execution_arg).native())
@staticmethod
cdef make(unique_ptr[CTaskSpecification]& task_spec):

View file

@ -5,13 +5,14 @@ from libc.stdint cimport uint8_t
cdef extern from "ray/id.h" namespace "ray" nogil:
cdef cppclass CUniqueID "ray::UniqueID":
CUniqueID()
CUniqueID(const c_string &binary)
CUniqueID(const CUniqueID &from_id)
@staticmethod
CUniqueID from_random()
@staticmethod
CUniqueID from_binary(const c_string & binary)
CUniqueID from_binary(const c_string &binary)
@staticmethod
const CUniqueID nil()
@ -26,14 +27,73 @@ cdef extern from "ray/id.h" namespace "ray" nogil:
c_string binary() const
c_string hex() const
ctypedef CUniqueID CActorCheckpointID
ctypedef CUniqueID CActorClassID
ctypedef CUniqueID CActorHandleID
ctypedef CUniqueID CActorID
ctypedef CUniqueID CClientID
ctypedef CUniqueID CConfigID
ctypedef CUniqueID CDriverID
ctypedef CUniqueID CFunctionID
ctypedef CUniqueID CObjectID
ctypedef CUniqueID CTaskID
ctypedef CUniqueID CWorkerID
cdef cppclass CActorCheckpointID "ray::ActorCheckpointID"(CUniqueID):
@staticmethod
CActorCheckpointID from_binary(const c_string &binary)
cdef cppclass CActorClassID "ray::ActorClassID"(CUniqueID):
@staticmethod
CActorClassID from_binary(const c_string &binary)
cdef cppclass CActorID "ray::ActorID"(CUniqueID):
@staticmethod
CActorID from_binary(const c_string &binary)
cdef cppclass CActorHandleID "ray::ActorHandleID"(CUniqueID):
@staticmethod
CActorHandleID from_binary(const c_string &binary)
cdef cppclass CClientID "ray::ClientID"(CUniqueID):
@staticmethod
CClientID from_binary(const c_string &binary)
cdef cppclass CConfigID "ray::ConfigID"(CUniqueID):
@staticmethod
CConfigID from_binary(const c_string &binary)
cdef cppclass CFunctionID "ray::FunctionID"(CUniqueID):
@staticmethod
CFunctionID from_binary(const c_string &binary)
cdef cppclass CDriverID "ray::DriverID"(CUniqueID):
@staticmethod
CDriverID from_binary(const c_string &binary)
cdef cppclass CJobID "ray::JobID"(CUniqueID):
@staticmethod
CJobID from_binary(const c_string &binary)
cdef cppclass CTaskID "ray::TaskID"(CUniqueID):
@staticmethod
CTaskID from_binary(const c_string &binary)
cdef cppclass CObjectID" ray::ObjectID"(CUniqueID):
@staticmethod
CObjectID from_binary(const c_string &binary)
cdef cppclass CWorkerID "ray::WorkerID"(CUniqueID):
@staticmethod
CWorkerID from_binary(const c_string &binary)

View file

@ -19,6 +19,7 @@ from ray.includes.unique_ids cimport (
CConfigID,
CDriverID,
CFunctionID,
CJobID,
CObjectID,
CTaskID,
CUniqueID,
@ -45,11 +46,8 @@ cdef class UniqueID:
cdef CUniqueID data
def __init__(self, id):
if not id:
self.data = CUniqueID()
else:
check_id(id)
self.data = CUniqueID.from_binary(id)
check_id(id)
self.data = CUniqueID.from_binary(id)
@classmethod
def from_binary(cls, id_bytes):
@ -59,7 +57,7 @@ cdef class UniqueID:
@classmethod
def nil(cls):
return cls(b"")
return cls(CUniqueID.nil().binary())
def __hash__(self):
return self.data.hash()
@ -106,40 +104,93 @@ cdef class UniqueID:
cdef class ObjectID(UniqueID):
pass
def __init__(self, id):
check_id(id)
self.data = CObjectID.from_binary(<c_string>id)
cdef CObjectID native(self):
return <CObjectID>self.data
cdef class TaskID(UniqueID):
pass
def __init__(self, id):
check_id(id)
self.data = CTaskID.from_binary(<c_string>id)
cdef CTaskID native(self):
return <CTaskID>self.data
cdef class ClientID(UniqueID):
pass
def __init__(self, id):
check_id(id)
self.data = CClientID.from_binary(<c_string>id)
cdef CClientID native(self):
return <CClientID>self.data
cdef class DriverID(UniqueID):
pass
def __init__(self, id):
check_id(id)
self.data = CDriverID.from_binary(<c_string>id)
cdef CDriverID native(self):
return <CDriverID>self.data
cdef class ActorID(UniqueID):
pass
def __init__(self, id):
check_id(id)
self.data = CActorID.from_binary(<c_string>id)
cdef CActorID native(self):
return <CActorID>self.data
cdef class ActorHandleID(UniqueID):
pass
def __init__(self, id):
check_id(id)
self.data = CActorHandleID.from_binary(<c_string>id)
cdef CActorHandleID native(self):
return <CActorHandleID>self.data
cdef class ActorCheckpointID(UniqueID):
pass
def __init__(self, id):
check_id(id)
self.data = CActorCheckpointID.from_binary(<c_string>id)
cdef CActorCheckpointID native(self):
return <CActorCheckpointID>self.data
cdef class FunctionID(UniqueID):
pass
def __init__(self, id):
check_id(id)
self.data = CFunctionID.from_binary(<c_string>id)
cdef CFunctionID native(self):
return <CFunctionID>self.data
cdef class ActorClassID(UniqueID):
pass
def __init__(self, id):
check_id(id)
self.data = CActorClassID.from_binary(<c_string>id)
cdef CActorClassID native(self):
return <CActorClassID>self.data
_ID_TYPES = [
ActorCheckpointID,

View file

@ -2,74 +2,6 @@
#include "ray/util/logging.h"
flatbuffers::Offset<flatbuffers::String> to_flatbuf(flatbuffers::FlatBufferBuilder &fbb,
ray::ObjectID object_id) {
return fbb.CreateString(reinterpret_cast<const char *>(object_id.data()),
sizeof(ray::ObjectID));
}
ray::ObjectID from_flatbuf(const flatbuffers::String &string) {
ray::ObjectID object_id;
RAY_CHECK(string.size() == sizeof(ray::ObjectID));
memcpy(object_id.mutable_data(), string.data(), sizeof(ray::ObjectID));
return object_id;
}
const std::vector<ray::ObjectID> from_flatbuf(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> &vector) {
std::vector<ray::ObjectID> object_ids;
for (int64_t i = 0; i < vector.Length(); i++) {
object_ids.push_back(from_flatbuf(*vector.Get(i)));
}
return object_ids;
}
const std::vector<ray::ObjectID> object_ids_from_flatbuf(
const flatbuffers::String &string) {
const auto &object_ids = string_from_flatbuf(string);
std::vector<ray::ObjectID> ret;
RAY_CHECK(object_ids.size() % kUniqueIDSize == 0);
auto count = object_ids.size() / kUniqueIDSize;
for (size_t i = 0; i < count; ++i) {
auto pos = static_cast<size_t>(kUniqueIDSize * i);
const auto &id = object_ids.substr(pos, kUniqueIDSize);
ret.push_back(ray::ObjectID::from_binary(id));
}
return ret;
}
flatbuffers::Offset<flatbuffers::String> object_ids_to_flatbuf(
flatbuffers::FlatBufferBuilder &fbb, const std::vector<ray::ObjectID> &object_ids) {
std::string result;
for (const auto &id : object_ids) {
result += id.binary();
}
return fbb.CreateString(result);
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ray::ObjectID object_ids[],
int64_t num_objects) {
std::vector<flatbuffers::Offset<flatbuffers::String>> results;
for (int64_t i = 0; i < num_objects; i++) {
results.push_back(to_flatbuf(fbb, object_ids[i]));
}
return fbb.CreateVector(results);
}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
to_flatbuf(flatbuffers::FlatBufferBuilder &fbb,
const std::vector<ray::ObjectID> &object_ids) {
std::vector<flatbuffers::Offset<flatbuffers::String>> results;
for (auto object_id : object_ids) {
results.push_back(to_flatbuf(fbb, object_id));
}
return fbb.CreateVector(results);
}
std::string string_from_flatbuf(const flatbuffers::String &string) {
return std::string(string.data(), string.size());
}

View file

@ -6,63 +6,68 @@
#include <unordered_map>
#include "ray/id.h"
#include "ray/util/logging.h"
/// Convert an object ID to a flatbuffer string.
/// Convert an unique ID to a flatbuffer string.
///
/// @param fbb Reference to the flatbuffer builder.
/// @param object_id The object ID to be converted.
/// @return The flatbuffer string contining the object ID.
/// @param id The ID to be converted.
/// @return The flatbuffer string containing the ID.
template <typename ID>
flatbuffers::Offset<flatbuffers::String> to_flatbuf(flatbuffers::FlatBufferBuilder &fbb,
ray::ObjectID object_id);
ID id);
/// Convert a flatbuffer string to an object ID.
/// Convert a flatbuffer string to an unique ID.
///
/// @param string The flatbuffer string.
/// @return The object ID.
ray::ObjectID from_flatbuf(const flatbuffers::String &string);
/// @return The ID.
template <typename ID>
ID from_flatbuf(const flatbuffers::String &string);
/// Convert a flatbuffer vector of strings to a vector of object IDs.
/// Convert a flatbuffer vector of strings to a vector of unique IDs.
///
/// @param vector The flatbuffer vector.
/// @return The vector of object IDs.
const std::vector<ray::ObjectID> from_flatbuf(
/// @return The vector of IDs.
template <typename ID>
const std::vector<ID> from_flatbuf(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> &vector);
/// Convert a flatbuffer of string that concatenated
/// object IDs to a vector of object IDs.
/// unique IDs to a vector of unique IDs.
///
/// @param vector The flatbuffer vector.
/// @return The vector of object IDs.
const std::vector<ray::ObjectID> object_ids_from_flatbuf(
const flatbuffers::String &string);
/// @return The vector of IDs.
template <typename ID>
const std::vector<ID> ids_from_flatbuf(const flatbuffers::String &string);
/// Convert a vector of object IDs to a flatbuffer string.
/// Convert a vector of unique IDs to a flatbuffer string.
/// The IDs are concatenated to a string with binary.
///
/// @param fbb Reference to the flatbuffer builder.
/// @param object_ids The vector of object IDs.
/// @param ids The vector of IDs.
/// @return Flatbuffer string of concatenated IDs.
flatbuffers::Offset<flatbuffers::String> object_ids_to_flatbuf(
flatbuffers::FlatBufferBuilder &fbb, const std::vector<ray::ObjectID> &object_ids);
template <typename ID>
flatbuffers::Offset<flatbuffers::String> ids_to_flatbuf(
flatbuffers::FlatBufferBuilder &fbb, const std::vector<ID> &ids);
/// Convert an array of object IDs to a flatbuffer vector of strings.
/// Convert an array of unique IDs to a flatbuffer vector of strings.
///
/// @param fbb Reference to the flatbuffer builder.
/// @param object_ids Array of object IDs.
/// @param num_objects Number of elements in the array.
/// @param ids Array of unique IDs.
/// @param num_ids Number of elements in the array.
/// @return Flatbuffer vector of strings.
template <typename ID>
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ray::ObjectID object_ids[],
int64_t num_objects);
to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ID ids[], int64_t num_ids);
/// Convert a vector of object IDs to a flatbuffer vector of strings.
/// Convert a vector of unique IDs to a flatbuffer vector of strings.
///
/// @param fbb Reference to the flatbuffer builder.
/// @param object_ids Vector of object IDs.
/// @param ids Vector of IDs.
/// @return Flatbuffer vector of strings.
template <typename ID>
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
to_flatbuf(flatbuffers::FlatBufferBuilder &fbb,
const std::vector<ray::ObjectID> &object_ids);
to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector<ID> &ids);
/// Convert a flatbuffer string to a std::string.
///
@ -95,4 +100,76 @@ std::vector<std::string> string_vec_from_flatbuf(
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
string_vec_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb,
const std::vector<std::string> &string_vector);
template <typename ID>
flatbuffers::Offset<flatbuffers::String> to_flatbuf(flatbuffers::FlatBufferBuilder &fbb,
ID id) {
return fbb.CreateString(reinterpret_cast<const char *>(id.data()), sizeof(ID));
}
template <typename ID>
ID from_flatbuf(const flatbuffers::String &string) {
ID id;
RAY_CHECK(string.size() == sizeof(ID));
memcpy(id.mutable_data(), string.data(), sizeof(ID));
return id;
}
template <typename ID>
const std::vector<ID> from_flatbuf(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> &vector) {
std::vector<ID> ids;
for (int64_t i = 0; i < vector.Length(); i++) {
ids.push_back(from_flatbuf<ID>(*vector.Get(i)));
}
return ids;
}
template <typename ID>
const std::vector<ID> ids_from_flatbuf(const flatbuffers::String &string) {
const auto &ids = string_from_flatbuf(string);
std::vector<ID> ret;
RAY_CHECK(ids.size() % kUniqueIDSize == 0);
auto count = ids.size() / kUniqueIDSize;
for (size_t i = 0; i < count; ++i) {
auto pos = static_cast<size_t>(kUniqueIDSize * i);
const auto &id = ids.substr(pos, kUniqueIDSize);
ret.push_back(ID::from_binary(id));
}
return ret;
}
template <typename ID>
flatbuffers::Offset<flatbuffers::String> ids_to_flatbuf(
flatbuffers::FlatBufferBuilder &fbb, const std::vector<ID> &ids) {
std::string result;
for (const auto &id : ids) {
result += id.binary();
}
return fbb.CreateString(result);
}
template <typename ID>
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ID ids[], int64_t num_ids) {
std::vector<flatbuffers::Offset<flatbuffers::String>> results;
for (int64_t i = 0; i < num_ids; i++) {
results.push_back(to_flatbuf(fbb, ids[i]));
}
return fbb.CreateVector(results);
}
template <typename ID>
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector<ID> &ids) {
std::vector<flatbuffers::Offset<flatbuffers::String>> results;
for (auto id : ids) {
results.push_back(to_flatbuf(fbb, id));
}
return fbb.CreateVector(results);
}
#endif

View file

@ -814,7 +814,7 @@ void TestClientTableConnect(const JobID &job_id,
// Register callbacks for when a client gets added and removed. The latter
// event will stop the event loop.
client->client_table().RegisterClientAddedCallback(
[](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) {
[](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
ClientTableNotification(client, id, data, true);
test->Stop();
});
@ -839,14 +839,14 @@ void TestClientTableDisconnect(const JobID &job_id,
// Register callbacks for when a client gets added and removed. The latter
// event will stop the event loop.
client->client_table().RegisterClientAddedCallback(
[](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) {
[](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
ClientTableNotification(client, id, data, /*is_insertion=*/true);
// Disconnect from the client table. We should receive a notification
// for the removal of our own entry.
RAY_CHECK_OK(client->client_table().Disconnect());
});
client->client_table().RegisterClientRemovedCallback(
[](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) {
[](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
ClientTableNotification(client, id, data, /*is_insertion=*/false);
test->Stop();
});
@ -870,11 +870,11 @@ void TestClientTableImmediateDisconnect(const JobID &job_id,
// Register callbacks for when a client gets added and removed. The latter
// event will stop the event loop.
client->client_table().RegisterClientAddedCallback(
[](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) {
[](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
ClientTableNotification(client, id, data, true);
});
client->client_table().RegisterClientRemovedCallback(
[](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) {
[](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
ClientTableNotification(client, id, data, false);
test->Stop();
});

View file

@ -91,7 +91,7 @@ Status Log<ID, Data>::Lookup(const JobID &job_id, const ID &id, const Callback &
std::vector<DataT> results;
if (!data.empty()) {
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
RAY_CHECK(from_flatbuf(*root->id()) == id);
RAY_CHECK(from_flatbuf<ID>(*root->id()) == id);
for (size_t i = 0; i < root->entries()->size(); i++) {
DataT result;
auto data_root = flatbuffers::GetRoot<Data>(root->entries()->Get(i)->data());
@ -128,7 +128,7 @@ Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
ID id;
if (root->id()->size() > 0) {
id = from_flatbuf(*root->id());
id = from_flatbuf<ID>(*root->id());
}
std::vector<DataT> results;
for (size_t i = 0; i < root->entries()->size(); i++) {
@ -274,18 +274,18 @@ std::string Table<ID, Data>::DebugString() const {
return result.str();
}
Status ErrorTable::PushErrorToDriver(const JobID &job_id, const std::string &type,
Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type,
const std::string &error_message, double timestamp) {
auto data = std::make_shared<ErrorTableDataT>();
data->job_id = job_id.binary();
data->job_id = driver_id.binary();
data->type = type;
data->error_message = error_message;
data->timestamp = timestamp;
return Append(job_id, job_id, data, /*done_callback=*/nullptr);
return Append(JobID(driver_id), driver_id, data, /*done_callback=*/nullptr);
}
std::string ErrorTable::DebugString() const {
return Log<JobID, ErrorTableData>::DebugString();
return Log<DriverID, ErrorTableData>::DebugString();
}
Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) {
@ -302,11 +302,11 @@ std::string ProfileTable::DebugString() const {
return Log<UniqueID, ProfileTableData>::DebugString();
}
Status DriverTable::AppendDriverData(const JobID &driver_id, bool is_dead) {
Status DriverTable::AppendDriverData(const DriverID &driver_id, bool is_dead) {
auto data = std::make_shared<DriverTableDataT>();
data->driver_id = driver_id.binary();
data->is_dead = is_dead;
return Append(driver_id, driver_id, data, /*done_callback=*/nullptr);
return Append(JobID(driver_id), driver_id, data, /*done_callback=*/nullptr);
}
void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) {
@ -492,7 +492,7 @@ Status ClientTable::Lookup(const Callback &lookup) {
std::string ClientTable::DebugString() const {
std::stringstream result;
result << Log<UniqueID, ClientTableData>::DebugString();
result << Log<ClientID, ClientTableData>::DebugString();
result << ", cache size: " << client_cache_.size()
<< ", num removed: " << removed_clients_.size();
return result.str();
@ -500,7 +500,7 @@ std::string ClientTable::DebugString() const {
Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id,
const ActorID &actor_id,
const UniqueID &checkpoint_id) {
const ActorCheckpointID &checkpoint_id) {
auto lookup_callback = [this, checkpoint_id, job_id, actor_id](
ray::gcs::AsyncGcsClient *client, const UniqueID &id,
const ActorCheckpointIdDataT &data) {
@ -512,7 +512,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id,
while (copy->timestamps.size() > num_to_keep) {
// Delete the checkpoint from actor checkpoint table.
const auto &checkpoint_id =
UniqueID::from_binary(copy->checkpoint_ids.substr(0, kUniqueIDSize));
ActorCheckpointID::from_binary(copy->checkpoint_ids.substr(0, kUniqueIDSize));
RAY_LOG(DEBUG) << "Deleting checkpoint " << checkpoint_id << " for actor "
<< actor_id;
copy->timestamps.erase(copy->timestamps.begin());
@ -542,9 +542,9 @@ template class Log<TaskID, TaskReconstructionData>;
template class Table<TaskID, TaskLeaseData>;
template class Table<ClientID, HeartbeatTableData>;
template class Table<ClientID, HeartbeatBatchTableData>;
template class Log<JobID, ErrorTableData>;
template class Log<UniqueID, ClientTableData>;
template class Log<JobID, DriverTableData>;
template class Log<DriverID, ErrorTableData>;
template class Log<ClientID, ClientTableData>;
template class Log<DriverID, DriverTableData>;
template class Log<UniqueID, ProfileTableData>;
template class Table<ActorCheckpointID, ActorCheckpointData>;
template class Table<ActorID, ActorCheckpointIdData>;

View file

@ -382,7 +382,7 @@ class HeartbeatBatchTable : public Table<ClientID, HeartbeatBatchTableData> {
virtual ~HeartbeatBatchTable() {}
};
class DriverTable : public Log<JobID, DriverTableData> {
class DriverTable : public Log<DriverID, DriverTableData> {
public:
DriverTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
@ -398,7 +398,7 @@ class DriverTable : public Log<JobID, DriverTableData> {
/// \param driver_id The driver id.
/// \param is_dead Whether the driver is dead.
/// \return The return status.
Status AppendDriverData(const JobID &driver_id, bool is_dead);
Status AppendDriverData(const DriverID &driver_id, bool is_dead);
};
class FunctionTable : public Table<ObjectID, FunctionTableData> {
@ -488,7 +488,7 @@ class ActorCheckpointIdTable : public Table<ActorID, ActorCheckpointIdData> {
/// \param checkpoint_id ID of the checkpoint.
/// \return Status.
Status AddCheckpointId(const JobID &job_id, const ActorID &actor_id,
const UniqueID &checkpoint_id);
const ActorCheckpointID &checkpoint_id);
};
namespace raylet {
@ -511,7 +511,7 @@ class TaskTable : public Table<TaskID, ray::protocol::Task> {
} // namespace raylet
class ErrorTable : private Log<JobID, ErrorTableData> {
class ErrorTable : private Log<DriverID, ErrorTableData> {
public:
ErrorTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
@ -532,7 +532,7 @@ class ErrorTable : private Log<JobID, ErrorTableData> {
/// \param error_message The error message to push.
/// \param timestamp The timestamp of the error.
/// \return Status.
Status PushErrorToDriver(const JobID &job_id, const std::string &type,
Status PushErrorToDriver(const DriverID &driver_id, const std::string &type,
const std::string &error_message, double timestamp);
/// Returns debug string for class.
@ -574,7 +574,7 @@ using ConfigTable = Table<ConfigID, ConfigTableData>;
/// it should append an entry to the log indicating that it is dead. A client
/// that is marked as dead should never again be marked as alive; if it needs
/// to reconnect, it must connect with a different ClientID.
class ClientTable : private Log<UniqueID, ClientTableData> {
class ClientTable : private Log<ClientID, ClientTableData> {
public:
using ClientTableCallback = std::function<void(
AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data)>;
@ -678,7 +678,7 @@ class ClientTable : private Log<UniqueID, ClientTableData> {
/// The key at which the log of client information is stored. This key must
/// be kept the same across all instances of the ClientTable, so that all
/// clients append and read from the same key.
UniqueID client_log_key_;
ClientID client_log_key_;
/// Whether this client has called Disconnect().
bool disconnected_;
/// This client's ID.

View file

@ -165,7 +165,7 @@ std::ostream &operator<<(std::ostream &os, const UniqueID &id) {
const ObjectID ComputeObjectId(const TaskID &task_id, int64_t object_index) {
RAY_CHECK(object_index <= kMaxTaskReturns && object_index >= -kMaxTaskPuts);
ObjectID return_id = task_id;
ObjectID return_id = ObjectID(task_id);
int64_t *first_bytes = reinterpret_cast<int64_t *>(&return_id);
// Zero out the lowest kObjectIdIndexSize bits of the first byte of the
// object ID.
@ -176,7 +176,9 @@ const ObjectID ComputeObjectId(const TaskID &task_id, int64_t object_index) {
return return_id;
}
const TaskID FinishTaskId(const TaskID &task_id) { return ComputeObjectId(task_id, 0); }
const TaskID FinishTaskId(const TaskID &task_id) {
return TaskID(ComputeObjectId(task_id, 0));
}
const ObjectID ComputeReturnId(const TaskID &task_id, int64_t return_index) {
RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns);
@ -190,7 +192,7 @@ const ObjectID ComputePutId(const TaskID &task_id, int64_t put_index) {
}
const TaskID ComputeTaskId(const ObjectID &object_id) {
TaskID task_id = object_id;
TaskID task_id = TaskID(object_id);
int64_t *first_bytes = reinterpret_cast<int64_t *>(&task_id);
// Zero out the lowest kObjectIdIndexSize bits of the first byte of the
// object ID.

View file

@ -30,7 +30,7 @@ class RAY_EXPORT UniqueID {
std::string hex() const;
plasma::UniqueID to_plasma_id() const;
private:
protected:
uint8_t id_[kUniqueIDSize];
};
@ -38,18 +38,24 @@ static_assert(std::is_standard_layout<UniqueID>::value, "UniqueID must be standa
std::ostream &operator<<(std::ostream &os, const UniqueID &id);
typedef UniqueID TaskID;
typedef UniqueID JobID;
typedef UniqueID ObjectID;
typedef UniqueID FunctionID;
typedef UniqueID ActorClassID;
typedef UniqueID ActorID;
typedef UniqueID ActorHandleID;
typedef UniqueID ActorCheckpointID;
typedef UniqueID WorkerID;
typedef UniqueID DriverID;
typedef UniqueID ConfigID;
typedef UniqueID ClientID;
#define DEFINE_UNIQUE_ID(type) \
class RAY_EXPORT type : public UniqueID { \
public: \
explicit type(const UniqueID &from) { \
std::memcpy(&id_, from.data(), kUniqueIDSize); \
} \
type() : UniqueID() {} \
static type from_random() { return type(UniqueID::from_random()); } \
static type from_binary(const std::string &binary) { return type(binary); } \
static type nil() { return type(UniqueID::nil()); } \
\
private: \
type(const std::string &binary) { std::memcpy(id_, binary.data(), kUniqueIDSize); } \
};
#include "id_def.h"
#undef DEFINE_UNIQUE_ID
// TODO(swang): ObjectID and TaskID should derive from UniqueID. Then, we
// can make these methods of the derived classes.
@ -101,14 +107,20 @@ int64_t ComputeObjectIndex(const ObjectID &object_id);
} // namespace ray
namespace std {
template <>
struct hash<::ray::UniqueID> {
size_t operator()(const ::ray::UniqueID &id) const { return id.hash(); }
};
template <>
struct hash<const ::ray::UniqueID> {
size_t operator()(const ::ray::UniqueID &id) const { return id.hash(); }
};
}
#define DEFINE_UNIQUE_ID(type) \
template <> \
struct hash<::ray::type> { \
size_t operator()(const ::ray::type &id) const { return id.hash(); } \
}; \
template <> \
struct hash<const ::ray::type> { \
size_t operator()(const ::ray::type &id) const { return id.hash(); } \
};
DEFINE_UNIQUE_ID(UniqueID);
#include "id_def.h"
#undef DEFINE_UNIQUE_ID
} // namespace std
#endif // RAY_ID_H_

18
src/ray/id_def.h Normal file
View file

@ -0,0 +1,18 @@
// This header file is used to avoid code duplication.
// It can be included multiple times in id.h, and each inclusion
// could use a different definition of the DEFINE_UNIQUE_ID macro.
// Macro definition format: DEFINE_UNIQUE_ID(id_type).
// NOTE: This file should NOT be included in any file other than id.h.
DEFINE_UNIQUE_ID(TaskID);
DEFINE_UNIQUE_ID(JobID);
DEFINE_UNIQUE_ID(ObjectID);
DEFINE_UNIQUE_ID(FunctionID);
DEFINE_UNIQUE_ID(ActorClassID);
DEFINE_UNIQUE_ID(ActorID);
DEFINE_UNIQUE_ID(ActorHandleID);
DEFINE_UNIQUE_ID(ActorCheckpointID);
DEFINE_UNIQUE_ID(WorkerID);
DEFINE_UNIQUE_ID(DriverID);
DEFINE_UNIQUE_ID(ConfigID);
DEFINE_UNIQUE_ID(ClientID);

View file

@ -78,7 +78,7 @@ void ObjectDirectory::RegisterBackend() {
}
};
RAY_CHECK_OK(gcs_client_->object_table().Subscribe(
UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(),
JobID::nil(), gcs_client_->client_table().GetLocalClientId(),
object_notification_callback, nullptr));
}

View file

@ -767,7 +767,7 @@ void ObjectManager::ConnectClient(std::shared_ptr<TcpClientConnection> &conn,
// TODO: trash connection on failure.
auto info =
flatbuffers::GetRoot<object_manager_protocol::ConnectClientMessage>(message);
ClientID client_id = ObjectID::from_binary(info->client_id()->str());
ClientID client_id = ClientID::from_binary(info->client_id()->str());
bool is_transfer = info->is_transfer();
conn->SetClientID(client_id);
if (is_transfer) {
@ -885,7 +885,7 @@ void ObjectManager::ReceiveFreeRequest(std::shared_ptr<TcpClientConnection> &con
const uint8_t *message) {
auto free_request =
flatbuffers::GetRoot<object_manager_protocol::FreeRequestMessage>(message);
std::vector<ObjectID> object_ids = from_flatbuf(*free_request->object_ids());
std::vector<ObjectID> object_ids = from_flatbuf<ObjectID>(*free_request->object_ids());
// This RPC should come from another Object Manager.
// Keep this request local.
bool local_only = true;

View file

@ -58,7 +58,7 @@ void ObjectStoreNotificationManager::ProcessStoreNotification(
const auto &object_info =
flatbuffers::GetRoot<object_manager::protocol::ObjectInfo>(notification_.data());
const auto &object_id = from_flatbuf(*object_info->object_id());
const auto &object_id = from_flatbuf<ObjectID>(*object_info->object_id());
if (object_info->is_deletion()) {
ProcessStoreRemove(object_id);
} else {

View file

@ -196,7 +196,7 @@ table WaitReply {
// This struct is the same as ErrorTableData.
table PushErrorRequest {
// The ID of the job that the error is for.
job_id: string;
driver_id: string;
// The type of the error.
type: string;
// The error message.

View file

@ -6,31 +6,30 @@
#include "ray/raylet/raylet_client.h"
#include "ray/util/logging.h"
#ifdef __cplusplus
extern "C" {
#endif
template <typename ID>
class UniqueIdFromJByteArray {
private:
JNIEnv *_env;
jbyteArray _bytes;
public:
UniqueID *PID;
const ID &GetId() const { return *id_pointer_; }
UniqueIdFromJByteArray(JNIEnv *env, jbyteArray wid) {
_env = env;
_bytes = wid;
jbyte *b = reinterpret_cast<jbyte *>(_env->GetByteArrayElements(_bytes, nullptr));
PID = reinterpret_cast<UniqueID *>(b);
UniqueIdFromJByteArray(JNIEnv *env, jbyteArray bytes) : env_(env), bytes_(bytes) {
jbyte *b = reinterpret_cast<jbyte *>(env_->GetByteArrayElements(bytes_, nullptr));
id_pointer_ = reinterpret_cast<ID *>(b);
}
~UniqueIdFromJByteArray() {
_env->ReleaseByteArrayElements(_bytes, reinterpret_cast<jbyte *>(PID), 0);
env_->ReleaseByteArrayElements(bytes_, reinterpret_cast<jbyte *>(id_pointer_), 0);
}
private:
JNIEnv *env_;
jbyteArray bytes_;
ID *id_pointer_;
};
#ifdef __cplusplus
extern "C" {
#endif
inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) {
if (!status.ok()) {
jclass exception_class = env->FindClass("org/ray/api/exception/RayException");
@ -49,11 +48,11 @@ inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) {
JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker,
jbyteArray driverId) {
UniqueIdFromJByteArray worker_id(env, workerId);
UniqueIdFromJByteArray driver_id(env, driverId);
UniqueIdFromJByteArray<ClientID> worker_id(env, workerId);
UniqueIdFromJByteArray<DriverID> driver_id(env, driverId);
const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE);
auto raylet_client = new RayletClient(nativeString, *worker_id.PID, isWorker,
*driver_id.PID, Language::JAVA);
auto raylet_client = new RayletClient(nativeString, worker_id.GetId(), isWorker,
driver_id.GetId(), Language::JAVA);
env->ReleaseStringUTFChars(sockName, nativeString);
return reinterpret_cast<jlong>(raylet_client);
}
@ -70,8 +69,8 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit
std::vector<ObjectID> execution_dependencies;
if (cursorId != nullptr) {
UniqueIdFromJByteArray cursor_id(env, cursorId);
execution_dependencies.push_back(*cursor_id.PID);
UniqueIdFromJByteArray<ObjectID> cursor_id(env, cursorId);
execution_dependencies.push_back(cursor_id.GetId());
}
auto data = reinterpret_cast<char *>(env->GetDirectBufferAddress(taskBuff)) + pos;
@ -143,14 +142,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(
for (int i = 0; i < len; i++) {
jbyteArray object_id_bytes =
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
UniqueIdFromJByteArray object_id(env, object_id_bytes);
object_ids.push_back(*object_id.PID);
UniqueIdFromJByteArray<ObjectID> object_id(env, object_id_bytes);
object_ids.push_back(object_id.GetId());
env->DeleteLocalRef(object_id_bytes);
}
UniqueIdFromJByteArray current_task_id(env, currentTaskId);
UniqueIdFromJByteArray<TaskID> current_task_id(env, currentTaskId);
auto raylet_client = reinterpret_cast<RayletClient *>(client);
auto status =
raylet_client->FetchOrReconstruct(object_ids, fetchOnly, *current_task_id.PID);
raylet_client->FetchOrReconstruct(object_ids, fetchOnly, current_task_id.GetId());
ThrowRayExceptionIfNotOK(env, status);
}
@ -161,9 +160,9 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked(
JNIEnv *env, jclass, jlong client, jbyteArray currentTaskId) {
UniqueIdFromJByteArray current_task_id(env, currentTaskId);
UniqueIdFromJByteArray<TaskID> current_task_id(env, currentTaskId);
auto raylet_client = reinterpret_cast<RayletClient *>(client);
auto status = raylet_client->NotifyUnblocked(*current_task_id.PID);
auto status = raylet_client->NotifyUnblocked(current_task_id.GetId());
ThrowRayExceptionIfNotOK(env, status);
}
@ -181,19 +180,19 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject(
for (int i = 0; i < len; i++) {
jbyteArray object_id_bytes =
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
UniqueIdFromJByteArray object_id(env, object_id_bytes);
object_ids.push_back(*object_id.PID);
UniqueIdFromJByteArray<ObjectID> object_id(env, object_id_bytes);
object_ids.push_back(object_id.GetId());
env->DeleteLocalRef(object_id_bytes);
}
UniqueIdFromJByteArray current_task_id(env, currentTaskId);
UniqueIdFromJByteArray<TaskID> current_task_id(env, currentTaskId);
auto raylet_client = reinterpret_cast<RayletClient *>(client);
// Invoke wait.
WaitResultPair result;
auto status =
raylet_client->Wait(object_ids, numReturns, timeoutMillis,
static_cast<bool>(isWaitLocal), *current_task_id.PID, &result);
auto status = raylet_client->Wait(object_ids, numReturns, timeoutMillis,
static_cast<bool>(isWaitLocal),
current_task_id.GetId(), &result);
if (ThrowRayExceptionIfNotOK(env, status)) {
return nullptr;
}
@ -231,15 +230,12 @@ JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId(
JNIEnv *env, jclass, jbyteArray driverId, jbyteArray parentTaskId,
jint parent_task_counter) {
UniqueIdFromJByteArray object_id1(env, driverId);
ray::DriverID driver_id = *object_id1.PID;
UniqueIdFromJByteArray<DriverID> driver_id(env, driverId);
UniqueIdFromJByteArray<TaskID> parent_task_id(env, parentTaskId);
UniqueIdFromJByteArray object_id2(env, parentTaskId);
ray::TaskID parent_task_id = *object_id2.PID;
ray::TaskID task_id =
ray::GenerateTaskId(driver_id, parent_task_id, parent_task_counter);
jbyteArray result = env->NewByteArray(sizeof(ray::TaskID));
TaskID task_id =
ray::GenerateTaskId(driver_id.GetId(), parent_task_id.GetId(), parent_task_counter);
jbyteArray result = env->NewByteArray(sizeof(TaskID));
if (nullptr == result) {
return nullptr;
}
@ -261,8 +257,8 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects(
for (int i = 0; i < len; i++) {
jbyteArray object_id_bytes =
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
UniqueIdFromJByteArray object_id(env, object_id_bytes);
object_ids.push_back(*object_id.PID);
UniqueIdFromJByteArray<ObjectID> object_id(env, object_id_bytes);
object_ids.push_back(object_id.GetId());
env->DeleteLocalRef(object_id_bytes);
}
auto raylet_client = reinterpret_cast<RayletClient *>(client);
@ -280,9 +276,9 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env
jlong client,
jbyteArray actorId) {
auto raylet_client = reinterpret_cast<RayletClient *>(client);
UniqueIdFromJByteArray actor_id(env, actorId);
UniqueIdFromJByteArray<ActorID> actor_id(env, actorId);
ActorCheckpointID checkpoint_id;
auto status = raylet_client->PrepareActorCheckpoint(*actor_id.PID, checkpoint_id);
auto status = raylet_client->PrepareActorCheckpoint(actor_id.GetId(), checkpoint_id);
if (ThrowRayExceptionIfNotOK(env, status)) {
return nullptr;
}
@ -301,10 +297,10 @@ JNIEXPORT void JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint(
JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) {
auto raylet_client = reinterpret_cast<RayletClient *>(client);
UniqueIdFromJByteArray actor_id(env, actorId);
UniqueIdFromJByteArray checkpoint_id(env, checkpointId);
auto status =
raylet_client->NotifyActorResumedFromCheckpoint(*actor_id.PID, *checkpoint_id.PID);
UniqueIdFromJByteArray<ActorID> actor_id(env, actorId);
UniqueIdFromJByteArray<ActorCheckpointID> checkpoint_id(env, checkpointId);
auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id.GetId(),
checkpoint_id.GetId());
ThrowRayExceptionIfNotOK(env, status);
}

View file

@ -358,8 +358,9 @@ void LineageCache::FlushTask(const TaskID &task_id) {
auto task_data = std::make_shared<protocol::TaskT>();
auto root = flatbuffers::GetRoot<protocol::Task>(fbb.GetBufferPointer());
root->UnPackTo(task_data.get());
RAY_CHECK_OK(task_storage_.Add(task->TaskData().GetTaskSpecification().DriverId(),
task_id, task_data, task_callback));
RAY_CHECK_OK(
task_storage_.Add(JobID(task->TaskData().GetTaskSpecification().DriverId()),
task_id, task_data, task_callback));
// We successfully wrote the task, so mark it as committing.
// TODO(swang): Use a batched interface and write with all object entries.

View file

@ -113,9 +113,9 @@ static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
task_arguments.emplace_back(std::make_shared<TaskArgumentByReference>(references));
}
std::vector<std::string> function_descriptor(3);
auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0,
task_arguments, num_returns, required_resources,
Language::PYTHON, function_descriptor);
auto spec = TaskSpecification(DriverID::nil(), TaskID::from_random(), 0, task_arguments,
num_returns, required_resources, Language::PYTHON,
function_descriptor);
auto execution_spec = TaskExecutionSpecification(std::vector<ObjectID>());
execution_spec.IncrementNumForwards();
Task task = Task(execution_spec, spec);

View file

@ -35,7 +35,7 @@ void Monitor::Start() {
HandleHeartbeat(id, heartbeat_data);
};
RAY_CHECK_OK(gcs_client_.heartbeat_table().Subscribe(
UniqueID::nil(), UniqueID::nil(), heartbeat_callback, nullptr, nullptr));
JobID::nil(), ClientID::nil(), heartbeat_callback, nullptr, nullptr));
Tick();
}
@ -69,7 +69,7 @@ void Monitor::Tick() {
<< " has missed too many heartbeats from it.";
// We use the nil JobID to broadcast the message to all drivers.
RAY_CHECK_OK(gcs_client_.error_table().PushErrorToDriver(
JobID::nil(), type, error_message.str(), current_time_ms()));
DriverID::nil(), type, error_message.str(), current_time_ms()));
}
};
RAY_CHECK_OK(gcs_client_.client_table().Lookup(lookup_callback));
@ -88,7 +88,7 @@ void Monitor::Tick() {
batch->batch.push_back(std::unique_ptr<HeartbeatTableDataT>(
new HeartbeatTableDataT(heartbeat.second)));
}
RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(UniqueID::nil(), UniqueID::nil(),
RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(JobID::nil(), ClientID::nil(),
batch, nullptr));
heartbeat_buffer_.clear();
}

View file

@ -145,7 +145,7 @@ ray::Status NodeManager::RegisterGcs() {
};
RAY_RETURN_NOT_OK(gcs_client_->actor_table().Subscribe(
UniqueID::nil(), UniqueID::nil(), actor_notification_callback, nullptr));
JobID::nil(), ClientID::nil(), actor_notification_callback, nullptr));
// Register a callback on the client table for new clients.
auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id,
@ -167,17 +167,17 @@ ray::Status NodeManager::RegisterGcs() {
HeartbeatBatchAdded(heartbeat_batch);
};
RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe(
UniqueID::nil(), UniqueID::nil(), heartbeat_batch_added,
JobID::nil(), ClientID::nil(), heartbeat_batch_added,
/*subscribe_callback=*/nullptr,
/*done_callback=*/nullptr));
// Subscribe to driver table updates.
const auto driver_table_handler = [this](
gcs::AsyncGcsClient *client, const ClientID &client_id,
gcs::AsyncGcsClient *client, const DriverID &client_id,
const std::vector<DriverTableDataT> &driver_data) {
HandleDriverTableUpdate(client_id, driver_data);
};
RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe(JobID::nil(), UniqueID::nil(),
RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe(JobID::nil(), ClientID::nil(),
driver_table_handler, nullptr));
// Start sending heartbeats to the GCS.
@ -210,12 +210,12 @@ void NodeManager::KillWorker(std::shared_ptr<Worker> worker) {
}
void NodeManager::HandleDriverTableUpdate(
const ClientID &id, const std::vector<DriverTableDataT> &driver_data) {
const DriverID &id, const std::vector<DriverTableDataT> &driver_data) {
for (const auto &entry : driver_data) {
RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::from_binary(entry.driver_id)
<< " " << entry.is_dead;
if (entry.is_dead) {
auto driver_id = UniqueID::from_binary(entry.driver_id);
auto driver_id = DriverID::from_binary(entry.driver_id);
auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id);
// Kill all the workers. The actual cleanup for these workers is done
@ -270,7 +270,7 @@ void NodeManager::Heartbeat() {
}
ray::Status status = heartbeat_table.Add(
UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data,
JobID::nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data,
/*success_callback=*/nullptr);
RAY_CHECK_OK_PREPEND(status, "Heartbeat failed");
@ -351,7 +351,7 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) {
<< ". This may be since the node was recently removed.";
// We use the nil JobID to broadcast the message to all drivers.
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
JobID::nil(), type, error_message.str(), current_time_ms()));
DriverID::nil(), type, error_message.str(), current_time_ms()));
return;
}
@ -684,7 +684,7 @@ void NodeManager::ProcessClientMessage(
} break;
case protocol::MessageType::NotifyUnblocked: {
auto message = flatbuffers::GetRoot<protocol::NotifyUnblocked>(message_data);
HandleTaskUnblocked(client, from_flatbuf(*message->task_id()));
HandleTaskUnblocked(client, from_flatbuf<TaskID>(*message->task_id()));
} break;
case protocol::MessageType::WaitRequest: {
ProcessWaitRequestMessage(client, message_data);
@ -698,7 +698,7 @@ void NodeManager::ProcessClientMessage(
} break;
case protocol::MessageType::FreeObjectsInObjectStoreRequest: {
auto message = flatbuffers::GetRoot<protocol::FreeObjectsRequest>(message_data);
std::vector<ObjectID> object_ids = from_flatbuf(*message->object_ids());
std::vector<ObjectID> object_ids = from_flatbuf<ObjectID>(*message->object_ids());
object_manager_.FreeObjects(object_ids, message->local_only());
} break;
case protocol::MessageType::PrepareActorCheckpointRequest: {
@ -719,7 +719,7 @@ void NodeManager::ProcessClientMessage(
void NodeManager::ProcessRegisterClientRequestMessage(
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data) {
auto message = flatbuffers::GetRoot<protocol::RegisterClientRequest>(message_data);
client->SetClientID(from_flatbuf(*message->client_id()));
client->SetClientID(from_flatbuf<ClientID>(*message->client_id()));
auto worker =
std::make_shared<Worker>(message->worker_pid(), message->language(), client);
if (message->is_worker()) {
@ -731,11 +731,11 @@ void NodeManager::ProcessRegisterClientRequestMessage(
// message is actually the ID of the driver task, while client_id represents the
// real driver ID, which can associate all the tasks/actors for a given driver,
// which is set to the worker ID.
const JobID driver_task_id = from_flatbuf(*message->driver_id());
worker->AssignTaskId(driver_task_id);
worker->AssignDriverId(from_flatbuf(*message->client_id()));
const JobID driver_task_id = from_flatbuf<JobID>(*message->driver_id());
worker->AssignTaskId(TaskID(driver_task_id));
worker->AssignDriverId(from_flatbuf<DriverID>(*message->client_id()));
worker_pool_.RegisterDriver(std::move(worker));
local_queues_.AddDriverTaskId(driver_task_id);
local_queues_.AddDriverTaskId(TaskID(driver_task_id));
}
}
@ -865,14 +865,14 @@ void NodeManager::ProcessDisconnectClientMessage(
if (!intentional_disconnect) {
// Push the error to driver.
const JobID &job_id = worker->GetAssignedDriverId();
const DriverID &driver_id = worker->GetAssignedDriverId();
// TODO(rkn): Define this constant somewhere else.
std::string type = "worker_died";
std::ostringstream error_message;
error_message << "A worker died or was killed while executing task " << task_id
<< ".";
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
job_id, type, error_message.str(), current_time_ms()));
driver_id, type, error_message.str(), current_time_ms()));
}
}
@ -899,8 +899,9 @@ void NodeManager::ProcessDisconnectClientMessage(
DispatchTasks(local_queues_.GetReadyTasksWithResources());
} else if (is_driver) {
// The client is a driver.
RAY_CHECK_OK(gcs_client_->driver_table().AppendDriverData(client->GetClientId(),
/*is_dead=*/true));
RAY_CHECK_OK(
gcs_client_->driver_table().AppendDriverData(DriverID(client->GetClientId()),
/*is_dead=*/true));
auto driver_id = worker->GetAssignedTaskId();
RAY_CHECK(!driver_id.is_nil());
local_queues_.RemoveDriverTaskId(driver_id);
@ -919,7 +920,7 @@ void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) {
// Read the task submitted by the client.
auto message = flatbuffers::GetRoot<protocol::SubmitTaskRequest>(message_data);
TaskExecutionSpecification task_execution_spec(
from_flatbuf(*message->execution_dependencies()));
from_flatbuf<ObjectID>(*message->execution_dependencies()));
TaskSpecification task_spec(*message->task_spec());
Task task(task_execution_spec, task_spec);
// Submit the task to the local scheduler. Since the task was submitted
@ -932,7 +933,7 @@ void NodeManager::ProcessFetchOrReconstructMessage(
auto message = flatbuffers::GetRoot<protocol::FetchOrReconstruct>(message_data);
std::vector<ObjectID> required_object_ids;
for (size_t i = 0; i < message->object_ids()->size(); ++i) {
ObjectID object_id = from_flatbuf(*message->object_ids()->Get(i));
ObjectID object_id = from_flatbuf<ObjectID>(*message->object_ids()->Get(i));
if (message->fetch_only()) {
// If only a fetch is required, then do not subscribe to the
// dependencies to the task dependency manager.
@ -950,7 +951,7 @@ void NodeManager::ProcessFetchOrReconstructMessage(
}
if (!required_object_ids.empty()) {
const TaskID task_id = from_flatbuf(*message->task_id());
const TaskID task_id = from_flatbuf<TaskID>(*message->task_id());
HandleTaskBlocked(client, required_object_ids, task_id);
}
}
@ -959,7 +960,7 @@ void NodeManager::ProcessWaitRequestMessage(
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data) {
// Read the data.
auto message = flatbuffers::GetRoot<protocol::WaitRequest>(message_data);
std::vector<ObjectID> object_ids = from_flatbuf(*message->object_ids());
std::vector<ObjectID> object_ids = from_flatbuf<ObjectID>(*message->object_ids());
int64_t wait_ms = message->timeout();
uint64_t num_required_objects = static_cast<uint64_t>(message->num_ready_objects());
bool wait_local = message->wait_local();
@ -974,7 +975,7 @@ void NodeManager::ProcessWaitRequestMessage(
}
}
const TaskID &current_task_id = from_flatbuf(*message->task_id());
const TaskID &current_task_id = from_flatbuf<TaskID>(*message->task_id());
bool client_blocked = !required_object_ids.empty();
if (client_blocked) {
HandleTaskBlocked(client, required_object_ids, current_task_id);
@ -1012,20 +1013,20 @@ void NodeManager::ProcessWaitRequestMessage(
void NodeManager::ProcessPushErrorRequestMessage(const uint8_t *message_data) {
auto message = flatbuffers::GetRoot<protocol::PushErrorRequest>(message_data);
JobID job_id = from_flatbuf(*message->job_id());
DriverID driver_id = from_flatbuf<DriverID>(*message->driver_id());
auto const &type = string_from_flatbuf(*message->type());
auto const &error_message = string_from_flatbuf(*message->error_message());
double timestamp = message->timestamp();
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message,
timestamp));
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(driver_id, type,
error_message, timestamp));
}
void NodeManager::ProcessPrepareActorCheckpointRequest(
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data) {
auto message =
flatbuffers::GetRoot<protocol::PrepareActorCheckpointRequest>(message_data);
ActorID actor_id = from_flatbuf(*message->actor_id());
ActorID actor_id = from_flatbuf<ActorID>(*message->actor_id());
RAY_LOG(DEBUG) << "Preparing checkpoint for actor " << actor_id;
const auto &actor_entry = actor_registry_.find(actor_id);
RAY_CHECK(actor_entry != actor_registry_.end());
@ -1037,15 +1038,15 @@ void NodeManager::ProcessPrepareActorCheckpointRequest(
const auto task_id = worker->GetAssignedTaskId();
const Task &task = local_queues_.GetTaskOfState(task_id, TaskState::RUNNING);
// Generate checkpoint id and data.
ActorCheckpointID checkpoint_id = UniqueID::from_random();
ActorCheckpointID checkpoint_id = ActorCheckpointID::from_random();
auto checkpoint_data =
actor_entry->second.GenerateCheckpointData(actor_entry->first, task);
// Write checkpoint data to GCS.
RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Add(
UniqueID::nil(), checkpoint_id, checkpoint_data,
JobID::nil(), checkpoint_id, checkpoint_data,
[worker, actor_id, this](ray::gcs::AsyncGcsClient *client,
const UniqueID &checkpoint_id,
const ActorCheckpointID &checkpoint_id,
const ActorCheckpointDataT &data) {
RAY_LOG(DEBUG) << "Checkpoint " << checkpoint_id << " saved for actor "
<< worker->GetActorId();
@ -1072,8 +1073,9 @@ void NodeManager::ProcessPrepareActorCheckpointRequest(
void NodeManager::ProcessNotifyActorResumedFromCheckpoint(const uint8_t *message_data) {
auto message =
flatbuffers::GetRoot<protocol::NotifyActorResumedFromCheckpoint>(message_data);
ActorID actor_id = from_flatbuf(*message->actor_id());
ActorCheckpointID checkpoint_id = from_flatbuf(*message->checkpoint_id());
ActorID actor_id = from_flatbuf<ActorID>(*message->actor_id());
ActorCheckpointID checkpoint_id =
from_flatbuf<ActorCheckpointID>(*message->checkpoint_id());
RAY_LOG(DEBUG) << "Actor " << actor_id << " was resumed from checkpoint "
<< checkpoint_id;
checkpoint_id_to_restore_.emplace(actor_id, checkpoint_id);
@ -1093,12 +1095,12 @@ void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_cl
switch (message_type_value) {
case protocol::MessageType::ConnectClient: {
auto message = flatbuffers::GetRoot<protocol::ConnectClient>(message_data);
auto client_id = from_flatbuf(*message->client_id());
auto client_id = from_flatbuf<ClientID>(*message->client_id());
node_manager_client.SetClientID(client_id);
} break;
case protocol::MessageType::ForwardTaskRequest: {
auto message = flatbuffers::GetRoot<protocol::ForwardTaskRequest>(message_data);
TaskID task_id = from_flatbuf(*message->task_id());
TaskID task_id = from_flatbuf<TaskID>(*message->task_id());
Lineage uncommitted_lineage(*message);
const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData();
@ -1589,7 +1591,7 @@ bool NodeManager::AssignTask(const Task &task) {
const std::string warning_message = worker_pool_.WarningAboutSize();
if (warning_message != "") {
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
JobID::nil(), "worker_pool_large", warning_message, current_time_ms()));
DriverID::nil(), "worker_pool_large", warning_message, current_time_ms()));
}
}
// We couldn't assign this task, as no worker available.
@ -1902,7 +1904,6 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) {
// Use a copy of the cached task spec to re-execute the task.
const Task task = lineage_cache_.GetTaskOrDie(task_id);
ResubmitTask(task);
}));
}

View file

@ -326,7 +326,7 @@ class NodeManager {
/// \param id An unused value. TODO(rkn): Should this be removed?
/// \param driver_data Data associated with a driver table event.
/// \return Void.
void HandleDriverTableUpdate(const ClientID &id,
void HandleDriverTableUpdate(const DriverID &id,
const std::vector<DriverTableDataT> &driver_data);
/// Check if certain invariants associated with the task dependency manager

View file

@ -201,8 +201,8 @@ ray::Status RayletConnection::AtomicRequestReply(
return ReadMessage(reply_type, reply_message);
}
RayletClient::RayletClient(const std::string &raylet_socket, const UniqueID &client_id,
bool is_worker, const JobID &driver_id,
RayletClient::RayletClient(const std::string &raylet_socket, const ClientID &client_id,
bool is_worker, const DriverID &driver_id,
const Language &language)
: client_id_(client_id),
is_worker_(is_worker),
@ -323,11 +323,11 @@ ray::Status RayletClient::Wait(const std::vector<ObjectID> &object_ids, int num_
return ray::Status::OK();
}
ray::Status RayletClient::PushError(const JobID &job_id, const std::string &type,
ray::Status RayletClient::PushError(const DriverID &driver_id, const std::string &type,
const std::string &error_message, double timestamp) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreatePushErrorRequest(
fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type),
fbb, to_flatbuf(fbb, driver_id), fbb.CreateString(type),
fbb.CreateString(error_message), timestamp);
fbb.Finish(message);
@ -373,7 +373,7 @@ ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id,
if (!status.ok()) return status;
auto reply_message =
flatbuffers::GetRoot<ray::protocol::PrepareActorCheckpointReply>(reply.get());
checkpoint_id = ObjectID::from_binary(reply_message->checkpoint_id()->str());
checkpoint_id = ActorCheckpointID::from_binary(reply_message->checkpoint_id()->str());
return ray::Status::OK();
}

View file

@ -9,13 +9,14 @@
#include "ray/raylet/task_spec.h"
#include "ray/status.h"
using ray::ActorID;
using ray::ActorCheckpointID;
using ray::ActorID;
using ray::ClientID;
using ray::DriverID;
using ray::JobID;
using ray::ObjectID;
using ray::TaskID;
using ray::UniqueID;
using ray::ClientID;
using MessageType = ray::protocol::MessageType;
using ResourceMappingType =
@ -68,8 +69,8 @@ class RayletClient {
/// additional message will be sent to register as one.
/// \param driver_id The ID of the driver. This is non-nil if the client is a driver.
/// \return The connection information.
RayletClient(const std::string &raylet_socket, const UniqueID &client_id,
bool is_worker, const JobID &driver_id, const Language &language);
RayletClient(const std::string &raylet_socket, const ClientID &client_id,
bool is_worker, const DriverID &driver_id, const Language &language);
ray::Status Disconnect() { return conn_->Disconnect(); };
@ -130,7 +131,7 @@ class RayletClient {
/// \param The error message.
/// \param The timestamp of the error.
/// \return ray::Status.
ray::Status PushError(const JobID &job_id, const std::string &type,
ray::Status PushError(const DriverID &driver_id, const std::string &type,
const std::string &error_message, double timestamp);
/// Store some profile events in the GCS.

View file

@ -322,7 +322,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) {
task_lease_data->node_manager_id = ClientID::from_random().binary();
task_lease_data->acquired_at = current_sys_time_ms();
task_lease_data->timeout = 2 * test_period;
mock_gcs_.Add(DriverID::nil(), task_id, task_lease_data);
mock_gcs_.Add(JobID::nil(), task_id, task_lease_data);
// Listen for an object.
reconstruction_policy_->ListenAndMaybeReconstruct(object_id);
@ -350,7 +350,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) {
task_lease_data->node_manager_id = ClientID::from_random().binary();
task_lease_data->acquired_at = current_sys_time_ms();
task_lease_data->timeout = reconstruction_timeout_ms_;
mock_gcs_.Add(DriverID::nil(), task_id, task_lease_data);
mock_gcs_.Add(JobID::nil(), task_id, task_lease_data);
});
// Run the test for much longer than the reconstruction timeout.
Run(reconstruction_timeout_ms_ * 2);
@ -404,7 +404,7 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) {
task_reconstruction_data->node_manager_id = ClientID::from_random().binary();
task_reconstruction_data->num_reconstructions = 0;
RAY_CHECK_OK(
mock_gcs_.AppendAt(DriverID::nil(), task_id, task_reconstruction_data, nullptr,
mock_gcs_.AppendAt(JobID::nil(), task_id, task_reconstruction_data, nullptr,
/*failure_callback=*/
[](ray::gcs::AsyncGcsClient *client, const TaskID &task_id,
const TaskReconstructionDataT &data) { ASSERT_TRUE(false); },

View file

@ -263,7 +263,7 @@ void TaskDependencyManager::AcquireTaskLease(const TaskID &task_id) {
task_lease_data->node_manager_id = client_id_.hex();
task_lease_data->acquired_at = current_sys_time_ms();
task_lease_data->timeout = it->second.lease_period;
RAY_CHECK_OK(task_lease_table_.Add(DriverID::nil(), task_id, task_lease_data, nullptr));
RAY_CHECK_OK(task_lease_table_.Add(JobID::nil(), task_id, task_lease_data, nullptr));
auto period = boost::posix_time::milliseconds(it->second.lease_period / 2);
it->second.lease_timer->expires_from_now(period);

View file

@ -75,9 +75,9 @@ static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
task_arguments.emplace_back(std::make_shared<TaskArgumentByReference>(references));
}
std::vector<std::string> function_descriptor(3);
auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0,
task_arguments, num_returns, required_resources,
Language::PYTHON, function_descriptor);
auto spec = TaskSpecification(DriverID::nil(), TaskID::from_random(), 0, task_arguments,
num_returns, required_resources, Language::PYTHON,
function_descriptor);
auto execution_spec = TaskExecutionSpecification(std::vector<ObjectID>());
execution_spec.IncrementNumForwards();
Task task = Task(execution_spec, spec);

View file

@ -17,7 +17,7 @@ TaskArgumentByReference::TaskArgumentByReference(const std::vector<ObjectID> &re
flatbuffers::Offset<Arg> TaskArgumentByReference::ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb) const {
return CreateArg(fbb, object_ids_to_flatbuf(fbb, references_));
return CreateArg(fbb, ids_to_flatbuf(fbb, references_));
}
TaskArgumentByValue::TaskArgumentByValue(const uint8_t *value, size_t length) {
@ -57,7 +57,7 @@ TaskSpecification::TaskSpecification(const std::string &string) {
}
TaskSpecification::TaskSpecification(
const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments, int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources,
const Language &language, const std::vector<std::string> &function_descriptor)
@ -68,7 +68,7 @@ TaskSpecification::TaskSpecification(
function_descriptor) {}
TaskSpecification::TaskSpecification(
const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id,
const int64_t max_actor_reconstructions, const ActorID &actor_id,
const ActorHandleID &actor_handle_id, int64_t actor_counter,
@ -100,8 +100,8 @@ TaskSpecification::TaskSpecification(
to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id),
to_flatbuf(fbb, actor_creation_dummy_object_id), max_actor_reconstructions,
to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter,
object_ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments),
object_ids_to_flatbuf(fbb, returns), map_to_flatbuf(fbb, required_resources),
ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments),
ids_to_flatbuf(fbb, returns), map_to_flatbuf(fbb, required_resources),
map_to_flatbuf(fbb, required_placement_resources), language,
string_vec_to_flatbuf(fbb, function_descriptor));
fbb.Finish(spec);
@ -122,15 +122,15 @@ size_t TaskSpecification::size() const { return spec_.size(); }
// Task specification getter methods.
TaskID TaskSpecification::TaskId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf(*message->task_id());
return from_flatbuf<TaskID>(*message->task_id());
}
UniqueID TaskSpecification::DriverId() const {
DriverID TaskSpecification::DriverId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf(*message->driver_id());
return from_flatbuf<DriverID>(*message->driver_id());
}
TaskID TaskSpecification::ParentTaskId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf(*message->parent_task_id());
return from_flatbuf<TaskID>(*message->parent_task_id());
}
int64_t TaskSpecification::ParentCounter() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
@ -168,7 +168,7 @@ int64_t TaskSpecification::NumReturns() const {
ObjectID TaskSpecification::ReturnId(int64_t return_index) const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return object_ids_from_flatbuf(*message->returns())[return_index];
return ids_from_flatbuf<ObjectID>(*message->returns())[return_index];
}
bool TaskSpecification::ArgByRef(int64_t arg_index) const {
@ -184,7 +184,7 @@ int TaskSpecification::ArgIdCount(int64_t arg_index) const {
ObjectID TaskSpecification::ArgId(int64_t arg_index, int64_t id_index) const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
const auto &object_ids =
object_ids_from_flatbuf(*message->args()->Get(arg_index)->object_ids());
ids_from_flatbuf<ObjectID>(*message->args()->Get(arg_index)->object_ids());
return object_ids[id_index];
}
@ -232,12 +232,12 @@ bool TaskSpecification::IsActorTask() const { return !ActorId().is_nil(); }
ActorID TaskSpecification::ActorCreationId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf(*message->actor_creation_id());
return from_flatbuf<ActorID>(*message->actor_creation_id());
}
ObjectID TaskSpecification::ActorCreationDummyObjectId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf(*message->actor_creation_dummy_object_id());
return from_flatbuf<ObjectID>(*message->actor_creation_dummy_object_id());
}
int64_t TaskSpecification::MaxActorReconstructions() const {
@ -247,12 +247,12 @@ int64_t TaskSpecification::MaxActorReconstructions() const {
ActorID TaskSpecification::ActorId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf(*message->actor_id());
return from_flatbuf<ActorID>(*message->actor_id());
}
ActorHandleID TaskSpecification::ActorHandleId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf(*message->actor_handle_id());
return from_flatbuf<ActorHandleID>(*message->actor_handle_id());
}
int64_t TaskSpecification::ActorCounter() const {
@ -267,7 +267,7 @@ ObjectID TaskSpecification::ActorDummyObject() const {
std::vector<ActorHandleID> TaskSpecification::NewActorHandles() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return object_ids_from_flatbuf(*message->new_actor_handles());
return ids_from_flatbuf<ActorHandleID>(*message->new_actor_handles());
}
} // namespace raylet

View file

@ -96,7 +96,7 @@ class TaskSpecification {
/// \param num_returns The number of values returned by the task.
/// \param required_resources The task's resource demands.
/// \param language The language of the worker that must execute the function.
TaskSpecification(const UniqueID &driver_id, const TaskID &parent_task_id,
TaskSpecification(const DriverID &driver_id, const TaskID &parent_task_id,
int64_t parent_counter,
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments,
int64_t num_returns,
@ -129,7 +129,7 @@ class TaskSpecification {
/// \param language The language of the worker that must execute the function.
/// \param function_descriptor The function descriptor.
TaskSpecification(
const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id,
int64_t max_actor_reconstructions, const ActorID &actor_id,
const ActorHandleID &actor_handle_id, int64_t actor_counter,
@ -164,7 +164,7 @@ class TaskSpecification {
// TODO(swang): Finalize and document these methods.
TaskID TaskId() const;
UniqueID DriverId() const;
DriverID DriverId() const;
TaskID ParentTaskId() const;
int64_t ParentCounter() const;
std::vector<std::string> FunctionDescriptor() const;

View file

@ -75,7 +75,7 @@ static inline TaskSpecification ExampleTaskSpec(
const ActorID actor_id = ActorID::nil(),
const Language &language = Language::PYTHON) {
std::vector<std::string> function_descriptor(3);
return TaskSpecification(UniqueID::nil(), TaskID::nil(), 0, ActorID::nil(),
return TaskSpecification(DriverID::nil(), TaskID::nil(), 0, ActorID::nil(),
ObjectID::nil(), 0, actor_id, ActorHandleID::nil(), 0, {}, {},
0, {{}}, {{}}, language, function_descriptor);
}