Handle exchange of direct call objects between tasks and actors (#6147)

This commit is contained in:
Eric Liang 2019-11-14 17:32:04 -08:00 committed by GitHub
parent 385783fcec
commit 8ff393a7bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 426 additions and 202 deletions

View file

@ -1015,6 +1015,12 @@ cdef class CoreWorker:
# Note: faster to not release GIL for short-running op.
self.core_worker.get().RemoveObjectIDReference(c_object_id)
def promote_object_to_plasma(self, ObjectID object_id):
cdef:
CObjectID c_object_id = object_id.native()
self.core_worker.get().PromoteObjectToPlasma(c_object_id)
return object_id.with_plasma_transport_type()
# TODO: handle noreturn better
cdef store_task_outputs(
self, worker, outputs, const c_vector[CObjectID] return_ids,

View file

@ -104,6 +104,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
*bytes)
void AddObjectIDReference(const CObjectID &object_id)
void RemoveObjectIDReference(const CObjectID &object_id)
void PromoteObjectToPlasma(const CObjectID &object_id)
CRayStatus SetClientOptions(c_string client_name, int64_t limit)
CRayStatus Put(const CRayObject &object, CObjectID *object_id)

View file

@ -95,8 +95,10 @@ cdef class TaskSpec:
:self.task_spec.get().ArgMetadataSize(i)]
if metadata == RAW_BUFFER_METADATA:
obj = data
else:
elif metadata == PICKLE_BUFFER_METADATA:
obj = pickle.loads(data)
else:
obj = data
arg_list.append(obj)
elif lang == <int32_t>LANGUAGE_JAVA:
arg_list = num_args * ["<java-argument>"]

View file

@ -150,6 +150,8 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
c_bool IsDirectCallType()
CObjectID WithPlasmaTransportType()
int64_t ObjectIndex() const
CTaskID TaskId() const

View file

@ -179,6 +179,9 @@ cdef class ObjectID(BaseID):
def is_direct_call_type(self):
return self.data.IsDirectCallType()
def with_plasma_transport_type(self):
return ObjectID(self.data.WithPlasmaTransportType().Binary())
def is_nil(self):
return self.data.IsNil()

View file

@ -158,9 +158,7 @@ class SerializationContext(object):
def id_serializer(obj):
if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type():
raise NotImplementedError(
"Objects produced by direct actor calls cannot be "
"passed to other tasks as arguments.")
obj = self.worker.core_worker.promote_object_to_plasma(obj)
return pickle.dumps(obj)
def id_deserializer(serialized_obj):
@ -191,9 +189,7 @@ class SerializationContext(object):
def id_serializer(obj):
if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type():
raise NotImplementedError(
"Objects produced by direct actor calls cannot be "
"passed to other tasks as arguments.")
obj = self.worker.core_worker.promote_object_to_plasma(obj)
return obj.__reduce__()
def id_deserializer(serialized_obj):

View file

@ -1218,6 +1218,71 @@ def test_direct_call_simple(ray_start_regular):
range(1, 101))
def test_direct_call_matrix(shutdown_only):
ray.init(object_store_memory=1000 * 1024 * 1024)
@ray.remote
class Actor(object):
def small_value(self):
return 0
def large_value(self):
return np.zeros(10 * 1024 * 1024)
def echo(self, x):
if isinstance(x, list):
x = ray.get(x[0])
return x
@ray.remote
def small_value():
return 0
@ray.remote
def large_value():
return np.zeros(10 * 1024 * 1024)
@ray.remote
def echo(x):
if isinstance(x, list):
x = ray.get(x[0])
return x
def check(source_actor, dest_actor, is_large, out_of_band):
print("CHECKING", "actor" if source_actor else "task", "to", "actor"
if dest_actor else "task", "large_object"
if is_large else "small_object", "out_of_band"
if out_of_band else "in_band")
if source_actor:
a = Actor.options(is_direct_call=True).remote()
if is_large:
x_id = a.large_value.remote()
else:
x_id = a.small_value.remote()
else:
if is_large:
x_id = large_value.options(is_direct_call=True).remote()
else:
x_id = small_value.options(is_direct_call=True).remote()
if out_of_band:
x_id = [x_id]
if dest_actor:
b = Actor.options(is_direct_call=True).remote()
x = ray.get(b.echo.remote(x_id))
else:
x = ray.get(echo.options(is_direct_call=True).remote(x_id))
if is_large:
assert isinstance(x, np.ndarray)
else:
assert isinstance(x, int)
for is_large in [False, True]:
for source_actor in [False, True]:
for dest_actor in [False, True]:
for out_of_band in [False, True]:
check(source_actor, dest_actor, is_large, out_of_band)
def test_direct_call_chain(ray_start_regular):
@ray.remote
def g(x):
@ -1265,26 +1330,6 @@ def test_direct_actor_large_objects(ray_start_regular):
assert isinstance(ray.get(obj_id), np.ndarray)
def test_direct_actor_errors(ray_start_regular):
@ray.remote
class Actor(object):
def __init__(self):
pass
def f(self, x):
return x * 2
@ray.remote
def f(x):
return 1
a = Actor._remote(is_direct_call=True)
# cannot pass returns to other methods even in a list
with pytest.raises(Exception):
ray.get(f.remote([a.f.remote(2)]))
def test_direct_actor_pass_by_ref(ray_start_regular):
@ray.remote
class Actor(object):

View file

@ -158,6 +158,14 @@ ObjectID ObjectID::WithTransportType(TaskTransportType transport_type) const {
return copy;
}
ObjectID ObjectID::WithPlasmaTransportType() const {
return WithTransportType(TaskTransportType::RAYLET);
}
ObjectID ObjectID::WithDirectTransportType() const {
return WithTransportType(TaskTransportType::DIRECT);
}
uint8_t ObjectID::GetTransportType() const {
return ::ray::GetTransportType(this->GetFlags());
}
@ -290,10 +298,6 @@ TaskID TaskID::ComputeDriverTaskId(const WorkerID &driver_id) {
}
TaskID ObjectID::TaskId() const {
if (!CreatedByTask()) {
// TODO(qwang): Should be RAY_CHECK here.
RAY_LOG(WARNING) << "Shouldn't call this on a non-task object id: " << this->Hex();
}
return TaskID::FromBinary(
std::string(reinterpret_cast<const char *>(id_), TaskID::Size()));
}

View file

@ -299,6 +299,16 @@ class ObjectID : public BaseID<ObjectID> {
/// \return Copy of this object id with the specified transport type.
ObjectID WithTransportType(TaskTransportType transport_type) const;
/// Return this object id with the plasma transport type.
///
/// \return Copy of this object id with the plasma transport type.
ObjectID WithPlasmaTransportType() const;
/// Return this object id with the direct call transport type.
///
/// \return Copy of this object id with the direct call transport type.
ObjectID WithDirectTransportType() const;
/// Get the transport type of this object.
///
/// \return The type of the transport which is used to transfer this object.

View file

@ -95,7 +95,7 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
SetCurrentJobId(task_spec.JobId());
RAY_CHECK(current_actor_id_.IsNil());
current_actor_id_ = task_spec.ActorCreationId();
current_task_is_direct_call_ = task_spec.IsDirectCall();
current_actor_is_direct_call_ = task_spec.IsDirectCall();
current_actor_max_concurrency_ = task_spec.MaxActorConcurrency();
} else if (task_spec.IsActorTask()) {
RAY_CHECK(current_job_id_ == task_spec.JobId());
@ -118,8 +118,12 @@ std::shared_ptr<const TaskSpecification> WorkerContext::GetCurrentTask() const {
const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; }
bool WorkerContext::CurrentActorIsDirectCall() const {
return current_actor_is_direct_call_;
}
bool WorkerContext::CurrentTaskIsDirectCall() const {
return current_task_is_direct_call_;
return current_task_is_direct_call_ || current_actor_is_direct_call_;
}
int WorkerContext::CurrentActorMaxConcurrency() const {

View file

@ -34,6 +34,11 @@ class WorkerContext {
const ActorID &GetCurrentActorID() const;
/// Returns whether we are in a direct call actor.
bool CurrentActorIsDirectCall() const;
/// Returns whether we are in a direct call task. This encompasses both direct
/// actor and normal tasks.
bool CurrentTaskIsDirectCall() const;
int CurrentActorMaxConcurrency() const;
@ -47,6 +52,7 @@ class WorkerContext {
const WorkerID worker_id_;
JobID current_job_id_;
ActorID current_actor_id_;
bool current_actor_is_direct_call_ = false;
bool current_task_is_direct_call_ = false;
int current_actor_max_concurrency_ = 1;

View file

@ -25,9 +25,10 @@ void BuildCommonTaskSpec(
// Set task arguments.
for (const auto &arg : args) {
if (arg.IsPassedByReference()) {
// TODO(ekl) remove this check once we deprecate TaskTransportType::RAYLET
if (transport_type == ray::TaskTransportType::RAYLET) {
RAY_CHECK(!arg.GetReference().IsDirectCallType())
<< "NotImplemented: passing direct call objects to other tasks";
<< "Passing direct call objects to non-direct tasks is not allowed.";
}
builder.AddByRefArg(arg.GetReference());
} else {
@ -61,6 +62,37 @@ void GroupObjectIdsByStoreProvider(const std::vector<ObjectID> &object_ids,
namespace ray {
// Prepare direct call args for sending to a direct call *actor*. Direct call actors
// always resolve their dependencies remotely, so we need some client-side preprocessing
// to ensure they don't try to resolve a direct call object ID remotely (which is
// impossible).
// - Direct call args that are local and small will be inlined.
// - Direct call args that are non-local or large will be promoted to plasma.
// Note that args for direct call *tasks* are handled by LocalDependencyResolver.
std::vector<TaskArg> PrepareDirectActorCallArgs(
const std::vector<TaskArg> &args,
std::shared_ptr<CoreWorkerMemoryStore> memory_store) {
std::vector<TaskArg> out;
for (const auto &arg : args) {
if (arg.IsPassedByReference() && arg.GetReference().IsDirectCallType()) {
const ObjectID &obj_id = arg.GetReference();
// TODO(ekl) we should consider resolving these dependencies on the client side
// for actor calls. It is a little tricky since we have to also preserve the
// task ordering so we can't simply use LocalDependencyResolver.
std::shared_ptr<RayObject> obj = memory_store->GetOrPromoteToPlasma(obj_id);
if (obj != nullptr) {
out.push_back(TaskArg::PassByValue(obj));
} else {
out.push_back(TaskArg::PassByReference(
obj_id.WithTransportType(TaskTransportType::RAYLET)));
}
} else {
out.push_back(arg);
}
}
return out;
}
CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id, const gcs::GcsClientOptions &gcs_options,
@ -78,8 +110,6 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
client_call_manager_(new rpc::ClientCallManager(io_service_)),
heartbeat_timer_(io_service_),
core_worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */),
memory_store_(std::make_shared<CoreWorkerMemoryStore>()),
memory_store_provider_(memory_store_),
task_execution_service_work_(task_execution_service_),
task_execution_callback_(task_execution_callback),
grpc_service_(io_service_, *this) {
@ -159,6 +189,11 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
plasma_store_provider_.reset(
new CoreWorkerPlasmaStoreProvider(store_socket, raylet_client_, check_signals_));
memory_store_.reset(
new CoreWorkerMemoryStore([this](const RayObject &obj, const ObjectID &obj_id) {
RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id));
}));
memory_store_provider_.reset(new CoreWorkerMemoryStoreProvider(memory_store_));
// Create an entry for the driver task in the task table. This task is
// added immediately with status RUNNING. This allows us to push errors
@ -267,6 +302,15 @@ void CoreWorker::ReportActiveObjectIDs() {
heartbeat_timer_.async_wait(boost::bind(&CoreWorker::ReportActiveObjectIDs, this));
}
void CoreWorker::PromoteObjectToPlasma(const ObjectID &object_id) {
RAY_CHECK(object_id.IsDirectCallType());
auto value = memory_store_->GetOrPromoteToPlasma(object_id);
if (value != nullptr) {
RAY_CHECK_OK(
plasma_store_provider_->Put(*value, object_id.WithPlasmaTransportType()));
}
}
Status CoreWorker::SetClientOptions(std::string name, int64_t limit_bytes) {
// Currently only the Plasma store supports client options.
return plasma_store_provider_->SetClientOptions(name, limit_bytes);
@ -324,9 +368,9 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, const int64_t timeout_m
local_timeout_ms = std::max(static_cast<int64_t>(0),
timeout_ms - (current_time_ms() - start_time));
}
RAY_RETURN_NOT_OK(memory_store_provider_.Get(memory_object_ids, local_timeout_ms,
worker_context_.GetCurrentTaskID(),
&result_map, &got_exception));
RAY_RETURN_NOT_OK(memory_store_provider_->Get(memory_object_ids, local_timeout_ms,
worker_context_.GetCurrentTaskID(),
&result_map, &got_exception));
}
// If any of the objects have been promoted to plasma, then we retry their
@ -379,7 +423,7 @@ Status CoreWorker::Contains(const ObjectID &object_id, bool *has_object) {
if (object_id.IsDirectCallType()) {
// Note that the memory store returns false if the object value is
// ErrorType::OBJECT_IN_PLASMA.
RAY_RETURN_NOT_OK(memory_store_provider_.Contains(object_id, &found));
RAY_RETURN_NOT_OK(memory_store_provider_->Contains(object_id, &found));
}
if (!found) {
// We check plasma as a fallback in all cases, since a direct call object
@ -430,7 +474,7 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
if (static_cast<int>(ready.size()) < num_objects && memory_object_ids.size() > 0) {
// TODO(ekl) for memory objects that are ErrorType::OBJECT_IN_PLASMA, we should
// consider waiting on them in plasma as well to ensure they are local.
RAY_RETURN_NOT_OK(memory_store_provider_.Wait(
RAY_RETURN_NOT_OK(memory_store_provider_->Wait(
memory_object_ids, num_objects - static_cast<int>(ready.size()),
/*timeout_ms=*/0, worker_context_.GetCurrentTaskID(), &ready));
}
@ -453,7 +497,7 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
std::max(0, static_cast<int>(timeout_ms - (current_time_ms() - start_time)));
}
if (static_cast<int>(ready.size()) < num_objects && memory_object_ids.size() > 0) {
RAY_RETURN_NOT_OK(memory_store_provider_.Wait(
RAY_RETURN_NOT_OK(memory_store_provider_->Wait(
memory_object_ids, num_objects - static_cast<int>(ready.size()), timeout_ms,
worker_context_.GetCurrentTaskID(), &ready));
}
@ -477,7 +521,7 @@ Status CoreWorker::Delete(const std::vector<ObjectID> &object_ids, bool local_on
RAY_RETURN_NOT_OK(plasma_store_provider_->Delete(plasma_object_ids, local_only,
delete_creating_tasks));
RAY_RETURN_NOT_OK(memory_store_provider_.Delete(memory_object_ids));
RAY_RETURN_NOT_OK(memory_store_provider_->Delete(memory_object_ids));
return Status::OK();
}
@ -602,10 +646,11 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
const TaskID actor_task_id = TaskID::ForActorTask(
worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(),
next_task_index, actor_handle->GetActorID());
BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id,
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(),
rpc_address_, function, args, num_returns, task_options.resources,
{}, transport_type, return_ids);
BuildCommonTaskSpec(
builder, actor_handle->CreationJobID(), actor_task_id,
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_,
function, is_direct_call ? PrepareDirectActorCallArgs(args, memory_store_) : args,
num_returns, task_options.resources, {}, transport_type, return_ids);
const ObjectID new_cursor = return_ids->back();
actor_handle->SetActorTaskSpec(builder, transport_type, new_cursor);
@ -848,7 +893,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
void CoreWorker::HandleAssignTask(const rpc::AssignTaskRequest &request,
rpc::AssignTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
if (worker_context_.CurrentTaskIsDirectCall()) {
if (worker_context_.CurrentActorIsDirectCall()) {
send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr,
nullptr);
return;

View file

@ -113,6 +113,15 @@ class CoreWorker {
reference_counter_.RemoveReference(object_id);
}
/// Promote an object to plasma. If it already exists locally, it will be
/// put into the plasma store. If it doesn't yet exist, it will be spilled to
/// plasma once available.
///
/// Postcondition: Get(object_id.WithPlasmaTransportType()) is valid.
///
/// \param[in] object_id The object ID to promote to plasma.
void PromoteObjectToPlasma(const ObjectID &object_id);
///
/// Public methods related to storing and retrieving objects.
///
@ -482,10 +491,10 @@ class CoreWorker {
std::shared_ptr<CoreWorkerMemoryStore> memory_store_;
/// Plasma store interface.
std::unique_ptr<CoreWorkerPlasmaStoreProvider> plasma_store_provider_;
std::shared_ptr<CoreWorkerPlasmaStoreProvider> plasma_store_provider_;
/// In-memory store interface.
CoreWorkerMemoryStoreProvider memory_store_provider_;
std::shared_ptr<CoreWorkerMemoryStoreProvider> memory_store_provider_;
///
/// Fields related to task submission.

View file

@ -107,7 +107,9 @@ std::shared_ptr<RayObject> GetRequest::Get(const ObjectID &object_id) const {
return nullptr;
}
CoreWorkerMemoryStore::CoreWorkerMemoryStore() {}
CoreWorkerMemoryStore::CoreWorkerMemoryStore(
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma)
: store_in_plasma_(store_in_plasma) {}
void CoreWorkerMemoryStore::GetAsync(
const ObjectID &object_id, std::function<void(std::shared_ptr<RayObject>)> callback) {
@ -127,7 +129,25 @@ void CoreWorkerMemoryStore::GetAsync(
}
}
std::shared_ptr<RayObject> CoreWorkerMemoryStore::GetOrPromoteToPlasma(
const ObjectID &object_id) {
absl::MutexLock lock(&mu_);
auto iter = objects_.find(object_id);
if (iter != objects_.end()) {
auto obj = iter->second;
if (obj->IsInPlasmaError()) {
return nullptr;
}
return obj;
}
RAY_CHECK(store_in_plasma_ != nullptr)
<< "Cannot promote object without plasma provider callback.";
promoted_to_plasma_.insert(object_id);
return nullptr;
}
Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &object) {
RAY_CHECK(object_id.IsDirectCallType());
std::vector<std::function<void(std::shared_ptr<RayObject>)>> async_callbacks;
auto object_entry =
std::make_shared<RayObject>(object.GetData(), object.GetMetadata(), true);
@ -146,6 +166,13 @@ Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &ob
object_async_get_requests_.erase(async_callback_it);
}
auto promoted_it = promoted_to_plasma_.find(object_id);
if (promoted_it != promoted_to_plasma_.end()) {
RAY_CHECK(store_in_plasma_ != nullptr);
store_in_plasma_(object, object_id.WithTransportType(TaskTransportType::RAYLET));
promoted_to_plasma_.erase(promoted_it);
}
bool should_add_entry = true;
auto object_request_iter = object_get_requests_.find(object_id);
if (object_request_iter != object_get_requests_.end()) {

View file

@ -18,7 +18,8 @@ class CoreWorkerMemoryStore;
/// actor call (see direct_actor_transport.cc).
class CoreWorkerMemoryStore {
public:
CoreWorkerMemoryStore();
CoreWorkerMemoryStore(
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma = nullptr);
~CoreWorkerMemoryStore(){};
/// Put an object with specified ID into object store.
@ -49,6 +50,14 @@ class CoreWorkerMemoryStore {
void GetAsync(const ObjectID &object_id,
std::function<void(std::shared_ptr<RayObject>)> callback);
/// Get a single object if available. If the object is not local yet, or if the object
/// is local but is ErrorType::OBJECT_IN_PLASMA, then nullptr will be returned, and
/// the store will ensure the object is promoted to plasma once available.
///
/// \param[in] object_id The object id to get.
/// \return pointer to the local object, or nullptr if promoted to plasma.
std::shared_ptr<RayObject> GetOrPromoteToPlasma(const ObjectID &object_id);
/// Delete a list of objects from the object store.
///
/// \param[in] object_ids IDs of the objects to delete.
@ -62,6 +71,15 @@ class CoreWorkerMemoryStore {
bool Contains(const ObjectID &object_id);
private:
/// Optional callback for putting objects into the plasma store.
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma_;
/// Protects the data structures below.
absl::Mutex mu_;
/// Set of objects that should be promoted to plasma once available.
absl::flat_hash_set<ObjectID> promoted_to_plasma_ GUARDED_BY(mu_);
/// Map from object ID to `RayObject`.
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> objects_ GUARDED_BY(mu_);
@ -73,9 +91,6 @@ class CoreWorkerMemoryStore {
absl::flat_hash_map<ObjectID,
std::vector<std::function<void(std::shared_ptr<RayObject>)>>>
object_async_get_requests_ GUARDED_BY(mu_);
/// Protect the two maps above.
absl::Mutex mu_;
};
} // namespace ray

View file

@ -14,6 +14,7 @@ CoreWorkerMemoryStoreProvider::CoreWorkerMemoryStoreProvider(
Status CoreWorkerMemoryStoreProvider::Put(const RayObject &object,
const ObjectID &object_id) {
RAY_CHECK(object_id.IsDirectCallType());
Status status = store_->Put(object_id, object);
if (status.IsObjectExists()) {
// Object already exists in store, treat it as ok.

View file

@ -27,6 +27,7 @@ Status CoreWorkerPlasmaStoreProvider::SetClientOptions(std::string name,
Status CoreWorkerPlasmaStoreProvider::Put(const RayObject &object,
const ObjectID &object_id) {
RAY_CHECK(!object_id.IsDirectCallType());
std::shared_ptr<Buffer> data;
RAY_RETURN_NOT_OK(Create(object.GetMetadata(),
object.HasData() ? object.GetData()->Size() : 0, object_id,

View file

@ -637,14 +637,14 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
std::vector<ObjectID> ids(buffers.size());
for (size_t i = 0; i < ids.size(); i++) {
ids[i] = ObjectID::FromRandom();
ids[i] = ObjectID::FromRandom().WithDirectTransportType();
RAY_CHECK_OK(provider.Put(buffers[i], ids[i]));
}
absl::flat_hash_set<ObjectID> wait_ids(ids.begin(), ids.end());
absl::flat_hash_set<ObjectID> wait_results;
ObjectID nonexistent_id = ObjectID::FromRandom();
ObjectID nonexistent_id = ObjectID::FromRandom().WithDirectTransportType();
wait_ids.insert(nonexistent_id);
RAY_CHECK_OK(
provider.Wait(wait_ids, ids.size() + 1, 100, RandomTaskId(), &wait_results));
@ -693,9 +693,9 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
std::vector<ObjectID> ready_ids(buffers.size());
std::vector<ObjectID> unready_ids(buffers.size());
for (size_t i = 0; i < unready_ids.size(); i++) {
ready_ids[i] = ObjectID::FromRandom();
ready_ids[i] = ObjectID::FromRandom().WithDirectTransportType();
RAY_CHECK_OK(provider.Put(buffers[i], ready_ids[i]));
unready_ids[i] = ObjectID::FromRandom();
unready_ids[i] = ObjectID::FromRandom().WithDirectTransportType();
}
auto thread_func = [&unready_ids, &provider, &buffers]() {

View file

@ -38,9 +38,33 @@ class MockRayletClient : public WorkerLeaseInterface {
int num_workers_returned = 0;
};
TEST(TestMemoryStore, TestPromoteToPlasma) {
bool num_plasma_puts = 0;
auto mem = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore(
[&](const RayObject &obj, const ObjectID &obj_id) { num_plasma_puts += 1; }));
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
auto data = GenerateRandomObject();
ASSERT_TRUE(mem->Put(obj1, *data).ok());
// Test getting an already existing object.
ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj1) != nullptr);
ASSERT_TRUE(num_plasma_puts == 0);
// Testing getting an object that doesn't exist yet causes promotion.
ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj2) == nullptr);
ASSERT_TRUE(num_plasma_puts == 0);
ASSERT_TRUE(mem->Put(obj2, *data).ok());
ASSERT_TRUE(num_plasma_puts == 1);
// The next time you get it, it's already there so no need to promote.
ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj2) != nullptr);
ASSERT_TRUE(num_plasma_puts == 1);
}
TEST(LocalDependencyResolverTest, TestNoDependencies) {
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
CoreWorkerMemoryStoreProvider store(ptr);
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
LocalDependencyResolver resolver(store);
TaskSpecification task;
bool ok = false;
@ -50,7 +74,7 @@ TEST(LocalDependencyResolverTest, TestNoDependencies) {
TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
CoreWorkerMemoryStoreProvider store(ptr);
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
LocalDependencyResolver resolver(store);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::RAYLET);
TaskSpecification task;
@ -62,16 +86,38 @@ TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
ASSERT_EQ(resolver.NumPendingTasks(), 0);
}
TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) {
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
LocalDependencyResolver resolver(store);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
std::string meta = std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
auto data = RayObject(nullptr, meta_buffer);
ASSERT_TRUE(store->Put(data, obj1).ok());
TaskSpecification task;
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
ASSERT_TRUE(task.ArgId(0, 0).IsDirectCallType());
bool ok = false;
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
ASSERT_TRUE(ok);
ASSERT_TRUE(task.ArgByRef(0));
// Checks that the object id was promoted to a plasma type id.
ASSERT_FALSE(task.ArgId(0, 0).IsDirectCallType());
ASSERT_EQ(resolver.NumPendingTasks(), 0);
}
TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
CoreWorkerMemoryStoreProvider store(ptr);
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
LocalDependencyResolver resolver(store);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
auto data = GenerateRandomObject();
// Ensure the data is already present in the local store.
ASSERT_TRUE(store.Put(*data, obj1).ok());
ASSERT_TRUE(store.Put(*data, obj2).ok());
ASSERT_TRUE(store->Put(*data, obj1).ok());
ASSERT_TRUE(store->Put(*data, obj2).ok());
TaskSpecification task;
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
task.GetMutableMessage().add_args()->add_object_ids(obj2.Binary());
@ -88,7 +134,7 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
CoreWorkerMemoryStoreProvider store(ptr);
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
LocalDependencyResolver resolver(store);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
@ -100,8 +146,8 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
ASSERT_EQ(resolver.NumPendingTasks(), 1);
ASSERT_TRUE(!ok);
ASSERT_TRUE(store.Put(*data, obj1).ok());
ASSERT_TRUE(store.Put(*data, obj2).ok());
ASSERT_TRUE(store->Put(*data, obj1).ok());
ASSERT_TRUE(store->Put(*data, obj2).ok());
// Tests that the task proto was rewritten to have inline argument values after
// resolution completes.
ASSERT_TRUE(ok);
@ -112,10 +158,11 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
ASSERT_EQ(resolver.NumPendingTasks(), 0);
}
TEST(DirectTaskTranportTest, TestSubmitOneTask) {
TEST(DirectTaskTransportTest, TestSubmitOneTask) {
MockRayletClient raylet_client;
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
auto factory = [&](WorkerAddress addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
TaskSpecification task;
@ -133,10 +180,11 @@ TEST(DirectTaskTranportTest, TestSubmitOneTask) {
ASSERT_EQ(raylet_client.num_workers_returned, 1);
}
TEST(DirectTaskTranportTest, TestHandleTaskFailure) {
TEST(DirectTaskTransportTest, TestHandleTaskFailure) {
MockRayletClient raylet_client;
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
auto factory = [&](WorkerAddress addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
TaskSpecification task;
@ -150,10 +198,11 @@ TEST(DirectTaskTranportTest, TestHandleTaskFailure) {
ASSERT_EQ(raylet_client.num_workers_returned, 1);
}
TEST(DirectTaskTranportTest, TestConcurrentWorkerLeases) {
TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) {
MockRayletClient raylet_client;
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
auto factory = [&](WorkerAddress addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
TaskSpecification task1;
@ -190,10 +239,11 @@ TEST(DirectTaskTranportTest, TestConcurrentWorkerLeases) {
ASSERT_EQ(raylet_client.num_workers_returned, 3);
}
TEST(DirectTaskTranportTest, TestReuseWorkerLease) {
TEST(DirectTaskTransportTest, TestReuseWorkerLease) {
MockRayletClient raylet_client;
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
auto factory = [&](WorkerAddress addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
TaskSpecification task1;
@ -232,10 +282,11 @@ TEST(DirectTaskTranportTest, TestReuseWorkerLease) {
ASSERT_EQ(raylet_client.num_workers_returned, 2);
}
TEST(DirectTaskTranportTest, TestWorkerNotReusedOnError) {
TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) {
MockRayletClient raylet_client;
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto store = std::make_shared<CoreWorkerMemoryStoreProvider>(ptr);
auto factory = [&](WorkerAddress addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
TaskSpecification task1;

View file

@ -5,9 +5,61 @@ using ray::rpc::ActorTableData;
namespace ray {
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
const rpc::ErrorType &error_type,
std::shared_ptr<CoreWorkerMemoryStoreProvider> &in_memory_store) {
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
<< ", error_type: " << ErrorType_Name(error_type);
for (int i = 0; i < num_returns; i++) {
const auto object_id = ObjectID::ForTaskReturn(
task_id, /*index=*/i + 1,
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
std::string meta = std::to_string(static_cast<int>(error_type));
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
RAY_CHECK_OK(in_memory_store->Put(RayObject(nullptr, meta_buffer), object_id));
}
}
void WriteObjectsToMemoryStore(
const rpc::PushTaskReply &reply,
std::shared_ptr<CoreWorkerMemoryStoreProvider> &in_memory_store) {
for (int i = 0; i < reply.return_objects_size(); i++) {
const auto &return_object = reply.return_objects(i);
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
if (return_object.in_plasma()) {
// Mark it as in plasma with a dummy object.
std::string meta =
std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
auto metadata =
const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
RAY_CHECK_OK(in_memory_store->Put(RayObject(nullptr, meta_buffer), object_id));
} else {
std::shared_ptr<LocalMemoryBuffer> data_buffer;
if (return_object.data().size() > 0) {
data_buffer = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(
reinterpret_cast<const uint8_t *>(return_object.data().data())),
return_object.data().size());
}
std::shared_ptr<LocalMemoryBuffer> metadata_buffer;
if (return_object.metadata().size() > 0) {
metadata_buffer = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
return_object.metadata().size());
}
RAY_CHECK_OK(
in_memory_store->Put(RayObject(data_buffer, metadata_buffer), object_id));
}
}
}
CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter(
rpc::ClientCallManager &client_call_manager,
CoreWorkerMemoryStoreProvider store_provider)
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
: client_call_manager_(client_call_manager), in_memory_store_(store_provider) {}
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
@ -49,7 +101,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe
} else {
// Actor is dead, treat the task as failure.
RAY_CHECK(iter->second.state_ == ActorTableData::DEAD);
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED);
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, in_memory_store_);
}
// If the task submission subsequently fails, then the client will receive
@ -84,7 +136,8 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
for (const auto &entry : iter->second) {
const auto &task_id = entry.first;
const auto num_returns = entry.second;
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED);
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
in_memory_store_);
}
waiting_reply_tasks_.erase(actor_id);
}
@ -94,7 +147,8 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
if (pending_it != pending_requests_.end()) {
for (const auto &request : pending_it->second) {
TreatTaskAsFailed(TaskID::FromBinary(request->task_spec().task_id()),
request->task_spec().num_returns(), rpc::ErrorType::ACTOR_DIED);
request->task_spec().num_returns(), rpc::ErrorType::ACTOR_DIED,
in_memory_store_);
}
pending_requests_.erase(pending_it);
}
@ -136,59 +190,14 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(
// Note that this might be the __ray_terminate__ task, so we don't log
// loudly with ERROR here.
RAY_LOG(INFO) << "Task failed with error: " << status;
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED);
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED,
in_memory_store_);
return;
}
for (int i = 0; i < reply.return_objects_size(); i++) {
const auto &return_object = reply.return_objects(i);
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
if (return_object.in_plasma()) {
// Mark it as in plasma with a dummy object.
std::string meta =
std::to_string(static_cast<int>(rpc::ErrorType::OBJECT_IN_PLASMA));
auto metadata =
const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
RAY_CHECK_OK(
in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id));
} else {
std::shared_ptr<LocalMemoryBuffer> data_buffer;
if (return_object.data().size() > 0) {
data_buffer = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(
reinterpret_cast<const uint8_t *>(return_object.data().data())),
return_object.data().size());
}
std::shared_ptr<LocalMemoryBuffer> metadata_buffer;
if (return_object.metadata().size() > 0) {
metadata_buffer = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
return_object.metadata().size());
}
RAY_CHECK_OK(
in_memory_store_.Put(RayObject(data_buffer, metadata_buffer), object_id));
}
}
WriteObjectsToMemoryStore(reply, in_memory_store_);
});
if (!status.ok()) {
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED);
}
}
void CoreWorkerDirectActorTaskSubmitter::TreatTaskAsFailed(
const TaskID &task_id, int num_returns, const rpc::ErrorType &error_type) {
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
<< ", error_type: " << ErrorType_Name(error_type);
for (int i = 0; i < num_returns; i++) {
const auto object_id = ObjectID::ForTaskReturn(
task_id, /*index=*/i + 1,
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
std::string meta = std::to_string(static_cast<int>(error_type));
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
RAY_CHECK_OK(in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id));
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, in_memory_store_);
}
}
@ -301,8 +310,8 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask(
send_reply_callback(status, nullptr, nullptr);
},
[send_reply_callback]() {
send_reply_callback(Status::Invalid("client cancelled rpc"), nullptr,
nullptr);
send_reply_callback(Status::Invalid("client cancelled stale rpc"),
nullptr, nullptr);
},
dependencies);
}

View file

@ -22,6 +22,26 @@ namespace ray {
/// The max time to wait for out-of-order tasks.
const int kMaxReorderWaitSeconds = 30;
/// Treat a task as failed.
///
/// \param[in] task_id The ID of a task.
/// \param[in] num_returns Number of return objects.
/// \param[in] error_type The type of the specific error.
/// \param[in] in_memory_store The memory store to write to.
/// \return Void.
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
const rpc::ErrorType &error_type,
std::shared_ptr<CoreWorkerMemoryStoreProvider> &in_memory_store);
/// Write return objects to the memory store.
///
/// \param[in] reply Proto response to a direct actor or task call.
/// \param[in] in_memory_store The memory store to write to.
/// \return Void.
void WriteObjectsToMemoryStore(
const rpc::PushTaskReply &reply,
std::shared_ptr<CoreWorkerMemoryStoreProvider> &in_memory_store);
/// In direct actor call task submitter and receiver, a task is directly submitted
/// to the actor that will execute it.
@ -40,8 +60,9 @@ struct ActorStateData {
// This class is thread-safe.
class CoreWorkerDirectActorTaskSubmitter {
public:
CoreWorkerDirectActorTaskSubmitter(rpc::ClientCallManager &client_call_manager,
CoreWorkerMemoryStoreProvider store_provider);
CoreWorkerDirectActorTaskSubmitter(
rpc::ClientCallManager &client_call_manager,
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider);
/// Submit a task to an actor for execution.
///
@ -70,15 +91,6 @@ class CoreWorkerDirectActorTaskSubmitter {
std::unique_ptr<rpc::PushTaskRequest> request,
const ActorID &actor_id, const TaskID &task_id, int num_returns);
/// Treat a task as failed.
///
/// \param[in] task_id The ID of a task.
/// \param[in] num_returns Number of return objects.
/// \param[in] error_type The type of the specific error.
/// \return Void.
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
const rpc::ErrorType &error_type);
/// Create connection to actor and send all pending tasks.
/// Note that this function doesn't take lock, the caller is expected to hold
/// `mutex_` before calling this function.
@ -120,7 +132,7 @@ class CoreWorkerDirectActorTaskSubmitter {
std::unordered_map<ActorID, std::unordered_map<TaskID, int>> waiting_reply_tasks_;
/// The store provider.
CoreWorkerMemoryStoreProvider in_memory_store_;
std::shared_ptr<CoreWorkerMemoryStoreProvider> in_memory_store_;
friend class CoreWorkerTest;
};
@ -230,7 +242,8 @@ class SchedulingQueue {
std::function<void()> accept_request, std::function<void()> reject_request,
const std::vector<ObjectID> &dependencies = {}) {
if (seq_no == -1) {
seq_no = next_seq_no_; // A value of -1 means no ordering constraint.
accept_request(); // A seq_no of -1 means no ordering constraint.
return;
}
RAY_CHECK(boost::this_thread::get_id() == main_thread_id_);
if (client_processed_up_to >= next_seq_no_) {
@ -259,6 +272,8 @@ class SchedulingQueue {
// Cancel any stale requests that the client doesn't need any longer.
while (!pending_tasks_.empty() && pending_tasks_.begin()->first < next_seq_no_) {
auto head = pending_tasks_.begin();
RAY_LOG(ERROR) << "Cancelling stale RPC with seqno "
<< pending_tasks_.begin()->first << " < " << next_seq_no_;
head->second.Cancel();
pending_tasks_.erase(head);
}

View file

@ -1,4 +1,5 @@
#include "ray/core_worker/transport/direct_task_transport.h"
#include "ray/core_worker/transport/direct_actor_transport.h"
namespace ray {
@ -13,13 +14,20 @@ void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr<RayObject> valu
if (id == obj_id) {
auto *mutable_arg = msg.mutable_args(i);
mutable_arg->clear_object_ids();
if (value->HasData()) {
const auto &data = value->GetData();
mutable_arg->set_data(data->Data(), data->Size());
}
if (value->HasMetadata()) {
const auto &metadata = value->GetMetadata();
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
if (value->IsInPlasmaError()) {
// Promote the object id to plasma.
mutable_arg->add_object_ids(
obj_id.WithTransportType(TaskTransportType::RAYLET).Binary());
} else {
// Inline the object value.
if (value->HasData()) {
const auto &data = value->GetData();
mutable_arg->set_data(data->Data(), data->Size());
}
if (value->HasMetadata()) {
const auto &metadata = value->GetMetadata();
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
}
}
found = true;
}
@ -52,7 +60,7 @@ void LocalDependencyResolver::ResolveDependencies(const TaskSpecification &task,
num_pending_ += 1;
for (const auto &obj_id : state->local_dependencies) {
in_memory_store_.GetAsync(
in_memory_store_->GetAsync(
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
RAY_CHECK(obj != nullptr);
bool complete = false;
@ -128,23 +136,6 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(
worker_request_pending_ = true;
}
void CoreWorkerDirectTaskSubmitter::TreatTaskAsFailed(const TaskID &task_id,
int num_returns,
const rpc::ErrorType &error_type) {
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
<< ", error_type: " << ErrorType_Name(error_type);
for (int i = 0; i < num_returns; i++) {
const auto object_id = ObjectID::ForTaskReturn(
task_id, /*index=*/i + 1,
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
std::string meta = std::to_string(static_cast<int>(error_type));
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
RAY_CHECK_OK(in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id));
}
}
// TODO(ekl) consider reconsolidating with DirectActorTransport.
void CoreWorkerDirectTaskSubmitter::PushNormalTask(const WorkerAddress &addr,
rpc::CoreWorkerClientInterface &client,
TaskSpecification &task_spec) {
@ -157,30 +148,15 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(const WorkerAddress &addr,
[this, task_id, num_returns, addr](Status status, const rpc::PushTaskReply &reply) {
OnWorkerIdle(addr, /*error=*/!status.ok());
if (!status.ok()) {
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED);
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED,
in_memory_store_);
return;
}
for (int i = 0; i < reply.return_objects_size(); i++) {
const auto &return_object = reply.return_objects(i);
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
std::shared_ptr<LocalMemoryBuffer> data_buffer;
if (return_object.data().size() > 0) {
data_buffer = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(
reinterpret_cast<const uint8_t *>(return_object.data().data())),
return_object.data().size());
}
std::shared_ptr<LocalMemoryBuffer> metadata_buffer;
if (return_object.metadata().size() > 0) {
metadata_buffer = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
return_object.metadata().size());
}
RAY_CHECK_OK(
in_memory_store_.Put(RayObject(data_buffer, metadata_buffer), object_id));
}
WriteObjectsToMemoryStore(reply, in_memory_store_);
});
RAY_CHECK_OK(status);
if (!status.ok()) {
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED,
in_memory_store_);
}
}
}; // namespace ray

View file

@ -23,7 +23,7 @@ struct TaskState {
// This class is thread-safe.
class LocalDependencyResolver {
public:
LocalDependencyResolver(CoreWorkerMemoryStoreProvider &store_provider)
LocalDependencyResolver(std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
: in_memory_store_(store_provider), num_pending_(0) {}
/// Resolve all local and remote dependencies for the task, calling the specified
@ -42,7 +42,7 @@ class LocalDependencyResolver {
private:
/// The store provider.
CoreWorkerMemoryStoreProvider in_memory_store_;
std::shared_ptr<CoreWorkerMemoryStoreProvider> in_memory_store_;
/// Number of tasks pending dependency resolution.
std::atomic<int> num_pending_;
@ -58,9 +58,9 @@ typedef std::function<std::shared_ptr<rpc::CoreWorkerClientInterface>(WorkerAddr
// This class is thread-safe.
class CoreWorkerDirectTaskSubmitter {
public:
CoreWorkerDirectTaskSubmitter(WorkerLeaseInterface &lease_client,
ClientFactoryFn client_factory,
CoreWorkerMemoryStoreProvider store_provider)
CoreWorkerDirectTaskSubmitter(
WorkerLeaseInterface &lease_client, ClientFactoryFn client_factory,
std::shared_ptr<CoreWorkerMemoryStoreProvider> store_provider)
: lease_client_(lease_client),
client_factory_(client_factory),
in_memory_store_(store_provider),
@ -93,10 +93,6 @@ class CoreWorkerDirectTaskSubmitter {
void PushNormalTask(const WorkerAddress &addr, rpc::CoreWorkerClientInterface &client,
TaskSpecification &task_spec);
/// Mark a direct call as failed by storing errors for its return objects.
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
const rpc::ErrorType &error_type);
// Client that can be used to lease and return workers.
WorkerLeaseInterface &lease_client_;
@ -104,7 +100,7 @@ class CoreWorkerDirectTaskSubmitter {
ClientFactoryFn client_factory_;
/// The store provider.
CoreWorkerMemoryStoreProvider in_memory_store_;
std::shared_ptr<CoreWorkerMemoryStoreProvider> in_memory_store_;
/// Resolve local and remote dependencies;
LocalDependencyResolver resolver_;