mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Handle exchange of direct call objects between tasks and actors (#6147)
This commit is contained in:
parent
385783fcec
commit
8ff393a7bd
23 changed files with 426 additions and 202 deletions
|
@ -1015,6 +1015,12 @@ cdef class CoreWorker:
|
|||
# Note: faster to not release GIL for short-running op.
|
||||
self.core_worker.get().RemoveObjectIDReference(c_object_id)
|
||||
|
||||
def promote_object_to_plasma(self, ObjectID object_id):
|
||||
cdef:
|
||||
CObjectID c_object_id = object_id.native()
|
||||
self.core_worker.get().PromoteObjectToPlasma(c_object_id)
|
||||
return object_id.with_plasma_transport_type()
|
||||
|
||||
# TODO: handle noreturn better
|
||||
cdef store_task_outputs(
|
||||
self, worker, outputs, const c_vector[CObjectID] return_ids,
|
||||
|
|
|
@ -104,6 +104,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
|||
*bytes)
|
||||
void AddObjectIDReference(const CObjectID &object_id)
|
||||
void RemoveObjectIDReference(const CObjectID &object_id)
|
||||
void PromoteObjectToPlasma(const CObjectID &object_id)
|
||||
|
||||
CRayStatus SetClientOptions(c_string client_name, int64_t limit)
|
||||
CRayStatus Put(const CRayObject &object, CObjectID *object_id)
|
||||
|
|
|
@ -95,8 +95,10 @@ cdef class TaskSpec:
|
|||
:self.task_spec.get().ArgMetadataSize(i)]
|
||||
if metadata == RAW_BUFFER_METADATA:
|
||||
obj = data
|
||||
else:
|
||||
elif metadata == PICKLE_BUFFER_METADATA:
|
||||
obj = pickle.loads(data)
|
||||
else:
|
||||
obj = data
|
||||
arg_list.append(obj)
|
||||
elif lang == <int32_t>LANGUAGE_JAVA:
|
||||
arg_list = num_args * ["<java-argument>"]
|
||||
|
|
|
@ -150,6 +150,8 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
|
|||
|
||||
c_bool IsDirectCallType()
|
||||
|
||||
CObjectID WithPlasmaTransportType()
|
||||
|
||||
int64_t ObjectIndex() const
|
||||
|
||||
CTaskID TaskId() const
|
||||
|
|
|
@ -179,6 +179,9 @@ cdef class ObjectID(BaseID):
|
|||
def is_direct_call_type(self):
|
||||
return self.data.IsDirectCallType()
|
||||
|
||||
def with_plasma_transport_type(self):
|
||||
return ObjectID(self.data.WithPlasmaTransportType().Binary())
|
||||
|
||||
def is_nil(self):
|
||||
return self.data.IsNil()
|
||||
|
||||
|
|
|
@ -158,9 +158,7 @@ class SerializationContext(object):
|
|||
|
||||
def id_serializer(obj):
|
||||
if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type():
|
||||
raise NotImplementedError(
|
||||
"Objects produced by direct actor calls cannot be "
|
||||
"passed to other tasks as arguments.")
|
||||
obj = self.worker.core_worker.promote_object_to_plasma(obj)
|
||||
return pickle.dumps(obj)
|
||||
|
||||
def id_deserializer(serialized_obj):
|
||||
|
@ -191,9 +189,7 @@ class SerializationContext(object):
|
|||
|
||||
def id_serializer(obj):
|
||||
if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type():
|
||||
raise NotImplementedError(
|
||||
"Objects produced by direct actor calls cannot be "
|
||||
"passed to other tasks as arguments.")
|
||||
obj = self.worker.core_worker.promote_object_to_plasma(obj)
|
||||
return obj.__reduce__()
|
||||
|
||||
def id_deserializer(serialized_obj):
|
||||
|
|
|
@ -1218,6 +1218,71 @@ def test_direct_call_simple(ray_start_regular):
|
|||
range(1, 101))
|
||||
|
||||
|
||||
def test_direct_call_matrix(shutdown_only):
|
||||
ray.init(object_store_memory=1000 * 1024 * 1024)
|
||||
|
||||
@ray.remote
|
||||
class Actor(object):
|
||||
def small_value(self):
|
||||
return 0
|
||||
|
||||
def large_value(self):
|
||||
return np.zeros(10 * 1024 * 1024)
|
||||
|
||||
def echo(self, x):
|
||||
if isinstance(x, list):
|
||||
x = ray.get(x[0])
|
||||
return x
|
||||
|
||||
@ray.remote
|
||||
def small_value():
|
||||
return 0
|
||||
|
||||
@ray.remote
|
||||
def large_value():
|
||||
return np.zeros(10 * 1024 * 1024)
|
||||
|
||||
@ray.remote
|
||||
def echo(x):
|
||||
if isinstance(x, list):
|
||||
x = ray.get(x[0])
|
||||
return x
|
||||
|
||||
def check(source_actor, dest_actor, is_large, out_of_band):
|
||||
print("CHECKING", "actor" if source_actor else "task", "to", "actor"
|
||||
if dest_actor else "task", "large_object"
|
||||
if is_large else "small_object", "out_of_band"
|
||||
if out_of_band else "in_band")
|
||||
if source_actor:
|
||||
a = Actor.options(is_direct_call=True).remote()
|
||||
if is_large:
|
||||
x_id = a.large_value.remote()
|
||||
else:
|
||||
x_id = a.small_value.remote()
|
||||
else:
|
||||
if is_large:
|
||||
x_id = large_value.options(is_direct_call=True).remote()
|
||||
else:
|
||||
x_id = small_value.options(is_direct_call=True).remote()
|
||||
if out_of_band:
|
||||
x_id = [x_id]
|
||||
if dest_actor:
|
||||
b = Actor.options(is_direct_call=True).remote()
|
||||
x = ray.get(b.echo.remote(x_id))
|
||||
else:
|
||||
x = ray.get(echo.options(is_direct_call=True).remote(x_id))
|
||||
if is_large:
|
||||
assert isinstance(x, np.ndarray)
|
||||
else:
|
||||
assert isinstance(x, int)
|
||||
|
||||
for is_large in [False, True]:
|
||||
for source_actor in [False, True]:
|
||||
for dest_actor in [False, True]:
|
||||
for out_of_band in [False, True]:
|
||||
check(source_actor, dest_actor, is_large, out_of_band)
|
||||
|
||||
|
||||
def test_direct_call_chain(ray_start_regular):
|
||||
@ray.remote
|
||||
def g(x):
|
||||
|
@ -1265,26 +1330,6 @@ def test_direct_actor_large_objects(ray_start_regular):
|
|||
assert isinstance(ray.get(obj_id), np.ndarray)
|
||||
|
||||
|
||||
def test_direct_actor_errors(ray_start_regular):
|
||||
@ray.remote
|
||||
class Actor(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def f(self, x):
|
||||
return x * 2
|
||||
|
||||
@ray.remote
|
||||
def f(x):
|
||||
return 1
|
||||
|
||||
a = Actor._remote(is_direct_call=True)
|
||||
|
||||
# cannot pass returns to other methods even in a list
|
||||
with pytest.raises(Exception):
|
||||
ray.get(f.remote([a.f.remote(2)]))
|
||||
|
||||
|
||||
def test_direct_actor_pass_by_ref(ray_start_regular):
|
||||
@ray.remote
|
||||
class Actor(object):
|
||||
|
|
|
@ -158,6 +158,14 @@ ObjectID ObjectID::WithTransportType(TaskTransportType transport_type) const {
|
|||
return copy;
|
||||
}
|
||||
|
||||
ObjectID ObjectID::WithPlasmaTransportType() const {
|
||||
return WithTransportType(TaskTransportType::RAYLET);
|
||||
}
|
||||
|
||||
ObjectID ObjectID::WithDirectTransportType() const {
|
||||
return WithTransportType(TaskTransportType::DIRECT);
|
||||
}
|
||||
|
||||
uint8_t ObjectID::GetTransportType() const {
|
||||
return ::ray::GetTransportType(this->GetFlags());
|
||||
}
|
||||
|
@ -290,10 +298,6 @@ TaskID TaskID::ComputeDriverTaskId(const WorkerID &driver_id) {
|
|||
}
|
||||
|
||||
TaskID ObjectID::TaskId() const {
|
||||
if (!CreatedByTask()) {
|
||||
// TODO(qwang): Should be RAY_CHECK here.
|
||||
RAY_LOG(WARNING) << "Shouldn't call this on a non-task object id: " << this->Hex();
|
||||
}
|
||||
return TaskID::FromBinary(
|
||||
std::string(reinterpret_cast<const char *>(id_), TaskID::Size()));
|
||||
}
|
||||
|
|
|
@ -299,6 +299,16 @@ class ObjectID : public BaseID<ObjectID> {
|
|||
/// \return Copy of this object id with the specified transport type.
|
||||
ObjectID WithTransportType(TaskTransportType transport_type) const;
|
||||
|
||||
/// Return this object id with the plasma transport type.
|
||||
///
|
||||
/// \return Copy of this object id with the plasma transport type.
|
||||
ObjectID WithPlasmaTransportType() const;
|
||||
|
||||
/// Return this object id with the direct call transport type.
|
||||
///
|
||||
/// \return Copy of this object id with the direct call transport type.
|
||||
ObjectID WithDirectTransportType() const;
|
||||
|
||||
/// Get the transport type of this object.
|
||||
///
|
||||
/// \return The type of the transport which is used to transfer this object.
|
||||
|
|
|
@ -95,7 +95,7 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
|
|||
SetCurrentJobId(task_spec.JobId());
|
||||
RAY_CHECK(current_actor_id_.IsNil());
|
||||
current_actor_id_ = task_spec.ActorCreationId();
|
||||
current_task_is_direct_call_ = task_spec.IsDirectCall();
|
||||
current_actor_is_direct_call_ = task_spec.IsDirectCall();
|
||||
current_actor_max_concurrency_ = task_spec.MaxActorConcurrency();
|
||||
} else if (task_spec.IsActorTask()) {
|
||||
RAY_CHECK(current_job_id_ == task_spec.JobId());
|
||||
|
@ -118,8 +118,12 @@ std::shared_ptr<const TaskSpecification> WorkerContext::GetCurrentTask() const {
|
|||
|
||||
const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; }
|
||||
|
||||
bool WorkerContext::CurrentActorIsDirectCall() const {
|
||||
return current_actor_is_direct_call_;
|
||||
}
|
||||
|
||||
bool WorkerContext::CurrentTaskIsDirectCall() const {
|
||||
return current_task_is_direct_call_;
|
||||
return current_task_is_direct_call_ || current_actor_is_direct_call_;
|
||||
}
|
||||
|
||||
int WorkerContext::CurrentActorMaxConcurrency() const {
|
||||
|
|
|
@ -34,6 +34,11 @@ class WorkerContext {
|
|||
|
||||
const ActorID &GetCurrentActorID() const;
|
||||
|
||||
/// Returns whether we are in a direct call actor.
|
||||
bool CurrentActorIsDirectCall() const;
|
||||
|
||||
/// Returns whether we are in a direct call task. This encompasses both direct
|
||||
/// actor and normal tasks.
|
||||
bool CurrentTaskIsDirectCall() const;
|
||||
|
||||
int CurrentActorMaxConcurrency() const;
|
||||
|
@ -47,6 +52,7 @@ class WorkerContext {
|
|||
const WorkerID worker_id_;
|
||||
JobID current_job_id_;
|
||||
ActorID current_actor_id_;
|
||||
bool current_actor_is_direct_call_ = false;
|
||||
bool current_task_is_direct_call_ = false;
|
||||
int current_actor_max_concurrency_ = 1;
|
||||
|
||||
|
|
|
@ -25,9 +25,10 @@ void BuildCommonTaskSpec(
|
|||
// Set task arguments.
|
||||
for (const auto &arg : args) {
|
||||
if (arg.IsPassedByReference()) {
|
||||
// TODO(ekl) remove this check once we deprecate TaskTransportType::RAYLET
|
||||
if (transport_type == ray::TaskTransportType::RAYLET) {
|
||||
RAY_CHECK(!arg.GetReference().IsDirectCallType())
|
||||
<< "NotImplemented: passing direct call objects to other tasks";
|
||||
<< "Passing direct call objects to non-direct tasks is not allowed.";
|
||||
}
|
||||
builder.AddByRefArg(arg.GetReference());
|
||||
} else {
|
||||
|
@ -61,6 +62,37 @@ void GroupObjectIdsByStoreProvider(const std::vector<ObjectID> &object_ids,
|
|||
|
||||
namespace ray {
|
||||
|
||||
// Prepare direct call args for sending to a direct call *actor*. Direct call actors
|
||||
// always resolve their dependencies remotely, so we need some client-side preprocessing
|
||||
// to ensure they don't try to resolve a direct call object ID remotely (which is
|
||||
// impossible).
|
||||
// - Direct call args that are local and small will be inlined.
|
||||
// - Direct call args that are non-local or large will be promoted to plasma.
|
||||
// Note that args for direct call *tasks* are handled by LocalDependencyResolver.
|
||||
std::vector<TaskArg> PrepareDirectActorCallArgs(
|
||||
const std::vector<TaskArg> &args,
|
||||
std::shared_ptr<CoreWorkerMemoryStore> memory_store) {
|
||||
std::vector<TaskArg> out;
|
||||
for (const auto &arg : args) {
|
||||
if (arg.IsPassedByReference() && arg.GetReference().IsDirectCallType()) {
|
||||
const ObjectID &obj_id = arg.GetReference();
|
||||
// TODO(ekl) we should consider resolving these dependencies on the client side
|
||||
// for actor calls. It is a little tricky since we have to also preserve the
|
||||
// task ordering so we can't simply use LocalDependencyResolver.
|
||||
std::shared_ptr<RayObject> obj = memory_store->GetOrPromoteToPlasma(obj_id);
|
||||
if (obj != nullptr) {
|
||||
out.push_back(TaskArg::PassByValue(obj));
|
||||
} else {
|
||||
out.push_back(TaskArg::PassByReference(
|
||||
obj_id.WithTransportType(TaskTransportType::RAYLET)));
|
||||
}
|
||||
} else {
|
||||
out.push_back(arg);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||
const std::string &store_socket, const std::string &raylet_socket,
|
||||
const JobID &job_id, const gcs::GcsClientOptions &gcs_options,
|
||||
|
@ -78,8 +110,6 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
|||
client_call_manager_(new rpc::ClientCallManager(io_service_)),
|
||||
heartbeat_timer_(io_service_),
|
||||
core_worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */),
|
||||
memory_store_(std::make_shared<CoreWorkerMemoryStore>()),
|
||||
memory_store_provider_(memory_store_),
|
||||
task_execution_service_work_(task_execution_service_),
|
||||
task_execution_callback_(task_execution_callback),
|
||||
grpc_service_(io_service_, *this) {
|
||||
|
@ -159,6 +189,11 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
|||
|
||||
plasma_store_provider_.reset(
|
||||
new CoreWorkerPlasmaStoreProvider(store_socket, raylet_client_, check_signals_));
|
||||
memory_store_.reset(
|
||||
new CoreWorkerMemoryStore([this](const RayObject &obj, const ObjectID &obj_id) {
|
||||
RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id));
|
||||
}));
|
||||
memory_store_provider_.reset(new CoreWorkerMemoryStoreProvider(memory_store_));
|
||||
|
||||
// Create an entry for the driver task in the task table. This task is
|
||||
// added immediately with status RUNNING. This allows us to push errors
|
||||
|
@ -267,6 +302,15 @@ void CoreWorker::ReportActiveObjectIDs() {
|
|||
heartbeat_timer_.async_wait(boost::bind(&CoreWorker::ReportActiveObjectIDs, this));
|
||||
}
|
||||
|
||||
void CoreWorker::PromoteObjectToPlasma(const ObjectID &object_id) {
|
||||
RAY_CHECK(object_id.IsDirectCallType());
|
||||
auto value = memory_store_->GetOrPromoteToPlasma(object_id);
|
||||
if (value != nullptr) {
|
||||
RAY_CHECK_OK(
|
||||
plasma_store_provider_->Put(*value, object_id.WithPlasmaTransportType()));
|
||||
}
|
||||
}
|
||||
|
||||
Status CoreWorker::SetClientOptions(std::string name, int64_t limit_bytes) {
|
||||
// Currently only the Plasma store supports client options.
|
||||
return plasma_store_provider_->SetClientOptions(name, limit_bytes);
|
||||
|
@ -324,7 +368,7 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, const int64_t timeout_m
|
|||
local_timeout_ms = std::max(static_cast<int64_t>(0),
|
||||
timeout_ms - (current_time_ms() - start_time));
|
||||
}
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_.Get(memory_object_ids, local_timeout_ms,
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_->Get(memory_object_ids, local_timeout_ms,
|
||||
worker_context_.GetCurrentTaskID(),
|
||||
&result_map, &got_exception));
|
||||
}
|
||||
|
@ -379,7 +423,7 @@ Status CoreWorker::Contains(const ObjectID &object_id, bool *has_object) {
|
|||
if (object_id.IsDirectCallType()) {
|
||||
// Note that the memory store returns false if the object value is
|
||||
// ErrorType::OBJECT_IN_PLASMA.
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_.Contains(object_id, &found));
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_->Contains(object_id, &found));
|
||||
}
|
||||
if (!found) {
|
||||
// We check plasma as a fallback in all cases, since a direct call object
|
||||
|
@ -430,7 +474,7 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
|
|||
if (static_cast<int>(ready.size()) < num_objects && memory_object_ids.size() > 0) {
|
||||
// TODO(ekl) for memory objects that are ErrorType::OBJECT_IN_PLASMA, we should
|
||||
// consider waiting on them in plasma as well to ensure they are local.
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_.Wait(
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_->Wait(
|
||||
memory_object_ids, num_objects - static_cast<int>(ready.size()),
|
||||
/*timeout_ms=*/0, worker_context_.GetCurrentTaskID(), &ready));
|
||||
}
|
||||
|
@ -453,7 +497,7 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
|
|||
std::max(0, static_cast<int>(timeout_ms - (current_time_ms() - start_time)));
|
||||
}
|
||||
if (static_cast<int>(ready.size()) < num_objects && memory_object_ids.size() > 0) {
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_.Wait(
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_->Wait(
|
||||
memory_object_ids, num_objects - static_cast<int>(ready.size()), timeout_ms,
|
||||
worker_context_.GetCurrentTaskID(), &ready));
|
||||
}
|
||||
|
@ -477,7 +521,7 @@ Status CoreWorker::Delete(const std::vector<ObjectID> &object_ids, bool local_on
|
|||
|
||||
RAY_RETURN_NOT_OK(plasma_store_provider_->Delete(plasma_object_ids, local_only,
|
||||
delete_creating_tasks));
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_.Delete(memory_object_ids));
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_->Delete(memory_object_ids));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -602,10 +646,11 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
|
|||
const TaskID actor_task_id = TaskID::ForActorTask(
|
||||
worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(),
|
||||
next_task_index, actor_handle->GetActorID());
|
||||
BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id,
|
||||
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(),
|
||||
rpc_address_, function, args, num_returns, task_options.resources,
|
||||
{}, transport_type, return_ids);
|
||||
BuildCommonTaskSpec(
|
||||
builder, actor_handle->CreationJobID(), actor_task_id,
|
||||
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_,
|
||||
function, is_direct_call ? PrepareDirectActorCallArgs(args, memory_store_) : args,
|
||||
num_returns, task_options.resources, {}, transport_type, return_ids);
|
||||
|
||||
const ObjectID new_cursor = return_ids->back();
|
||||
actor_handle->SetActorTaskSpec(builder, transport_type, new_cursor);
|
||||
|
@ -848,7 +893,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
|
|||
void CoreWorker::HandleAssignTask(const rpc::AssignTaskRequest &request,
|
||||
rpc::AssignTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
if (worker_context_.CurrentTaskIsDirectCall()) {
|
||||
if (worker_context_.CurrentActorIsDirectCall()) {
|
||||
send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr,
|
||||
nullptr);
|
||||
return;
|
||||
|
|
|
@ -113,6 +113,15 @@ class CoreWorker {
|
|||
reference_counter_.RemoveReference(object_id);
|
||||
}
|
||||
|
||||
/// Promote an object to plasma. If it already exists locally, it will be
|
||||
/// put into the plasma store. If it doesn't yet exist, it will be spilled to
|
||||
/// plasma once available.
|
||||
///
|
||||
/// Postcondition: Get(object_id.WithPlasmaTransportType()) is valid.
|
||||
///
|
||||
/// \param[in] object_id The object ID to promote to plasma.
|
||||
void PromoteObjectToPlasma(const ObjectID &object_id);
|
||||
|
||||
///
|
||||
/// Public methods related to storing and retrieving objects.
|
||||
///
|
||||
|
@ -482,10 +491,10 @@ class CoreWorker {
|
|||
std::shared_ptr<CoreWorkerMemoryStore> memory_store_;
|
||||
|
||||
/// Plasma store interface.
|
||||
std::unique_ptr<CoreWorkerPlasmaStoreProvider> plasma_store_provider_;
|
||||
std::shared_ptr<CoreWorkerPlasmaStoreProvider> plasma_store_provider_;
|
||||
|
||||
/// In-memory store interface.
|
||||
CoreWorkerMemoryStoreProvider memory_store_provider_;
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> memory_store_provider_;
|
||||
|
||||
///
|
||||
/// Fields related to task submission.
|
||||
|
|
|
@ -107,7 +107,9 @@ std::shared_ptr<RayObject> GetRequest::Get(const ObjectID &object_id) const {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
CoreWorkerMemoryStore::CoreWorkerMemoryStore() {}
|
||||
CoreWorkerMemoryStore::CoreWorkerMemoryStore(
|
||||
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma)
|
||||
: store_in_plasma_(store_in_plasma) {}
|
||||
|
||||
void CoreWorkerMemoryStore::GetAsync(
|
||||
const ObjectID &object_id, std::function<void(std::shared_ptr<RayObject>)> callback) {
|
||||
|
@ -127,7 +129,25 @@ void CoreWorkerMemoryStore::GetAsync(
|
|||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<RayObject> CoreWorkerMemoryStore::GetOrPromoteToPlasma(
|
||||
const ObjectID &object_id) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto iter = objects_.find(object_id);
|
||||
if (iter != objects_.end()) {
|
||||
auto obj = iter->second;
|
||||
if (obj->IsInPlasmaError()) {
|
||||
return nullptr;
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
RAY_CHECK(store_in_plasma_ != nullptr)
|
||||
<< "Cannot promote object without plasma provider callback.";
|
||||
promoted_to_plasma_.insert(object_id);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &object) {
|
||||
RAY_CHECK(object_id.IsDirectCallType());
|
||||
std::vector<std::function<void(std::shared_ptr<RayObject>)>> async_callbacks;
|
||||
auto object_entry =
|
||||
std::make_shared<RayObject>(object.GetData(), object.GetMetadata(), true);
|
||||
|
@ -146,6 +166,13 @@ Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &ob
|
|||
object_async_get_requests_.erase(async_callback_it);
|
||||
}
|
||||
|
||||
auto promoted_it = promoted_to_plasma_.find(object_id);
|
||||
if (promoted_it != promoted_to_plasma_.end()) {
|
||||
RAY_CHECK(store_in_plasma_ != nullptr);
|
||||
store_in_plasma_(object, object_id.WithTransportType(TaskTransportType::RAYLET));
|
||||
promoted_to_plasma_.erase(promoted_it);
|
||||
}
|
||||
|
||||
bool should_add_entry = true;
|
||||
auto object_request_iter = object_get_requests_.find(object_id);
|
||||
if (object_request_iter != object_get_requests_.end()) {
|
||||
|
|
|
@ -18,7 +18,8 @@ class CoreWorkerMemoryStore;
|
|||
/// actor call (see direct_actor_transport.cc).
|
||||
class CoreWorkerMemoryStore {
|
||||
public:
|
||||
CoreWorkerMemoryStore();
|
||||
CoreWorkerMemoryStore(
|
||||
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma = nullptr);
|
||||
~CoreWorkerMemoryStore(){};
|
||||
|
||||
/// Put an object with specified ID into object store.
|
||||
|
@ -49,6 +50,14 @@ class CoreWorkerMemoryStore {
|
|||
void GetAsync(const ObjectID &object_id,
|
||||
std::function<void(std::shared_ptr<RayObject>)> callback);
|
||||
|
||||
/// Get a single object if available. If the object is not local yet, or if the object
|
||||
/// is local but is ErrorType::OBJECT_IN_PLASMA, then nullptr will be returned, and
|
||||
/// the store will ensure the object is promoted to plasma once available.
|
||||
///
|
||||
/// \param[in] object_id The object id to get.
|
||||
/// \return pointer to the local object, or nullptr if promoted to plasma.
|
||||
std::shared_ptr<RayObject> GetOrPromoteToPlasma(const ObjectID &object_id);
|
||||
|
||||
/// Delete a list of objects from the object store.
|
||||
///
|
||||
/// \param[in] object_ids IDs of the objects to delete.
|
||||
|
@ -62,6 +71,15 @@ class CoreWorkerMemoryStore {
|
|||
bool Contains(const ObjectID &object_id);
|
||||
|
||||
private:
|
||||
/// Optional callback for putting objects into the plasma store.
|
||||
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma_;
|
||||
|
||||
/// Protects the data structures below.
|
||||
absl::Mutex mu_;
|
||||
|
||||
/// Set of objects that should be promoted to plasma once available.
|
||||
absl::flat_hash_set<ObjectID> promoted_to_plasma_ GUARDED_BY(mu_);
|
||||
|
||||
/// Map from object ID to `RayObject`.
|
||||
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> objects_ GUARDED_BY(mu_);
|
||||
|
||||
|
@ -73,9 +91,6 @@ class CoreWorkerMemoryStore {
|
|||
absl::flat_hash_map<ObjectID,
|
||||
std::vector<std::function<void(std::shared_ptr<RayObject>)>>>
|
||||
object_async_get_requests_ GUARDED_BY(mu_);
|
||||
|
||||
/// Protect the two maps above.
|
||||
absl::Mutex mu_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
|
|
@ -14,6 +14,7 @@ CoreWorkerMemoryStoreProvider::CoreWorkerMemoryStoreProvider(
|
|||
|
||||
Status CoreWorkerMemoryStoreProvider::Put(const RayObject &object,
|
||||
const ObjectID &object_id) {
|
||||
RAY_CHECK(object_id.IsDirectCallType());
|
||||
Status status = store_->Put(object_id, object);
|
||||
if (status.IsObjectExists()) {
|
||||
// Object already exists in store, treat it as ok.
|
||||
|
|
|
@ -27,6 +27,7 @@ Status CoreWorkerPlasmaStoreProvider::SetClientOptions(std::string name,
|
|||
|
||||
Status CoreWorkerPlasmaStoreProvider::Put(const RayObject &object,
|
||||
const ObjectID &object_id) {
|
||||
RAY_CHECK(!object_id.IsDirectCallType());
|
||||
std::shared_ptr<Buffer> data;
|
||||
RAY_RETURN_NOT_OK(Create(object.GetMetadata(),
|
||||
object.HasData() ? object.GetData()->Size() : 0, object_id,
|
||||
|
|
|
@ -637,14 +637,14 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
|
|||
|
||||
std::vector<ObjectID> ids(buffers.size());
|
||||
for (size_t i = 0; i < ids.size(); i++) {
|
||||
ids[i] = ObjectID::FromRandom();
|
||||
ids[i] = ObjectID::FromRandom().WithDirectTransportType();
|
||||
RAY_CHECK_OK(provider.Put(buffers[i], ids[i]));
|
||||
}
|
||||
|
||||
absl::flat_hash_set<ObjectID> wait_ids(ids.begin(), ids.end());
|
||||
absl::flat_hash_set<ObjectID> wait_results;
|
||||
|
||||
ObjectID nonexistent_id = ObjectID::FromRandom();
|
||||
ObjectID nonexistent_id = ObjectID::FromRandom().WithDirectTransportType();
|
||||
wait_ids.insert(nonexistent_id);
|
||||
RAY_CHECK_OK(
|
||||
provider.Wait(wait_ids, ids.size() + 1, 100, RandomTaskId(), &wait_results));
|
||||
|
@ -693,9 +693,9 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
|
|||
std::vector<ObjectID> ready_ids(buffers.size());
|
||||
std::vector<ObjectID> unready_ids(buffers.size());
|
||||
for (size_t i = 0; i < unready_ids.size(); i++) {
|
||||
ready_ids[i] = ObjectID::FromRandom();
|
||||
ready_ids[i] = ObjectID::FromRandom().WithDirectTransportType();
|
||||
RAY_CHECK_OK(provider.Put(buffers[i], ready_ids[i]));
|
||||
unready_ids[i] = ObjectID::FromRandom();
|
||||
unready_ids[i] = ObjectID::FromRandom().WithDirectTransportType();
|
||||
}
|
||||
|
||||
auto thread_func = [&unready_ids, &provider, &buffers]() {
|
||||
|
|
|
@ -38,9 +38,33 @@ class MockRayletClient : public WorkerLeaseInterface {
|
|||
int num_workers_returned = 0;
|
||||
};
|
||||
|
||||
TEST(TestMemoryStore, TestPromoteToPlasma) {
|
||||
bool num_plasma_puts = 0;
|
||||
auto mem = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore(
|
||||
[&](const RayObject &obj, const ObjectID &obj_id) { num_plasma_puts += 1; }));
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
auto data = GenerateRandomObject();
|
||||
ASSERT_TRUE(mem->Put(obj1, *data).ok());
|
||||
|
||||
// Test getting an already existing object.
|
||||
ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj1) != nullptr);
|
||||
ASSERT_TRUE(num_plasma_puts == 0);
|
||||
|
||||
// Testing getting an object that doesn't exist yet causes promotion.
|
||||
ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj2) == nullptr);
|
||||
ASSERT_TRUE(num_plasma_puts == 0);
|
||||
ASSERT_TRUE(mem->Put(obj2, *data).ok());
|
||||
ASSERT_TRUE(num_plasma_puts == 1);
|
||||
|
||||
// The next time you get it, it's already there so no need to promote.
|
||||
ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj2) != nullptr);
|
||||
ASSERT_TRUE(num_plasma_puts == 1);
|
||||
}
|
||||
|
||||
TEST(LocalDependencyResolverTest, TestNoDependencies) {
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
CoreWorkerMemoryStoreProvider store(ptr);
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
LocalDependencyResolver resolver(store);
|
||||
TaskSpecification task;
|
||||
bool ok = false;
|
||||
|
@ -50,7 +74,7 @@ TEST(LocalDependencyResolverTest, TestNoDependencies) {
|
|||
|
||||
TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
CoreWorkerMemoryStoreProvider store(ptr);
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
LocalDependencyResolver resolver(store);
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::RAYLET);
|
||||
TaskSpecification task;
|
||||
|
@ -62,16 +86,38 @@ TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
|
|||
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
||||
}
|
||||
|
||||
TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
LocalDependencyResolver resolver(store);
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
std::string meta = std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
|
||||
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||
auto data = RayObject(nullptr, meta_buffer);
|
||||
ASSERT_TRUE(store->Put(data, obj1).ok());
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
|
||||
ASSERT_TRUE(task.ArgId(0, 0).IsDirectCallType());
|
||||
bool ok = false;
|
||||
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
|
||||
ASSERT_TRUE(ok);
|
||||
ASSERT_TRUE(task.ArgByRef(0));
|
||||
// Checks that the object id was promoted to a plasma type id.
|
||||
ASSERT_FALSE(task.ArgId(0, 0).IsDirectCallType());
|
||||
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
||||
}
|
||||
|
||||
TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
CoreWorkerMemoryStoreProvider store(ptr);
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
LocalDependencyResolver resolver(store);
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
auto data = GenerateRandomObject();
|
||||
// Ensure the data is already present in the local store.
|
||||
ASSERT_TRUE(store.Put(*data, obj1).ok());
|
||||
ASSERT_TRUE(store.Put(*data, obj2).ok());
|
||||
ASSERT_TRUE(store->Put(*data, obj1).ok());
|
||||
ASSERT_TRUE(store->Put(*data, obj2).ok());
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
|
||||
task.GetMutableMessage().add_args()->add_object_ids(obj2.Binary());
|
||||
|
@ -88,7 +134,7 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
|
|||
|
||||
TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
CoreWorkerMemoryStoreProvider store(ptr);
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
LocalDependencyResolver resolver(store);
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
|
@ -100,8 +146,8 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
|
|||
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
|
||||
ASSERT_EQ(resolver.NumPendingTasks(), 1);
|
||||
ASSERT_TRUE(!ok);
|
||||
ASSERT_TRUE(store.Put(*data, obj1).ok());
|
||||
ASSERT_TRUE(store.Put(*data, obj2).ok());
|
||||
ASSERT_TRUE(store->Put(*data, obj1).ok());
|
||||
ASSERT_TRUE(store->Put(*data, obj2).ok());
|
||||
// Tests that the task proto was rewritten to have inline argument values after
|
||||
// resolution completes.
|
||||
ASSERT_TRUE(ok);
|
||||
|
@ -112,10 +158,11 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
|
|||
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTranportTest, TestSubmitOneTask) {
|
||||
TEST(DirectTaskTransportTest, TestSubmitOneTask) {
|
||||
MockRayletClient raylet_client;
|
||||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
|
||||
TaskSpecification task;
|
||||
|
@ -133,10 +180,11 @@ TEST(DirectTaskTranportTest, TestSubmitOneTask) {
|
|||
ASSERT_EQ(raylet_client.num_workers_returned, 1);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTranportTest, TestHandleTaskFailure) {
|
||||
TEST(DirectTaskTransportTest, TestHandleTaskFailure) {
|
||||
MockRayletClient raylet_client;
|
||||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
|
||||
TaskSpecification task;
|
||||
|
@ -150,10 +198,11 @@ TEST(DirectTaskTranportTest, TestHandleTaskFailure) {
|
|||
ASSERT_EQ(raylet_client.num_workers_returned, 1);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTranportTest, TestConcurrentWorkerLeases) {
|
||||
TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
|
||||
MockRayletClient raylet_client;
|
||||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
|
||||
TaskSpecification task1;
|
||||
|
@ -190,10 +239,11 @@ TEST(DirectTaskTranportTest, TestConcurrentWorkerLeases) {
|
|||
ASSERT_EQ(raylet_client.num_workers_returned, 3);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTranportTest, TestReuseWorkerLease) {
|
||||
TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
|
||||
MockRayletClient raylet_client;
|
||||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
|
||||
TaskSpecification task1;
|
||||
|
@ -232,10 +282,11 @@ TEST(DirectTaskTranportTest, TestReuseWorkerLease) {
|
|||
ASSERT_EQ(raylet_client.num_workers_returned, 2);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTranportTest, TestWorkerNotReusedOnError) {
|
||||
TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
|
||||
MockRayletClient raylet_client;
|
||||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
|
||||
TaskSpecification task1;
|
||||
|
|
|
@ -5,9 +5,61 @@ using ray::rpc::ActorTableData;
|
|||
|
||||
namespace ray {
|
||||
|
||||
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
|
||||
const rpc::ErrorType &error_type,
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> &in_memory_store) {
|
||||
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
|
||||
<< ", error_type: " << ErrorType_Name(error_type);
|
||||
for (int i = 0; i < num_returns; i++) {
|
||||
const auto object_id = ObjectID::ForTaskReturn(
|
||||
task_id, /*index=*/i + 1,
|
||||
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
|
||||
std::string meta = std::to_string(static_cast<int>(error_type));
|
||||
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||
RAY_CHECK_OK(in_memory_store->Put(RayObject(nullptr, meta_buffer), object_id));
|
||||
}
|
||||
}
|
||||
|
||||
void WriteObjectsToMemoryStore(
|
||||
const rpc::PushTaskReply &reply,
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> &in_memory_store) {
|
||||
for (int i = 0; i < reply.return_objects_size(); i++) {
|
||||
const auto &return_object = reply.return_objects(i);
|
||||
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
|
||||
|
||||
if (return_object.in_plasma()) {
|
||||
// Mark it as in plasma with a dummy object.
|
||||
std::string meta =
|
||||
std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
|
||||
auto metadata =
|
||||
const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||
RAY_CHECK_OK(in_memory_store->Put(RayObject(nullptr, meta_buffer), object_id));
|
||||
} else {
|
||||
std::shared_ptr<LocalMemoryBuffer> data_buffer;
|
||||
if (return_object.data().size() > 0) {
|
||||
data_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||
const_cast<uint8_t *>(
|
||||
reinterpret_cast<const uint8_t *>(return_object.data().data())),
|
||||
return_object.data().size());
|
||||
}
|
||||
std::shared_ptr<LocalMemoryBuffer> metadata_buffer;
|
||||
if (return_object.metadata().size() > 0) {
|
||||
metadata_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||
const_cast<uint8_t *>(
|
||||
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
|
||||
return_object.metadata().size());
|
||||
}
|
||||
RAY_CHECK_OK(
|
||||
in_memory_store->Put(RayObject(data_buffer, metadata_buffer), object_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter(
|
||||
rpc::ClientCallManager &client_call_manager,
|
||||
CoreWorkerMemoryStoreProvider store_provider)
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
|
||||
: client_call_manager_(client_call_manager), in_memory_store_(store_provider) {}
|
||||
|
||||
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
||||
|
@ -49,7 +101,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
|
|||
} else {
|
||||
// Actor is dead, treat the task as failure.
|
||||
RAY_CHECK(iter->second.state_ == ActorTableData::DEAD);
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED);
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, in_memory_store_);
|
||||
}
|
||||
|
||||
// If the task submission subsequently fails, then the client will receive
|
||||
|
@ -84,7 +136,8 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
|
|||
for (const auto &entry : iter->second) {
|
||||
const auto &task_id = entry.first;
|
||||
const auto num_returns = entry.second;
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED);
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
|
||||
in_memory_store_);
|
||||
}
|
||||
waiting_reply_tasks_.erase(actor_id);
|
||||
}
|
||||
|
@ -94,7 +147,8 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
|
|||
if (pending_it != pending_requests_.end()) {
|
||||
for (const auto &request : pending_it->second) {
|
||||
TreatTaskAsFailed(TaskID::FromBinary(request->task_spec().task_id()),
|
||||
request->task_spec().num_returns(), rpc::ErrorType::ACTOR_DIED);
|
||||
request->task_spec().num_returns(), rpc::ErrorType::ACTOR_DIED,
|
||||
in_memory_store_);
|
||||
}
|
||||
pending_requests_.erase(pending_it);
|
||||
}
|
||||
|
@ -136,59 +190,14 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(
|
|||
// Note that this might be the __ray_terminate__ task, so we don't log
|
||||
// loudly with ERROR here.
|
||||
RAY_LOG(INFO) << "Task failed with error: " << status;
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED);
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
|
||||
in_memory_store_);
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < reply.return_objects_size(); i++) {
|
||||
const auto &return_object = reply.return_objects(i);
|
||||
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
|
||||
|
||||
if (return_object.in_plasma()) {
|
||||
// Mark it as in plasma with a dummy object.
|
||||
std::string meta =
|
||||
std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
|
||||
auto metadata =
|
||||
const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||
RAY_CHECK_OK(
|
||||
in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id));
|
||||
} else {
|
||||
std::shared_ptr<LocalMemoryBuffer> data_buffer;
|
||||
if (return_object.data().size() > 0) {
|
||||
data_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||
const_cast<uint8_t *>(
|
||||
reinterpret_cast<const uint8_t *>(return_object.data().data())),
|
||||
return_object.data().size());
|
||||
}
|
||||
std::shared_ptr<LocalMemoryBuffer> metadata_buffer;
|
||||
if (return_object.metadata().size() > 0) {
|
||||
metadata_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||
const_cast<uint8_t *>(
|
||||
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
|
||||
return_object.metadata().size());
|
||||
}
|
||||
RAY_CHECK_OK(
|
||||
in_memory_store_.Put(RayObject(data_buffer, metadata_buffer), object_id));
|
||||
}
|
||||
}
|
||||
WriteObjectsToMemoryStore(reply, in_memory_store_);
|
||||
});
|
||||
if (!status.ok()) {
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED);
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorkerDirectActorTaskSubmitter::TreatTaskAsFailed(
|
||||
const TaskID &task_id, int num_returns, const rpc::ErrorType &error_type) {
|
||||
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
|
||||
<< ", error_type: " << ErrorType_Name(error_type);
|
||||
for (int i = 0; i < num_returns; i++) {
|
||||
const auto object_id = ObjectID::ForTaskReturn(
|
||||
task_id, /*index=*/i + 1,
|
||||
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
|
||||
std::string meta = std::to_string(static_cast<int>(error_type));
|
||||
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||
RAY_CHECK_OK(in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id));
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, in_memory_store_);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -301,8 +310,8 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask(
|
|||
send_reply_callback(status, nullptr, nullptr);
|
||||
},
|
||||
[send_reply_callback]() {
|
||||
send_reply_callback(Status::Invalid("client cancelled rpc"), nullptr,
|
||||
nullptr);
|
||||
send_reply_callback(Status::Invalid("client cancelled stale rpc"),
|
||||
nullptr, nullptr);
|
||||
},
|
||||
dependencies);
|
||||
}
|
||||
|
|
|
@ -22,6 +22,26 @@ namespace ray {
|
|||
/// The max time to wait for out-of-order tasks.
|
||||
const int kMaxReorderWaitSeconds = 30;
|
||||
|
||||
/// Treat a task as failed.
|
||||
///
|
||||
/// \param[in] task_id The ID of a task.
|
||||
/// \param[in] num_returns Number of return objects.
|
||||
/// \param[in] error_type The type of the specific error.
|
||||
/// \param[in] in_memory_store The memory store to write to.
|
||||
/// \return Void.
|
||||
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
|
||||
const rpc::ErrorType &error_type,
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> &in_memory_store);
|
||||
|
||||
/// Write return objects to the memory store.
|
||||
///
|
||||
/// \param[in] reply Proto response to a direct actor or task call.
|
||||
/// \param[in] in_memory_store The memory store to write to.
|
||||
/// \return Void.
|
||||
void WriteObjectsToMemoryStore(
|
||||
const rpc::PushTaskReply &reply,
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> &in_memory_store);
|
||||
|
||||
/// In direct actor call task submitter and receiver, a task is directly submitted
|
||||
/// to the actor that will execute it.
|
||||
|
||||
|
@ -40,8 +60,9 @@ struct ActorStateData {
|
|||
// This class is thread-safe.
|
||||
class CoreWorkerDirectActorTaskSubmitter {
|
||||
public:
|
||||
CoreWorkerDirectActorTaskSubmitter(rpc::ClientCallManager &client_call_manager,
|
||||
CoreWorkerMemoryStoreProvider store_provider);
|
||||
CoreWorkerDirectActorTaskSubmitter(
|
||||
rpc::ClientCallManager &client_call_manager,
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider);
|
||||
|
||||
/// Submit a task to an actor for execution.
|
||||
///
|
||||
|
@ -70,15 +91,6 @@ class CoreWorkerDirectActorTaskSubmitter {
|
|||
std::unique_ptr<rpc::PushTaskRequest> request,
|
||||
const ActorID &actor_id, const TaskID &task_id, int num_returns);
|
||||
|
||||
/// Treat a task as failed.
|
||||
///
|
||||
/// \param[in] task_id The ID of a task.
|
||||
/// \param[in] num_returns Number of return objects.
|
||||
/// \param[in] error_type The type of the specific error.
|
||||
/// \return Void.
|
||||
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
|
||||
const rpc::ErrorType &error_type);
|
||||
|
||||
/// Create connection to actor and send all pending tasks.
|
||||
/// Note that this function doesn't take lock, the caller is expected to hold
|
||||
/// `mutex_` before calling this function.
|
||||
|
@ -120,7 +132,7 @@ class CoreWorkerDirectActorTaskSubmitter {
|
|||
std::unordered_map<ActorID, std::unordered_map<TaskID, int>> waiting_reply_tasks_;
|
||||
|
||||
/// The store provider.
|
||||
CoreWorkerMemoryStoreProvider in_memory_store_;
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> in_memory_store_;
|
||||
|
||||
friend class CoreWorkerTest;
|
||||
};
|
||||
|
@ -230,7 +242,8 @@ class SchedulingQueue {
|
|||
std::function<void()> accept_request, std::function<void()> reject_request,
|
||||
const std::vector<ObjectID> &dependencies = {}) {
|
||||
if (seq_no == -1) {
|
||||
seq_no = next_seq_no_; // A value of -1 means no ordering constraint.
|
||||
accept_request(); // A seq_no of -1 means no ordering constraint.
|
||||
return;
|
||||
}
|
||||
RAY_CHECK(boost::this_thread::get_id() == main_thread_id_);
|
||||
if (client_processed_up_to >= next_seq_no_) {
|
||||
|
@ -259,6 +272,8 @@ class SchedulingQueue {
|
|||
// Cancel any stale requests that the client doesn't need any longer.
|
||||
while (!pending_tasks_.empty() && pending_tasks_.begin()->first < next_seq_no_) {
|
||||
auto head = pending_tasks_.begin();
|
||||
RAY_LOG(ERROR) << "Cancelling stale RPC with seqno "
|
||||
<< pending_tasks_.begin()->first << " < " << next_seq_no_;
|
||||
head->second.Cancel();
|
||||
pending_tasks_.erase(head);
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#include "ray/core_worker/transport/direct_task_transport.h"
|
||||
#include "ray/core_worker/transport/direct_actor_transport.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
|
@ -13,6 +14,12 @@ void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr<RayObject> valu
|
|||
if (id == obj_id) {
|
||||
auto *mutable_arg = msg.mutable_args(i);
|
||||
mutable_arg->clear_object_ids();
|
||||
if (value->IsInPlasmaError()) {
|
||||
// Promote the object id to plasma.
|
||||
mutable_arg->add_object_ids(
|
||||
obj_id.WithTransportType(TaskTransportType::RAYLET).Binary());
|
||||
} else {
|
||||
// Inline the object value.
|
||||
if (value->HasData()) {
|
||||
const auto &data = value->GetData();
|
||||
mutable_arg->set_data(data->Data(), data->Size());
|
||||
|
@ -21,6 +28,7 @@ void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr<RayObject> valu
|
|||
const auto &metadata = value->GetMetadata();
|
||||
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
|
||||
}
|
||||
}
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
|
@ -52,7 +60,7 @@ void LocalDependencyResolver::ResolveDependencies(const TaskSpecification &task,
|
|||
num_pending_ += 1;
|
||||
|
||||
for (const auto &obj_id : state->local_dependencies) {
|
||||
in_memory_store_.GetAsync(
|
||||
in_memory_store_->GetAsync(
|
||||
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
|
||||
RAY_CHECK(obj != nullptr);
|
||||
bool complete = false;
|
||||
|
@ -128,23 +136,6 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(
|
|||
worker_request_pending_ = true;
|
||||
}
|
||||
|
||||
void CoreWorkerDirectTaskSubmitter::TreatTaskAsFailed(const TaskID &task_id,
|
||||
int num_returns,
|
||||
const rpc::ErrorType &error_type) {
|
||||
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
|
||||
<< ", error_type: " << ErrorType_Name(error_type);
|
||||
for (int i = 0; i < num_returns; i++) {
|
||||
const auto object_id = ObjectID::ForTaskReturn(
|
||||
task_id, /*index=*/i + 1,
|
||||
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
|
||||
std::string meta = std::to_string(static_cast<int>(error_type));
|
||||
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||
RAY_CHECK_OK(in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(ekl) consider reconsolidating with DirectActorTransport.
|
||||
void CoreWorkerDirectTaskSubmitter::PushNormalTask(const WorkerAddress &addr,
|
||||
rpc::CoreWorkerClientInterface &client,
|
||||
TaskSpecification &task_spec) {
|
||||
|
@ -157,30 +148,15 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(const WorkerAddress &addr,
|
|||
[this, task_id, num_returns, addr](Status status, const rpc::PushTaskReply &reply) {
|
||||
OnWorkerIdle(addr, /*error=*/!status.ok());
|
||||
if (!status.ok()) {
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED);
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED,
|
||||
in_memory_store_);
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < reply.return_objects_size(); i++) {
|
||||
const auto &return_object = reply.return_objects(i);
|
||||
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
|
||||
std::shared_ptr<LocalMemoryBuffer> data_buffer;
|
||||
if (return_object.data().size() > 0) {
|
||||
data_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||
const_cast<uint8_t *>(
|
||||
reinterpret_cast<const uint8_t *>(return_object.data().data())),
|
||||
return_object.data().size());
|
||||
}
|
||||
std::shared_ptr<LocalMemoryBuffer> metadata_buffer;
|
||||
if (return_object.metadata().size() > 0) {
|
||||
metadata_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||
const_cast<uint8_t *>(
|
||||
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
|
||||
return_object.metadata().size());
|
||||
}
|
||||
RAY_CHECK_OK(
|
||||
in_memory_store_.Put(RayObject(data_buffer, metadata_buffer), object_id));
|
||||
}
|
||||
WriteObjectsToMemoryStore(reply, in_memory_store_);
|
||||
});
|
||||
RAY_CHECK_OK(status);
|
||||
if (!status.ok()) {
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED,
|
||||
in_memory_store_);
|
||||
}
|
||||
}
|
||||
}; // namespace ray
|
||||
|
|
|
@ -23,7 +23,7 @@ struct TaskState {
|
|||
// This class is thread-safe.
|
||||
class LocalDependencyResolver {
|
||||
public:
|
||||
LocalDependencyResolver(CoreWorkerMemoryStoreProvider &store_provider)
|
||||
LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
|
||||
: in_memory_store_(store_provider), num_pending_(0) {}
|
||||
|
||||
/// Resolve all local and remote dependencies for the task, calling the specified
|
||||
|
@ -42,7 +42,7 @@ class LocalDependencyResolver {
|
|||
|
||||
private:
|
||||
/// The store provider.
|
||||
CoreWorkerMemoryStoreProvider in_memory_store_;
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> in_memory_store_;
|
||||
|
||||
/// Number of tasks pending dependency resolution.
|
||||
std::atomic<int> num_pending_;
|
||||
|
@ -58,9 +58,9 @@ typedef std::function<std::shared_ptr<rpc::CoreWorkerClientInterface>(WorkerAddr
|
|||
// This class is thread-safe.
|
||||
class CoreWorkerDirectTaskSubmitter {
|
||||
public:
|
||||
CoreWorkerDirectTaskSubmitter(WorkerLeaseInterface &lease_client,
|
||||
ClientFactoryFn client_factory,
|
||||
CoreWorkerMemoryStoreProvider store_provider)
|
||||
CoreWorkerDirectTaskSubmitter(
|
||||
WorkerLeaseInterface &lease_client, ClientFactoryFn client_factory,
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
|
||||
: lease_client_(lease_client),
|
||||
client_factory_(client_factory),
|
||||
in_memory_store_(store_provider),
|
||||
|
@ -93,10 +93,6 @@ class CoreWorkerDirectTaskSubmitter {
|
|||
void PushNormalTask(const WorkerAddress &addr, rpc::CoreWorkerClientInterface &client,
|
||||
TaskSpecification &task_spec);
|
||||
|
||||
/// Mark a direct call as failed by storing errors for its return objects.
|
||||
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
|
||||
const rpc::ErrorType &error_type);
|
||||
|
||||
// Client that can be used to lease and return workers.
|
||||
WorkerLeaseInterface &lease_client_;
|
||||
|
||||
|
@ -104,7 +100,7 @@ class CoreWorkerDirectTaskSubmitter {
|
|||
ClientFactoryFn client_factory_;
|
||||
|
||||
/// The store provider.
|
||||
CoreWorkerMemoryStoreProvider in_memory_store_;
|
||||
std::shared_ptr<CoreWorkerMemoryStoreProvider> in_memory_store_;
|
||||
|
||||
/// Resolve local and remote dependencies;
|
||||
LocalDependencyResolver resolver_;
|
||||
|
|
Loading…
Add table
Reference in a new issue