diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 9ea19f0d1..979921cd9 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1015,6 +1015,12 @@ cdef class CoreWorker: # Note: faster to not release GIL for short-running op. self.core_worker.get().RemoveObjectIDReference(c_object_id) + def promote_object_to_plasma(self, ObjectID object_id): + cdef: + CObjectID c_object_id = object_id.native() + self.core_worker.get().PromoteObjectToPlasma(c_object_id) + return object_id.with_plasma_transport_type() + # TODO: handle noreturn better cdef store_task_outputs( self, worker, outputs, const c_vector[CObjectID] return_ids, diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index abcc587cd..3d55fd98a 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -104,6 +104,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: *bytes) void AddObjectIDReference(const CObjectID &object_id) void RemoveObjectIDReference(const CObjectID &object_id) + void PromoteObjectToPlasma(const CObjectID &object_id) CRayStatus SetClientOptions(c_string client_name, int64_t limit) CRayStatus Put(const CRayObject &object, CObjectID *object_id) diff --git a/python/ray/includes/task.pxi b/python/ray/includes/task.pxi index d6bfd3e31..093e4db2b 100644 --- a/python/ray/includes/task.pxi +++ b/python/ray/includes/task.pxi @@ -95,8 +95,10 @@ cdef class TaskSpec: :self.task_spec.get().ArgMetadataSize(i)] if metadata == RAW_BUFFER_METADATA: obj = data - else: + elif metadata == PICKLE_BUFFER_METADATA: obj = pickle.loads(data) + else: + obj = data arg_list.append(obj) elif lang == LANGUAGE_JAVA: arg_list = num_args * [""] diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index c317e622b..012222a2a 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -150,6 +150,8 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: c_bool IsDirectCallType() + CObjectID WithPlasmaTransportType() + int64_t ObjectIndex() const CTaskID TaskId() const diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 6ed41f57a..7170745e5 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -179,6 +179,9 @@ cdef class ObjectID(BaseID): def is_direct_call_type(self): return self.data.IsDirectCallType() + def with_plasma_transport_type(self): + return ObjectID(self.data.WithPlasmaTransportType().Binary()) + def is_nil(self): return self.data.IsNil() diff --git a/python/ray/serialization.py b/python/ray/serialization.py index f6a6af8ac..c5d2f7293 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -158,9 +158,7 @@ class SerializationContext(object): def id_serializer(obj): if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type(): - raise NotImplementedError( - "Objects produced by direct actor calls cannot be " - "passed to other tasks as arguments.") + obj = self.worker.core_worker.promote_object_to_plasma(obj) return pickle.dumps(obj) def id_deserializer(serialized_obj): @@ -191,9 +189,7 @@ class SerializationContext(object): def id_serializer(obj): if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type(): - raise NotImplementedError( - "Objects produced by direct actor calls cannot be " - "passed to other tasks as arguments.") + obj = self.worker.core_worker.promote_object_to_plasma(obj) return obj.__reduce__() def id_deserializer(serialized_obj): diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 029dcc2a5..2766aaf66 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -1218,6 +1218,71 @@ def test_direct_call_simple(ray_start_regular): range(1, 101)) +def test_direct_call_matrix(shutdown_only): + ray.init(object_store_memory=1000 * 1024 * 1024) + + @ray.remote + class Actor(object): + def small_value(self): + return 0 + + def large_value(self): + return np.zeros(10 * 1024 * 1024) + + def echo(self, x): + if isinstance(x, list): + x = ray.get(x[0]) + return x + + @ray.remote + def small_value(): + return 0 + + @ray.remote + def large_value(): + return np.zeros(10 * 1024 * 1024) + + @ray.remote + def echo(x): + if isinstance(x, list): + x = ray.get(x[0]) + return x + + def check(source_actor, dest_actor, is_large, out_of_band): + print("CHECKING", "actor" if source_actor else "task", "to", "actor" + if dest_actor else "task", "large_object" + if is_large else "small_object", "out_of_band" + if out_of_band else "in_band") + if source_actor: + a = Actor.options(is_direct_call=True).remote() + if is_large: + x_id = a.large_value.remote() + else: + x_id = a.small_value.remote() + else: + if is_large: + x_id = large_value.options(is_direct_call=True).remote() + else: + x_id = small_value.options(is_direct_call=True).remote() + if out_of_band: + x_id = [x_id] + if dest_actor: + b = Actor.options(is_direct_call=True).remote() + x = ray.get(b.echo.remote(x_id)) + else: + x = ray.get(echo.options(is_direct_call=True).remote(x_id)) + if is_large: + assert isinstance(x, np.ndarray) + else: + assert isinstance(x, int) + + for is_large in [False, True]: + for source_actor in [False, True]: + for dest_actor in [False, True]: + for out_of_band in [False, True]: + check(source_actor, dest_actor, is_large, out_of_band) + + def test_direct_call_chain(ray_start_regular): @ray.remote def g(x): @@ -1265,26 +1330,6 @@ def test_direct_actor_large_objects(ray_start_regular): assert isinstance(ray.get(obj_id), np.ndarray) -def test_direct_actor_errors(ray_start_regular): - @ray.remote - class Actor(object): - def __init__(self): - pass - - def f(self, x): - return x * 2 - - @ray.remote - def f(x): - return 1 - - a = Actor._remote(is_direct_call=True) - - # cannot pass returns to other methods even in a list - with pytest.raises(Exception): - ray.get(f.remote([a.f.remote(2)])) - - def test_direct_actor_pass_by_ref(ray_start_regular): @ray.remote class Actor(object): diff --git a/src/ray/common/id.cc b/src/ray/common/id.cc index 9be9cb138..da819e191 100644 --- a/src/ray/common/id.cc +++ b/src/ray/common/id.cc @@ -158,6 +158,14 @@ ObjectID ObjectID::WithTransportType(TaskTransportType transport_type) const { return copy; } +ObjectID ObjectID::WithPlasmaTransportType() const { + return WithTransportType(TaskTransportType::RAYLET); +} + +ObjectID ObjectID::WithDirectTransportType() const { + return WithTransportType(TaskTransportType::DIRECT); +} + uint8_t ObjectID::GetTransportType() const { return ::ray::GetTransportType(this->GetFlags()); } @@ -290,10 +298,6 @@ TaskID TaskID::ComputeDriverTaskId(const WorkerID &driver_id) { } TaskID ObjectID::TaskId() const { - if (!CreatedByTask()) { - // TODO(qwang): Should be RAY_CHECK here. - RAY_LOG(WARNING) << "Shouldn't call this on a non-task object id: " << this->Hex(); - } return TaskID::FromBinary( std::string(reinterpret_cast(id_), TaskID::Size())); } diff --git a/src/ray/common/id.h b/src/ray/common/id.h index a847ea4fc..b95b88fd5 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -299,6 +299,16 @@ class ObjectID : public BaseID { /// \return Copy of this object id with the specified transport type. ObjectID WithTransportType(TaskTransportType transport_type) const; + /// Return this object id with the plasma transport type. + /// + /// \return Copy of this object id with the plasma transport type. + ObjectID WithPlasmaTransportType() const; + + /// Return this object id with the direct call transport type. + /// + /// \return Copy of this object id with the direct call transport type. + ObjectID WithDirectTransportType() const; + /// Get the transport type of this object. /// /// \return The type of the transport which is used to transfer this object. diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 634fd87a4..59dda5c11 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -95,7 +95,7 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { SetCurrentJobId(task_spec.JobId()); RAY_CHECK(current_actor_id_.IsNil()); current_actor_id_ = task_spec.ActorCreationId(); - current_task_is_direct_call_ = task_spec.IsDirectCall(); + current_actor_is_direct_call_ = task_spec.IsDirectCall(); current_actor_max_concurrency_ = task_spec.MaxActorConcurrency(); } else if (task_spec.IsActorTask()) { RAY_CHECK(current_job_id_ == task_spec.JobId()); @@ -118,8 +118,12 @@ std::shared_ptr WorkerContext::GetCurrentTask() const { const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; } +bool WorkerContext::CurrentActorIsDirectCall() const { + return current_actor_is_direct_call_; +} + bool WorkerContext::CurrentTaskIsDirectCall() const { - return current_task_is_direct_call_; + return current_task_is_direct_call_ || current_actor_is_direct_call_; } int WorkerContext::CurrentActorMaxConcurrency() const { diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 9a30e0123..3ced2ced1 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -34,6 +34,11 @@ class WorkerContext { const ActorID &GetCurrentActorID() const; + /// Returns whether we are in a direct call actor. + bool CurrentActorIsDirectCall() const; + + /// Returns whether we are in a direct call task. This encompasses both direct + /// actor and normal tasks. bool CurrentTaskIsDirectCall() const; int CurrentActorMaxConcurrency() const; @@ -47,6 +52,7 @@ class WorkerContext { const WorkerID worker_id_; JobID current_job_id_; ActorID current_actor_id_; + bool current_actor_is_direct_call_ = false; bool current_task_is_direct_call_ = false; int current_actor_max_concurrency_ = 1; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index d2c4581a8..07f18b461 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -25,9 +25,10 @@ void BuildCommonTaskSpec( // Set task arguments. for (const auto &arg : args) { if (arg.IsPassedByReference()) { + // TODO(ekl) remove this check once we deprecate TaskTransportType::RAYLET if (transport_type == ray::TaskTransportType::RAYLET) { RAY_CHECK(!arg.GetReference().IsDirectCallType()) - << "NotImplemented: passing direct call objects to other tasks"; + << "Passing direct call objects to non-direct tasks is not allowed."; } builder.AddByRefArg(arg.GetReference()); } else { @@ -61,6 +62,37 @@ void GroupObjectIdsByStoreProvider(const std::vector &object_ids, namespace ray { +// Prepare direct call args for sending to a direct call *actor*. Direct call actors +// always resolve their dependencies remotely, so we need some client-side preprocessing +// to ensure they don't try to resolve a direct call object ID remotely (which is +// impossible). +// - Direct call args that are local and small will be inlined. +// - Direct call args that are non-local or large will be promoted to plasma. +// Note that args for direct call *tasks* are handled by LocalDependencyResolver. +std::vector PrepareDirectActorCallArgs( + const std::vector &args, + std::shared_ptr memory_store) { + std::vector out; + for (const auto &arg : args) { + if (arg.IsPassedByReference() && arg.GetReference().IsDirectCallType()) { + const ObjectID &obj_id = arg.GetReference(); + // TODO(ekl) we should consider resolving these dependencies on the client side + // for actor calls. It is a little tricky since we have to also preserve the + // task ordering so we can't simply use LocalDependencyResolver. + std::shared_ptr obj = memory_store->GetOrPromoteToPlasma(obj_id); + if (obj != nullptr) { + out.push_back(TaskArg::PassByValue(obj)); + } else { + out.push_back(TaskArg::PassByReference( + obj_id.WithTransportType(TaskTransportType::RAYLET))); + } + } else { + out.push_back(arg); + } + } + return out; +} + CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, const std::string &store_socket, const std::string &raylet_socket, const JobID &job_id, const gcs::GcsClientOptions &gcs_options, @@ -78,8 +110,6 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, client_call_manager_(new rpc::ClientCallManager(io_service_)), heartbeat_timer_(io_service_), core_worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */), - memory_store_(std::make_shared()), - memory_store_provider_(memory_store_), task_execution_service_work_(task_execution_service_), task_execution_callback_(task_execution_callback), grpc_service_(io_service_, *this) { @@ -159,6 +189,11 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, plasma_store_provider_.reset( new CoreWorkerPlasmaStoreProvider(store_socket, raylet_client_, check_signals_)); + memory_store_.reset( + new CoreWorkerMemoryStore([this](const RayObject &obj, const ObjectID &obj_id) { + RAY_CHECK_OK(plasma_store_provider_->Put(obj, obj_id)); + })); + memory_store_provider_.reset(new CoreWorkerMemoryStoreProvider(memory_store_)); // Create an entry for the driver task in the task table. This task is // added immediately with status RUNNING. This allows us to push errors @@ -267,6 +302,15 @@ void CoreWorker::ReportActiveObjectIDs() { heartbeat_timer_.async_wait(boost::bind(&CoreWorker::ReportActiveObjectIDs, this)); } +void CoreWorker::PromoteObjectToPlasma(const ObjectID &object_id) { + RAY_CHECK(object_id.IsDirectCallType()); + auto value = memory_store_->GetOrPromoteToPlasma(object_id); + if (value != nullptr) { + RAY_CHECK_OK( + plasma_store_provider_->Put(*value, object_id.WithPlasmaTransportType())); + } +} + Status CoreWorker::SetClientOptions(std::string name, int64_t limit_bytes) { // Currently only the Plasma store supports client options. return plasma_store_provider_->SetClientOptions(name, limit_bytes); @@ -324,9 +368,9 @@ Status CoreWorker::Get(const std::vector &ids, const int64_t timeout_m local_timeout_ms = std::max(static_cast(0), timeout_ms - (current_time_ms() - start_time)); } - RAY_RETURN_NOT_OK(memory_store_provider_.Get(memory_object_ids, local_timeout_ms, - worker_context_.GetCurrentTaskID(), - &result_map, &got_exception)); + RAY_RETURN_NOT_OK(memory_store_provider_->Get(memory_object_ids, local_timeout_ms, + worker_context_.GetCurrentTaskID(), + &result_map, &got_exception)); } // If any of the objects have been promoted to plasma, then we retry their @@ -379,7 +423,7 @@ Status CoreWorker::Contains(const ObjectID &object_id, bool *has_object) { if (object_id.IsDirectCallType()) { // Note that the memory store returns false if the object value is // ErrorType::OBJECT_IN_PLASMA. - RAY_RETURN_NOT_OK(memory_store_provider_.Contains(object_id, &found)); + RAY_RETURN_NOT_OK(memory_store_provider_->Contains(object_id, &found)); } if (!found) { // We check plasma as a fallback in all cases, since a direct call object @@ -430,7 +474,7 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, if (static_cast(ready.size()) < num_objects && memory_object_ids.size() > 0) { // TODO(ekl) for memory objects that are ErrorType::OBJECT_IN_PLASMA, we should // consider waiting on them in plasma as well to ensure they are local. - RAY_RETURN_NOT_OK(memory_store_provider_.Wait( + RAY_RETURN_NOT_OK(memory_store_provider_->Wait( memory_object_ids, num_objects - static_cast(ready.size()), /*timeout_ms=*/0, worker_context_.GetCurrentTaskID(), &ready)); } @@ -453,7 +497,7 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, std::max(0, static_cast(timeout_ms - (current_time_ms() - start_time))); } if (static_cast(ready.size()) < num_objects && memory_object_ids.size() > 0) { - RAY_RETURN_NOT_OK(memory_store_provider_.Wait( + RAY_RETURN_NOT_OK(memory_store_provider_->Wait( memory_object_ids, num_objects - static_cast(ready.size()), timeout_ms, worker_context_.GetCurrentTaskID(), &ready)); } @@ -477,7 +521,7 @@ Status CoreWorker::Delete(const std::vector &object_ids, bool local_on RAY_RETURN_NOT_OK(plasma_store_provider_->Delete(plasma_object_ids, local_only, delete_creating_tasks)); - RAY_RETURN_NOT_OK(memory_store_provider_.Delete(memory_object_ids)); + RAY_RETURN_NOT_OK(memory_store_provider_->Delete(memory_object_ids)); return Status::OK(); } @@ -602,10 +646,11 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f const TaskID actor_task_id = TaskID::ForActorTask( worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(), next_task_index, actor_handle->GetActorID()); - BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id, - worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), - rpc_address_, function, args, num_returns, task_options.resources, - {}, transport_type, return_ids); + BuildCommonTaskSpec( + builder, actor_handle->CreationJobID(), actor_task_id, + worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_, + function, is_direct_call ? PrepareDirectActorCallArgs(args, memory_store_) : args, + num_returns, task_options.resources, {}, transport_type, return_ids); const ObjectID new_cursor = return_ids->back(); actor_handle->SetActorTaskSpec(builder, transport_type, new_cursor); @@ -848,7 +893,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, void CoreWorker::HandleAssignTask(const rpc::AssignTaskRequest &request, rpc::AssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { - if (worker_context_.CurrentTaskIsDirectCall()) { + if (worker_context_.CurrentActorIsDirectCall()) { send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr, nullptr); return; diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index c79ad3022..8dc8270d9 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -113,6 +113,15 @@ class CoreWorker { reference_counter_.RemoveReference(object_id); } + /// Promote an object to plasma. If it already exists locally, it will be + /// put into the plasma store. If it doesn't yet exist, it will be spilled to + /// plasma once available. + /// + /// Postcondition: Get(object_id.WithPlasmaTransportType()) is valid. + /// + /// \param[in] object_id The object ID to promote to plasma. + void PromoteObjectToPlasma(const ObjectID &object_id); + /// /// Public methods related to storing and retrieving objects. /// @@ -482,10 +491,10 @@ class CoreWorker { std::shared_ptr memory_store_; /// Plasma store interface. - std::unique_ptr plasma_store_provider_; + std::shared_ptr plasma_store_provider_; /// In-memory store interface. - CoreWorkerMemoryStoreProvider memory_store_provider_; + std::shared_ptr memory_store_provider_; /// /// Fields related to task submission. diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 4b59451c4..368b8dee2 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -107,7 +107,9 @@ std::shared_ptr GetRequest::Get(const ObjectID &object_id) const { return nullptr; } -CoreWorkerMemoryStore::CoreWorkerMemoryStore() {} +CoreWorkerMemoryStore::CoreWorkerMemoryStore( + std::function store_in_plasma) + : store_in_plasma_(store_in_plasma) {} void CoreWorkerMemoryStore::GetAsync( const ObjectID &object_id, std::function)> callback) { @@ -127,7 +129,25 @@ void CoreWorkerMemoryStore::GetAsync( } } +std::shared_ptr CoreWorkerMemoryStore::GetOrPromoteToPlasma( + const ObjectID &object_id) { + absl::MutexLock lock(&mu_); + auto iter = objects_.find(object_id); + if (iter != objects_.end()) { + auto obj = iter->second; + if (obj->IsInPlasmaError()) { + return nullptr; + } + return obj; + } + RAY_CHECK(store_in_plasma_ != nullptr) + << "Cannot promote object without plasma provider callback."; + promoted_to_plasma_.insert(object_id); + return nullptr; +} + Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &object) { + RAY_CHECK(object_id.IsDirectCallType()); std::vector)>> async_callbacks; auto object_entry = std::make_shared(object.GetData(), object.GetMetadata(), true); @@ -146,6 +166,13 @@ Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &ob object_async_get_requests_.erase(async_callback_it); } + auto promoted_it = promoted_to_plasma_.find(object_id); + if (promoted_it != promoted_to_plasma_.end()) { + RAY_CHECK(store_in_plasma_ != nullptr); + store_in_plasma_(object, object_id.WithTransportType(TaskTransportType::RAYLET)); + promoted_to_plasma_.erase(promoted_it); + } + bool should_add_entry = true; auto object_request_iter = object_get_requests_.find(object_id); if (object_request_iter != object_get_requests_.end()) { diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index 415da008e..431e1e9ca 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -18,7 +18,8 @@ class CoreWorkerMemoryStore; /// actor call (see direct_actor_transport.cc). class CoreWorkerMemoryStore { public: - CoreWorkerMemoryStore(); + CoreWorkerMemoryStore( + std::function store_in_plasma = nullptr); ~CoreWorkerMemoryStore(){}; /// Put an object with specified ID into object store. @@ -49,6 +50,14 @@ class CoreWorkerMemoryStore { void GetAsync(const ObjectID &object_id, std::function)> callback); + /// Get a single object if available. If the object is not local yet, or if the object + /// is local but is ErrorType::OBJECT_IN_PLASMA, then nullptr will be returned, and + /// the store will ensure the object is promoted to plasma once available. + /// + /// \param[in] object_id The object id to get. + /// \return pointer to the local object, or nullptr if promoted to plasma. + std::shared_ptr GetOrPromoteToPlasma(const ObjectID &object_id); + /// Delete a list of objects from the object store. /// /// \param[in] object_ids IDs of the objects to delete. @@ -62,6 +71,15 @@ class CoreWorkerMemoryStore { bool Contains(const ObjectID &object_id); private: + /// Optional callback for putting objects into the plasma store. + std::function store_in_plasma_; + + /// Protects the data structures below. + absl::Mutex mu_; + + /// Set of objects that should be promoted to plasma once available. + absl::flat_hash_set promoted_to_plasma_ GUARDED_BY(mu_); + /// Map from object ID to `RayObject`. absl::flat_hash_map> objects_ GUARDED_BY(mu_); @@ -73,9 +91,6 @@ class CoreWorkerMemoryStore { absl::flat_hash_map)>>> object_async_get_requests_ GUARDED_BY(mu_); - - /// Protect the two maps above. - absl::Mutex mu_; }; } // namespace ray diff --git a/src/ray/core_worker/store_provider/memory_store_provider.cc b/src/ray/core_worker/store_provider/memory_store_provider.cc index 883c2949b..3568fa923 100644 --- a/src/ray/core_worker/store_provider/memory_store_provider.cc +++ b/src/ray/core_worker/store_provider/memory_store_provider.cc @@ -14,6 +14,7 @@ CoreWorkerMemoryStoreProvider::CoreWorkerMemoryStoreProvider( Status CoreWorkerMemoryStoreProvider::Put(const RayObject &object, const ObjectID &object_id) { + RAY_CHECK(object_id.IsDirectCallType()); Status status = store_->Put(object_id, object); if (status.IsObjectExists()) { // Object already exists in store, treat it as ok. diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index 011c37282..dcceb2322 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -27,6 +27,7 @@ Status CoreWorkerPlasmaStoreProvider::SetClientOptions(std::string name, Status CoreWorkerPlasmaStoreProvider::Put(const RayObject &object, const ObjectID &object_id) { + RAY_CHECK(!object_id.IsDirectCallType()); std::shared_ptr data; RAY_RETURN_NOT_OK(Create(object.GetMetadata(), object.HasData() ? object.GetData()->Size() : 0, object_id, diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 174ed6ab2..34450bb4a 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -637,14 +637,14 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { std::vector ids(buffers.size()); for (size_t i = 0; i < ids.size(); i++) { - ids[i] = ObjectID::FromRandom(); + ids[i] = ObjectID::FromRandom().WithDirectTransportType(); RAY_CHECK_OK(provider.Put(buffers[i], ids[i])); } absl::flat_hash_set wait_ids(ids.begin(), ids.end()); absl::flat_hash_set wait_results; - ObjectID nonexistent_id = ObjectID::FromRandom(); + ObjectID nonexistent_id = ObjectID::FromRandom().WithDirectTransportType(); wait_ids.insert(nonexistent_id); RAY_CHECK_OK( provider.Wait(wait_ids, ids.size() + 1, 100, RandomTaskId(), &wait_results)); @@ -693,9 +693,9 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { std::vector ready_ids(buffers.size()); std::vector unready_ids(buffers.size()); for (size_t i = 0; i < unready_ids.size(); i++) { - ready_ids[i] = ObjectID::FromRandom(); + ready_ids[i] = ObjectID::FromRandom().WithDirectTransportType(); RAY_CHECK_OK(provider.Put(buffers[i], ready_ids[i])); - unready_ids[i] = ObjectID::FromRandom(); + unready_ids[i] = ObjectID::FromRandom().WithDirectTransportType(); } auto thread_func = [&unready_ids, &provider, &buffers]() { diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index bf2025b11..ff03163a8 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -38,9 +38,33 @@ class MockRayletClient : public WorkerLeaseInterface { int num_workers_returned = 0; }; +TEST(TestMemoryStore, TestPromoteToPlasma) { + bool num_plasma_puts = 0; + auto mem = std::shared_ptr(new CoreWorkerMemoryStore( + [&](const RayObject &obj, const ObjectID &obj_id) { num_plasma_puts += 1; })); + ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + auto data = GenerateRandomObject(); + ASSERT_TRUE(mem->Put(obj1, *data).ok()); + + // Test getting an already existing object. + ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj1) != nullptr); + ASSERT_TRUE(num_plasma_puts == 0); + + // Testing getting an object that doesn't exist yet causes promotion. + ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj2) == nullptr); + ASSERT_TRUE(num_plasma_puts == 0); + ASSERT_TRUE(mem->Put(obj2, *data).ok()); + ASSERT_TRUE(num_plasma_puts == 1); + + // The next time you get it, it's already there so no need to promote. + ASSERT_TRUE(mem->GetOrPromoteToPlasma(obj2) != nullptr); + ASSERT_TRUE(num_plasma_puts == 1); +} + TEST(LocalDependencyResolverTest, TestNoDependencies) { auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - CoreWorkerMemoryStoreProvider store(ptr); + auto store = std::make_shared(ptr); LocalDependencyResolver resolver(store); TaskSpecification task; bool ok = false; @@ -50,7 +74,7 @@ TEST(LocalDependencyResolverTest, TestNoDependencies) { TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) { auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - CoreWorkerMemoryStoreProvider store(ptr); + auto store = std::make_shared(ptr); LocalDependencyResolver resolver(store); ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::RAYLET); TaskSpecification task; @@ -62,16 +86,38 @@ TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) { ASSERT_EQ(resolver.NumPendingTasks(), 0); } +TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { + auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); + auto store = std::make_shared(ptr); + LocalDependencyResolver resolver(store); + ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); + auto metadata = const_cast(reinterpret_cast(meta.data())); + auto meta_buffer = std::make_shared(metadata, meta.size()); + auto data = RayObject(nullptr, meta_buffer); + ASSERT_TRUE(store->Put(data, obj1).ok()); + TaskSpecification task; + task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); + ASSERT_TRUE(task.ArgId(0, 0).IsDirectCallType()); + bool ok = false; + resolver.ResolveDependencies(task, [&ok]() { ok = true; }); + ASSERT_TRUE(ok); + ASSERT_TRUE(task.ArgByRef(0)); + // Checks that the object id was promoted to a plasma type id. + ASSERT_FALSE(task.ArgId(0, 0).IsDirectCallType()); + ASSERT_EQ(resolver.NumPendingTasks(), 0); +} + TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) { auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - CoreWorkerMemoryStoreProvider store(ptr); + auto store = std::make_shared(ptr); LocalDependencyResolver resolver(store); ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); auto data = GenerateRandomObject(); // Ensure the data is already present in the local store. - ASSERT_TRUE(store.Put(*data, obj1).ok()); - ASSERT_TRUE(store.Put(*data, obj2).ok()); + ASSERT_TRUE(store->Put(*data, obj1).ok()); + ASSERT_TRUE(store->Put(*data, obj2).ok()); TaskSpecification task; task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); task.GetMutableMessage().add_args()->add_object_ids(obj2.Binary()); @@ -88,7 +134,7 @@ TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) { TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); - CoreWorkerMemoryStoreProvider store(ptr); + auto store = std::make_shared(ptr); LocalDependencyResolver resolver(store); ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); @@ -100,8 +146,8 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { resolver.ResolveDependencies(task, [&ok]() { ok = true; }); ASSERT_EQ(resolver.NumPendingTasks(), 1); ASSERT_TRUE(!ok); - ASSERT_TRUE(store.Put(*data, obj1).ok()); - ASSERT_TRUE(store.Put(*data, obj2).ok()); + ASSERT_TRUE(store->Put(*data, obj1).ok()); + ASSERT_TRUE(store->Put(*data, obj2).ok()); // Tests that the task proto was rewritten to have inline argument values after // resolution completes. ASSERT_TRUE(ok); @@ -112,10 +158,11 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { ASSERT_EQ(resolver.NumPendingTasks(), 0); } -TEST(DirectTaskTranportTest, TestSubmitOneTask) { +TEST(DirectTaskTransportTest, TestSubmitOneTask) { MockRayletClient raylet_client; auto worker_client = std::shared_ptr(new MockWorkerClient()); - auto store = std::shared_ptr(new CoreWorkerMemoryStore()); + auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); + auto store = std::make_shared(ptr); auto factory = [&](WorkerAddress addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); TaskSpecification task; @@ -133,10 +180,11 @@ TEST(DirectTaskTranportTest, TestSubmitOneTask) { ASSERT_EQ(raylet_client.num_workers_returned, 1); } -TEST(DirectTaskTranportTest, TestHandleTaskFailure) { +TEST(DirectTaskTransportTest, TestHandleTaskFailure) { MockRayletClient raylet_client; auto worker_client = std::shared_ptr(new MockWorkerClient()); - auto store = std::shared_ptr(new CoreWorkerMemoryStore()); + auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); + auto store = std::make_shared(ptr); auto factory = [&](WorkerAddress addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); TaskSpecification task; @@ -150,10 +198,11 @@ TEST(DirectTaskTranportTest, TestHandleTaskFailure) { ASSERT_EQ(raylet_client.num_workers_returned, 1); } -TEST(DirectTaskTranportTest, TestConcurrentWorkerLeases) { +TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { MockRayletClient raylet_client; auto worker_client = std::shared_ptr(new MockWorkerClient()); - auto store = std::shared_ptr(new CoreWorkerMemoryStore()); + auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); + auto store = std::make_shared(ptr); auto factory = [&](WorkerAddress addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); TaskSpecification task1; @@ -190,10 +239,11 @@ TEST(DirectTaskTranportTest, TestConcurrentWorkerLeases) { ASSERT_EQ(raylet_client.num_workers_returned, 3); } -TEST(DirectTaskTranportTest, TestReuseWorkerLease) { +TEST(DirectTaskTransportTest, TestReuseWorkerLease) { MockRayletClient raylet_client; auto worker_client = std::shared_ptr(new MockWorkerClient()); - auto store = std::shared_ptr(new CoreWorkerMemoryStore()); + auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); + auto store = std::make_shared(ptr); auto factory = [&](WorkerAddress addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); TaskSpecification task1; @@ -232,10 +282,11 @@ TEST(DirectTaskTranportTest, TestReuseWorkerLease) { ASSERT_EQ(raylet_client.num_workers_returned, 2); } -TEST(DirectTaskTranportTest, TestWorkerNotReusedOnError) { +TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { MockRayletClient raylet_client; auto worker_client = std::shared_ptr(new MockWorkerClient()); - auto store = std::shared_ptr(new CoreWorkerMemoryStore()); + auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); + auto store = std::make_shared(ptr); auto factory = [&](WorkerAddress addr) { return worker_client; }; CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); TaskSpecification task1; diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 0c702d963..d6115da68 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -5,9 +5,61 @@ using ray::rpc::ActorTableData; namespace ray { +void TreatTaskAsFailed(const TaskID &task_id, int num_returns, + const rpc::ErrorType &error_type, + std::shared_ptr &in_memory_store) { + RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id + << ", error_type: " << ErrorType_Name(error_type); + for (int i = 0; i < num_returns; i++) { + const auto object_id = ObjectID::ForTaskReturn( + task_id, /*index=*/i + 1, + /*transport_type=*/static_cast(TaskTransportType::DIRECT)); + std::string meta = std::to_string(static_cast(error_type)); + auto metadata = const_cast(reinterpret_cast(meta.data())); + auto meta_buffer = std::make_shared(metadata, meta.size()); + RAY_CHECK_OK(in_memory_store->Put(RayObject(nullptr, meta_buffer), object_id)); + } +} + +void WriteObjectsToMemoryStore( + const rpc::PushTaskReply &reply, + std::shared_ptr &in_memory_store) { + for (int i = 0; i < reply.return_objects_size(); i++) { + const auto &return_object = reply.return_objects(i); + ObjectID object_id = ObjectID::FromBinary(return_object.object_id()); + + if (return_object.in_plasma()) { + // Mark it as in plasma with a dummy object. + std::string meta = + std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); + auto metadata = + const_cast(reinterpret_cast(meta.data())); + auto meta_buffer = std::make_shared(metadata, meta.size()); + RAY_CHECK_OK(in_memory_store->Put(RayObject(nullptr, meta_buffer), object_id)); + } else { + std::shared_ptr data_buffer; + if (return_object.data().size() > 0) { + data_buffer = std::make_shared( + const_cast( + reinterpret_cast(return_object.data().data())), + return_object.data().size()); + } + std::shared_ptr metadata_buffer; + if (return_object.metadata().size() > 0) { + metadata_buffer = std::make_shared( + const_cast( + reinterpret_cast(return_object.metadata().data())), + return_object.metadata().size()); + } + RAY_CHECK_OK( + in_memory_store->Put(RayObject(data_buffer, metadata_buffer), object_id)); + } + } +} + CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter( rpc::ClientCallManager &client_call_manager, - CoreWorkerMemoryStoreProvider store_provider) + std::shared_ptr store_provider) : client_call_manager_(client_call_manager), in_memory_store_(store_provider) {} Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) { @@ -49,7 +101,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe } else { // Actor is dead, treat the task as failure. RAY_CHECK(iter->second.state_ == ActorTableData::DEAD); - TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED); + TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, in_memory_store_); } // If the task submission subsequently fails, then the client will receive @@ -84,7 +136,8 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate( for (const auto &entry : iter->second) { const auto &task_id = entry.first; const auto num_returns = entry.second; - TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED); + TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, + in_memory_store_); } waiting_reply_tasks_.erase(actor_id); } @@ -94,7 +147,8 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate( if (pending_it != pending_requests_.end()) { for (const auto &request : pending_it->second) { TreatTaskAsFailed(TaskID::FromBinary(request->task_spec().task_id()), - request->task_spec().num_returns(), rpc::ErrorType::ACTOR_DIED); + request->task_spec().num_returns(), rpc::ErrorType::ACTOR_DIED, + in_memory_store_); } pending_requests_.erase(pending_it); } @@ -136,59 +190,14 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask( // Note that this might be the __ray_terminate__ task, so we don't log // loudly with ERROR here. RAY_LOG(INFO) << "Task failed with error: " << status; - TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED); + TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, + in_memory_store_); return; } - for (int i = 0; i < reply.return_objects_size(); i++) { - const auto &return_object = reply.return_objects(i); - ObjectID object_id = ObjectID::FromBinary(return_object.object_id()); - - if (return_object.in_plasma()) { - // Mark it as in plasma with a dummy object. - std::string meta = - std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); - auto metadata = - const_cast(reinterpret_cast(meta.data())); - auto meta_buffer = std::make_shared(metadata, meta.size()); - RAY_CHECK_OK( - in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id)); - } else { - std::shared_ptr data_buffer; - if (return_object.data().size() > 0) { - data_buffer = std::make_shared( - const_cast( - reinterpret_cast(return_object.data().data())), - return_object.data().size()); - } - std::shared_ptr metadata_buffer; - if (return_object.metadata().size() > 0) { - metadata_buffer = std::make_shared( - const_cast( - reinterpret_cast(return_object.metadata().data())), - return_object.metadata().size()); - } - RAY_CHECK_OK( - in_memory_store_.Put(RayObject(data_buffer, metadata_buffer), object_id)); - } - } + WriteObjectsToMemoryStore(reply, in_memory_store_); }); if (!status.ok()) { - TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED); - } -} - -void CoreWorkerDirectActorTaskSubmitter::TreatTaskAsFailed( - const TaskID &task_id, int num_returns, const rpc::ErrorType &error_type) { - RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id - << ", error_type: " << ErrorType_Name(error_type); - for (int i = 0; i < num_returns; i++) { - const auto object_id = ObjectID::ForTaskReturn( - task_id, /*index=*/i + 1, - /*transport_type=*/static_cast(TaskTransportType::DIRECT)); - std::string meta = std::to_string(static_cast(error_type)); - auto metadata = const_cast(reinterpret_cast(meta.data())); - auto meta_buffer = std::make_shared(metadata, meta.size()); - RAY_CHECK_OK(in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id)); + TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::ACTOR_DIED, in_memory_store_); } } @@ -301,8 +310,8 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask( send_reply_callback(status, nullptr, nullptr); }, [send_reply_callback]() { - send_reply_callback(Status::Invalid("client cancelled rpc"), nullptr, - nullptr); + send_reply_callback(Status::Invalid("client cancelled stale rpc"), + nullptr, nullptr); }, dependencies); } diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 076e63a6e..275b9016b 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -22,6 +22,26 @@ namespace ray { /// The max time to wait for out-of-order tasks. const int kMaxReorderWaitSeconds = 30; +/// Treat a task as failed. +/// +/// \param[in] task_id The ID of a task. +/// \param[in] num_returns Number of return objects. +/// \param[in] error_type The type of the specific error. +/// \param[in] in_memory_store The memory store to write to. +/// \return Void. +void TreatTaskAsFailed(const TaskID &task_id, int num_returns, + const rpc::ErrorType &error_type, + std::shared_ptr &in_memory_store); + +/// Write return objects to the memory store. +/// +/// \param[in] reply Proto response to a direct actor or task call. +/// \param[in] in_memory_store The memory store to write to. +/// \return Void. +void WriteObjectsToMemoryStore( + const rpc::PushTaskReply &reply, + std::shared_ptr &in_memory_store); + /// In direct actor call task submitter and receiver, a task is directly submitted /// to the actor that will execute it. @@ -40,8 +60,9 @@ struct ActorStateData { // This class is thread-safe. class CoreWorkerDirectActorTaskSubmitter { public: - CoreWorkerDirectActorTaskSubmitter(rpc::ClientCallManager &client_call_manager, - CoreWorkerMemoryStoreProvider store_provider); + CoreWorkerDirectActorTaskSubmitter( + rpc::ClientCallManager &client_call_manager, + std::shared_ptr store_provider); /// Submit a task to an actor for execution. /// @@ -70,15 +91,6 @@ class CoreWorkerDirectActorTaskSubmitter { std::unique_ptr request, const ActorID &actor_id, const TaskID &task_id, int num_returns); - /// Treat a task as failed. - /// - /// \param[in] task_id The ID of a task. - /// \param[in] num_returns Number of return objects. - /// \param[in] error_type The type of the specific error. - /// \return Void. - void TreatTaskAsFailed(const TaskID &task_id, int num_returns, - const rpc::ErrorType &error_type); - /// Create connection to actor and send all pending tasks. /// Note that this function doesn't take lock, the caller is expected to hold /// `mutex_` before calling this function. @@ -120,7 +132,7 @@ class CoreWorkerDirectActorTaskSubmitter { std::unordered_map> waiting_reply_tasks_; /// The store provider. - CoreWorkerMemoryStoreProvider in_memory_store_; + std::shared_ptr in_memory_store_; friend class CoreWorkerTest; }; @@ -230,7 +242,8 @@ class SchedulingQueue { std::function accept_request, std::function reject_request, const std::vector &dependencies = {}) { if (seq_no == -1) { - seq_no = next_seq_no_; // A value of -1 means no ordering constraint. + accept_request(); // A seq_no of -1 means no ordering constraint. + return; } RAY_CHECK(boost::this_thread::get_id() == main_thread_id_); if (client_processed_up_to >= next_seq_no_) { @@ -259,6 +272,8 @@ class SchedulingQueue { // Cancel any stale requests that the client doesn't need any longer. while (!pending_tasks_.empty() && pending_tasks_.begin()->first < next_seq_no_) { auto head = pending_tasks_.begin(); + RAY_LOG(ERROR) << "Cancelling stale RPC with seqno " + << pending_tasks_.begin()->first << " < " << next_seq_no_; head->second.Cancel(); pending_tasks_.erase(head); } diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index c05818e18..e94d7dc8a 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -1,4 +1,5 @@ #include "ray/core_worker/transport/direct_task_transport.h" +#include "ray/core_worker/transport/direct_actor_transport.h" namespace ray { @@ -13,13 +14,20 @@ void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr valu if (id == obj_id) { auto *mutable_arg = msg.mutable_args(i); mutable_arg->clear_object_ids(); - if (value->HasData()) { - const auto &data = value->GetData(); - mutable_arg->set_data(data->Data(), data->Size()); - } - if (value->HasMetadata()) { - const auto &metadata = value->GetMetadata(); - mutable_arg->set_metadata(metadata->Data(), metadata->Size()); + if (value->IsInPlasmaError()) { + // Promote the object id to plasma. + mutable_arg->add_object_ids( + obj_id.WithTransportType(TaskTransportType::RAYLET).Binary()); + } else { + // Inline the object value. + if (value->HasData()) { + const auto &data = value->GetData(); + mutable_arg->set_data(data->Data(), data->Size()); + } + if (value->HasMetadata()) { + const auto &metadata = value->GetMetadata(); + mutable_arg->set_metadata(metadata->Data(), metadata->Size()); + } } found = true; } @@ -52,7 +60,7 @@ void LocalDependencyResolver::ResolveDependencies(const TaskSpecification &task, num_pending_ += 1; for (const auto &obj_id : state->local_dependencies) { - in_memory_store_.GetAsync( + in_memory_store_->GetAsync( obj_id, [this, state, obj_id, on_complete](std::shared_ptr obj) { RAY_CHECK(obj != nullptr); bool complete = false; @@ -128,23 +136,6 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( worker_request_pending_ = true; } -void CoreWorkerDirectTaskSubmitter::TreatTaskAsFailed(const TaskID &task_id, - int num_returns, - const rpc::ErrorType &error_type) { - RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id - << ", error_type: " << ErrorType_Name(error_type); - for (int i = 0; i < num_returns; i++) { - const auto object_id = ObjectID::ForTaskReturn( - task_id, /*index=*/i + 1, - /*transport_type=*/static_cast(TaskTransportType::DIRECT)); - std::string meta = std::to_string(static_cast(error_type)); - auto metadata = const_cast(reinterpret_cast(meta.data())); - auto meta_buffer = std::make_shared(metadata, meta.size()); - RAY_CHECK_OK(in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id)); - } -} - -// TODO(ekl) consider reconsolidating with DirectActorTransport. void CoreWorkerDirectTaskSubmitter::PushNormalTask(const WorkerAddress &addr, rpc::CoreWorkerClientInterface &client, TaskSpecification &task_spec) { @@ -157,30 +148,15 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(const WorkerAddress &addr, [this, task_id, num_returns, addr](Status status, const rpc::PushTaskReply &reply) { OnWorkerIdle(addr, /*error=*/!status.ok()); if (!status.ok()) { - TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED); + TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED, + in_memory_store_); return; } - for (int i = 0; i < reply.return_objects_size(); i++) { - const auto &return_object = reply.return_objects(i); - ObjectID object_id = ObjectID::FromBinary(return_object.object_id()); - std::shared_ptr data_buffer; - if (return_object.data().size() > 0) { - data_buffer = std::make_shared( - const_cast( - reinterpret_cast(return_object.data().data())), - return_object.data().size()); - } - std::shared_ptr metadata_buffer; - if (return_object.metadata().size() > 0) { - metadata_buffer = std::make_shared( - const_cast( - reinterpret_cast(return_object.metadata().data())), - return_object.metadata().size()); - } - RAY_CHECK_OK( - in_memory_store_.Put(RayObject(data_buffer, metadata_buffer), object_id)); - } + WriteObjectsToMemoryStore(reply, in_memory_store_); }); - RAY_CHECK_OK(status); + if (!status.ok()) { + TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED, + in_memory_store_); + } } }; // namespace ray diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index 81547492d..0bb708a23 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -23,7 +23,7 @@ struct TaskState { // This class is thread-safe. class LocalDependencyResolver { public: - LocalDependencyResolver(CoreWorkerMemoryStoreProvider &store_provider) + LocalDependencyResolver(std::shared_ptr store_provider) : in_memory_store_(store_provider), num_pending_(0) {} /// Resolve all local and remote dependencies for the task, calling the specified @@ -42,7 +42,7 @@ class LocalDependencyResolver { private: /// The store provider. - CoreWorkerMemoryStoreProvider in_memory_store_; + std::shared_ptr in_memory_store_; /// Number of tasks pending dependency resolution. std::atomic num_pending_; @@ -58,9 +58,9 @@ typedef std::function(WorkerAddr // This class is thread-safe. class CoreWorkerDirectTaskSubmitter { public: - CoreWorkerDirectTaskSubmitter(WorkerLeaseInterface &lease_client, - ClientFactoryFn client_factory, - CoreWorkerMemoryStoreProvider store_provider) + CoreWorkerDirectTaskSubmitter( + WorkerLeaseInterface &lease_client, ClientFactoryFn client_factory, + std::shared_ptr store_provider) : lease_client_(lease_client), client_factory_(client_factory), in_memory_store_(store_provider), @@ -93,10 +93,6 @@ class CoreWorkerDirectTaskSubmitter { void PushNormalTask(const WorkerAddress &addr, rpc::CoreWorkerClientInterface &client, TaskSpecification &task_spec); - /// Mark a direct call as failed by storing errors for its return objects. - void TreatTaskAsFailed(const TaskID &task_id, int num_returns, - const rpc::ErrorType &error_type); - // Client that can be used to lease and return workers. WorkerLeaseInterface &lease_client_; @@ -104,7 +100,7 @@ class CoreWorkerDirectTaskSubmitter { ClientFactoryFn client_factory_; /// The store provider. - CoreWorkerMemoryStoreProvider in_memory_store_; + std::shared_ptr in_memory_store_; /// Resolve local and remote dependencies; LocalDependencyResolver resolver_;