[minor] Perf optimizations for direct actor task submission (#6044)

* merge optimizations

* fix

* fix memory err

* optimize

* fix tests

* fix serialization of method handles

* document weakref

* fix check

* bazel format

* disable on 2
This commit is contained in:
Eric Liang 2019-11-01 14:41:14 -07:00 committed by GitHub
parent eef4ad3bba
commit fb34928a2a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 157 additions and 118 deletions

View file

@ -371,6 +371,7 @@ cc_library(
]), ]),
copts = COPTS, copts = COPTS,
deps = [ deps = [
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
":core_worker_cc_proto", ":core_worker_cc_proto",
":ray_common", ":ray_common",
@ -413,6 +414,8 @@ cc_library(
deps = [ deps = [
":core_worker_lib", ":core_worker_lib",
":gcs", ":gcs",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_googletest//:gtest_main", "@com_google_googletest//:gtest_main",
], ],
) )

View file

@ -683,21 +683,38 @@ cdef void push_objects_into_return_vector(
c_vector[shared_ptr[CRayObject]] *returns): c_vector[shared_ptr[CRayObject]] *returns):
cdef: cdef:
c_string metadata_str = RAW_BUFFER_METADATA
c_string raw_data_str
shared_ptr[CBuffer] data shared_ptr[CBuffer] data
shared_ptr[CBuffer] metadata shared_ptr[CBuffer] metadata
shared_ptr[CRayObject] ray_object shared_ptr[CRayObject] ray_object
int64_t data_size int64_t data_size
for serialized_object in py_objects: for serialized_object in py_objects:
data_size = serialized_object.total_bytes if isinstance(serialized_object, bytes):
data = dynamic_pointer_cast[ data_size = len(serialized_object)
CBuffer, LocalMemoryBuffer]( raw_data_str = serialized_object
make_shared[LocalMemoryBuffer](data_size)) data = dynamic_pointer_cast[
stream = pyarrow.FixedSizeBufferWriter( CBuffer, LocalMemoryBuffer](
pyarrow.py_buffer(Buffer.make(data))) make_shared[LocalMemoryBuffer](
serialized_object.write_to(stream) <uint8_t*>(raw_data_str.data()), raw_data_str.size()))
ray_object = make_shared[CRayObject](data, metadata) metadata = dynamic_pointer_cast[
returns.push_back(ray_object) CBuffer, LocalMemoryBuffer](
make_shared[LocalMemoryBuffer](
<uint8_t*>(metadata_str.data()), metadata_str.size()))
ray_object = make_shared[CRayObject](data, metadata, True)
returns.push_back(ray_object)
else:
data_size = serialized_object.total_bytes
data = dynamic_pointer_cast[
CBuffer, LocalMemoryBuffer](
make_shared[LocalMemoryBuffer](data_size))
metadata.reset()
stream = pyarrow.FixedSizeBufferWriter(
pyarrow.py_buffer(Buffer.make(data)))
serialized_object.write_to(stream)
ray_object = make_shared[CRayObject](data, metadata)
returns.push_back(ray_object)
cdef class CoreWorker: cdef class CoreWorker:
@ -981,7 +998,7 @@ cdef class CoreWorker:
function_descriptor, function_descriptor,
args, args,
int num_return_vals, int num_return_vals,
resources): double num_method_cpus):
cdef: cdef:
CActorID c_actor_id = actor_id.native() CActorID c_actor_id = actor_id.native()
@ -992,7 +1009,8 @@ cdef class CoreWorker:
c_vector[CObjectID] return_ids c_vector[CObjectID] return_ids
with self.profile_event(b"submit_task"): with self.profile_event(b"submit_task"):
prepare_resources(resources, &c_resources) if num_method_cpus > 0:
c_resources[b"CPU"] = num_method_cpus
task_options = CTaskOptions(num_return_vals, c_resources) task_options = CTaskOptions(num_return_vals, c_resources)
ray_function = CRayFunction( ray_function = CRayFunction(
LANGUAGE_PYTHON, string_vector_from_list(function_descriptor)) LANGUAGE_PYTHON, string_vector_from_list(function_descriptor))

View file

@ -7,6 +7,7 @@ import inspect
import logging import logging
import six import six
import sys import sys
import weakref
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import namedtuple from collections import namedtuple
@ -57,9 +58,8 @@ def method(*args, **kwargs):
class ActorMethod(object): class ActorMethod(object):
"""A class used to invoke an actor method. """A class used to invoke an actor method.
Note: This class is instantiated only while the actor method is being Note: This class only keeps a weak ref to the actor, unless it has been
invoked (so that it doesn't keep a reference to the actor handle and passed to a remote function. This avoids delays in GC of the actor.
prevent it from going out of scope).
Attributes: Attributes:
_actor: A handle to the actor. _actor: A handle to the actor.
@ -75,8 +75,13 @@ class ActorMethod(object):
"test_decorated_method" in "python/ray/tests/test_actor.py". "test_decorated_method" in "python/ray/tests/test_actor.py".
""" """
def __init__(self, actor, method_name, num_return_vals, decorator=None): def __init__(self,
self._actor = actor actor,
method_name,
num_return_vals,
decorator=None,
hardref=False):
self._actor_ref = weakref.ref(actor)
self._method_name = method_name self._method_name = method_name
self._num_return_vals = num_return_vals self._num_return_vals = num_return_vals
# This is a decorator that is used to wrap the function invocation (as # This is a decorator that is used to wrap the function invocation (as
@ -86,6 +91,11 @@ class ActorMethod(object):
# and return the resulting ObjectIDs. # and return the resulting ObjectIDs.
self._decorator = decorator self._decorator = decorator
# Acquire a hard ref to the actor, this is useful mainly when passing
# actor method handles to remote functions.
if hardref:
self._actor_hard_ref = actor
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
raise Exception("Actor methods cannot be called directly. Instead " raise Exception("Actor methods cannot be called directly. Instead "
"of running 'object.{}()', try " "of running 'object.{}()', try "
@ -96,15 +106,14 @@ class ActorMethod(object):
return self._remote(args, kwargs) return self._remote(args, kwargs)
def _remote(self, args=None, kwargs=None, num_return_vals=None): def _remote(self, args=None, kwargs=None, num_return_vals=None):
if args is None:
args = []
if kwargs is None:
kwargs = {}
if num_return_vals is None: if num_return_vals is None:
num_return_vals = self._num_return_vals num_return_vals = self._num_return_vals
def invocation(args, kwargs): def invocation(args, kwargs):
return self._actor._actor_method_call( actor = self._actor_ref()
if actor is None:
raise RuntimeError("Lost reference to actor")
return actor._actor_method_call(
self._method_name, self._method_name,
args=args, args=args,
kwargs=kwargs, kwargs=kwargs,
@ -116,6 +125,22 @@ class ActorMethod(object):
return invocation(args, kwargs) return invocation(args, kwargs)
def __getstate__(self):
return {
"actor": self._actor_ref(),
"method_name": self._method_name,
"num_return_vals": self._num_return_vals,
"decorator": self._decorator,
}
def __setstate__(self, state):
self.__init__(
state["actor"],
state["method_name"],
state["num_return_vals"],
state["decorator"],
hardref=True)
class ActorClassMetadata(object): class ActorClassMetadata(object):
"""Metadata for an actor class. """Metadata for an actor class.
@ -502,6 +527,14 @@ class ActorHandle(object):
for method_name in self._ray_method_signatures.keys() for method_name in self._ray_method_signatures.keys()
} }
for method_name in actor_method_names:
method = ActorMethod(
self,
method_name,
self._ray_method_num_return_vals[method_name],
decorator=self._ray_method_decorators.get(method_name))
setattr(self, method_name, method)
def _actor_method_call(self, def _actor_method_call(self,
method_name, method_name,
args=None, args=None,
@ -526,13 +559,15 @@ class ActorHandle(object):
""" """
worker = ray.worker.get_global_worker() worker = ray.worker.get_global_worker()
worker.check_connected()
function_signature = self._ray_method_signatures[method_name]
args = args or [] args = args or []
kwargs = kwargs or {} kwargs = kwargs or {}
function_signature = self._ray_method_signatures[method_name]
list_args = signature.flatten_args(function_signature, args, kwargs) if not args and not kwargs and not function_signature:
list_args = []
else:
list_args = signature.flatten_args(function_signature, args,
kwargs)
if worker.mode == ray.LOCAL_MODE: if worker.mode == ray.LOCAL_MODE:
function = getattr(worker.actors[self._actor_id], method_name) function = getattr(worker.actors[self._actor_id], method_name)
object_ids = worker.local_mode_manager.execute( object_ids = worker.local_mode_manager.execute(
@ -541,7 +576,7 @@ class ActorHandle(object):
object_ids = worker.core_worker.submit_actor_task( object_ids = worker.core_worker.submit_actor_task(
self._ray_actor_id, self._ray_actor_id,
self._ray_function_descriptor_lists[method_name], list_args, self._ray_function_descriptor_lists[method_name], list_args,
num_return_vals, {"CPU": self._ray_actor_method_cpus}) num_return_vals, self._ray_actor_method_cpus)
if len(object_ids) == 1: if len(object_ids) == 1:
object_ids = object_ids[0] object_ids = object_ids[0]
@ -554,30 +589,6 @@ class ActorHandle(object):
def __dir__(self): def __dir__(self):
return self._ray_actor_method_names return self._ray_actor_method_names
def __getattribute__(self, attr):
try:
# Check whether this is an actor method.
actor_method_names = object.__getattribute__(
self, "_ray_actor_method_names")
if attr in actor_method_names:
# We create the ActorMethod on the fly here so that the
# ActorHandle doesn't need a reference to the ActorMethod.
# The ActorMethod has a reference to the ActorHandle and
# this was causing cyclic references which were prevent
# object deallocation from behaving in a predictable
# manner.
return ActorMethod(
self,
attr,
self._ray_method_num_return_vals[attr],
decorator=self._ray_method_decorators.get(attr))
except AttributeError:
pass
# If the requested attribute is not a registered method, fall back
# to default __getattribute__.
return object.__getattribute__(self, attr)
def __repr__(self): def __repr__(self):
return "Actor({}, {})".format(self._ray_class_name, return "Actor({}, {})".format(self._ray_class_name,
self._actor_id.hex()) self._actor_id.hex())

View file

@ -10,19 +10,19 @@ import ray
filter_pattern = os.environ.get("TESTS_TO_RUN", "") filter_pattern = os.environ.get("TESTS_TO_RUN", "")
@ray.remote @ray.remote(num_cpus=0)
class Actor(object): class Actor(object):
def small_value(self): def small_value(self):
return 0 return b"ok"
def small_value_arg(self, x): def small_value_arg(self, x):
return 0 return b"ok"
def small_value_batch(self, n): def small_value_batch(self, n):
ray.get([small_value.remote() for _ in range(n)]) ray.get([small_value.remote() for _ in range(n)])
@ray.remote @ray.remote(num_cpus=0)
class Client(object): class Client(object):
def __init__(self, servers): def __init__(self, servers):
if not isinstance(servers, list): if not isinstance(servers, list):
@ -45,7 +45,7 @@ class Client(object):
@ray.remote @ray.remote
def small_value(): def small_value():
return 0 return b"ok"
@ray.remote @ray.remote

View file

@ -494,11 +494,16 @@ def test_actor_deletion(ray_start_regular):
actors = None actors = None
[ray.tests.utils.wait_for_pid_to_exit(pid) for pid in pids] [ray.tests.utils.wait_for_pid_to_exit(pid) for pid in pids]
@pytest.mark.skipif(
sys.version_info < (3, 0), reason="This test requires Python 3.")
def test_actor_method_deletion(ray_start_regular):
@ray.remote @ray.remote
class Actor(object): class Actor(object):
def method(self): def method(self):
return 1 return 1
# TODO(ekl) this doesn't work in Python 2 after the weak ref method change.
# Make sure that if we create an actor and call a method on it # Make sure that if we create an actor and call a method on it
# immediately, the actor doesn't get killed before the method is # immediately, the actor doesn't get killed before the method is
# called. # called.

View file

@ -292,8 +292,8 @@ class Worker(object):
if isinstance(value, bytes): if isinstance(value, bytes):
if return_buffer is not None: if return_buffer is not None:
raise NotImplementedError( return_buffer.append(value)
"returning raw buffers from direct actor calls") return
# If the object is a byte array, skip serializing it and # If the object is a byte array, skip serializing it and
# use a special metadata to indicate it's raw binary. So # use a special metadata to indicate it's raw binary. So
# that this object can also be read by Java. # that this object can also be read by Java.

View file

@ -25,11 +25,6 @@ class RayObject {
RayObject(const std::shared_ptr<Buffer> &data, const std::shared_ptr<Buffer> &metadata, RayObject(const std::shared_ptr<Buffer> &data, const std::shared_ptr<Buffer> &metadata,
bool copy_data = false) bool copy_data = false)
: data_(data), metadata_(metadata), has_data_copy_(copy_data) { : data_(data), metadata_(metadata), has_data_copy_(copy_data) {
RAY_CHECK(!data || data_->Size())
<< "Zero-length buffers are not allowed when constructing a RayObject.";
RAY_CHECK(!metadata || metadata->Size())
<< "Zero-length buffers are not allowed when constructing a RayObject.";
if (has_data_copy_) { if (has_data_copy_) {
// If this object is required to hold a copy of the data, // If this object is required to hold a copy of the data,
// make a copy if the passed in buffers don't already have a copy. // make a copy if the passed in buffers don't already have a copy.

View file

@ -39,8 +39,8 @@ void BuildCommonTaskSpec(
// Group object ids according the the corresponding store providers. // Group object ids according the the corresponding store providers.
void GroupObjectIdsByStoreProvider(const std::vector<ObjectID> &object_ids, void GroupObjectIdsByStoreProvider(const std::vector<ObjectID> &object_ids,
std::unordered_set<ObjectID> *plasma_object_ids, absl::flat_hash_set<ObjectID> *plasma_object_ids,
std::unordered_set<ObjectID> *memory_object_ids) { absl::flat_hash_set<ObjectID> *memory_object_ids) {
// There are two cases: // There are two cases:
// - for task return objects from direct actor call, use memory store provider; // - for task return objects from direct actor call, use memory store provider;
// - all the others use plasma store provider. // - all the others use plasma store provider.
@ -312,12 +312,12 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, int64_t timeout_ms,
std::vector<std::shared_ptr<RayObject>> *results) { std::vector<std::shared_ptr<RayObject>> *results) {
results->resize(ids.size(), nullptr); results->resize(ids.size(), nullptr);
std::unordered_set<ObjectID> plasma_object_ids; absl::flat_hash_set<ObjectID> plasma_object_ids;
std::unordered_set<ObjectID> memory_object_ids; absl::flat_hash_set<ObjectID> memory_object_ids;
GroupObjectIdsByStoreProvider(ids, &plasma_object_ids, &memory_object_ids); GroupObjectIdsByStoreProvider(ids, &plasma_object_ids, &memory_object_ids);
bool got_exception = false; bool got_exception = false;
std::unordered_map<ObjectID, std::shared_ptr<RayObject>> result_map; absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> result_map;
auto start_time = current_time_ms(); auto start_time = current_time_ms();
RAY_RETURN_NOT_OK(plasma_store_provider_->Get(plasma_object_ids, timeout_ms, RAY_RETURN_NOT_OK(plasma_store_provider_->Get(plasma_object_ids, timeout_ms,
worker_context_.GetCurrentTaskID(), worker_context_.GetCurrentTaskID(),
@ -360,8 +360,8 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
"Number of objects to wait for must be between 1 and the number of ids."); "Number of objects to wait for must be between 1 and the number of ids.");
} }
std::unordered_set<ObjectID> plasma_object_ids; absl::flat_hash_set<ObjectID> plasma_object_ids;
std::unordered_set<ObjectID> memory_object_ids; absl::flat_hash_set<ObjectID> memory_object_ids;
GroupObjectIdsByStoreProvider(ids, &plasma_object_ids, &memory_object_ids); GroupObjectIdsByStoreProvider(ids, &plasma_object_ids, &memory_object_ids);
if (plasma_object_ids.size() + memory_object_ids.size() != ids.size()) { if (plasma_object_ids.size() + memory_object_ids.size() != ids.size()) {
@ -377,7 +377,7 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
// a timeout of 0, but that does not address the situation where objects // a timeout of 0, but that does not address the situation where objects
// become available on the second store provider while waiting on the first. // become available on the second store provider while waiting on the first.
std::unordered_set<ObjectID> ready; absl::flat_hash_set<ObjectID> ready;
// Wait from both store providers with timeout set to 0. This is to avoid the case // Wait from both store providers with timeout set to 0. This is to avoid the case
// where we might use up the entire timeout on trying to get objects from one store // where we might use up the entire timeout on trying to get objects from one store
// provider before even trying another (which might have all of the objects available). // provider before even trying another (which might have all of the objects available).
@ -421,8 +421,8 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
Status CoreWorker::Delete(const std::vector<ObjectID> &object_ids, bool local_only, Status CoreWorker::Delete(const std::vector<ObjectID> &object_ids, bool local_only,
bool delete_creating_tasks) { bool delete_creating_tasks) {
std::unordered_set<ObjectID> plasma_object_ids; absl::flat_hash_set<ObjectID> plasma_object_ids;
std::unordered_set<ObjectID> memory_object_ids; absl::flat_hash_set<ObjectID> memory_object_ids;
GroupObjectIdsByStoreProvider(object_ids, &plasma_object_ids, &memory_object_ids); GroupObjectIdsByStoreProvider(object_ids, &plasma_object_ids, &memory_object_ids);
RAY_RETURN_NOT_OK(plasma_store_provider_->Delete(plasma_object_ids, local_only, RAY_RETURN_NOT_OK(plasma_store_provider_->Delete(plasma_object_ids, local_only,

View file

@ -2,6 +2,7 @@
#define RAY_CORE_WORKER_CORE_WORKER_H #define RAY_CORE_WORKER_CORE_WORKER_H
#include "absl/base/thread_annotations.h" #include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
@ -421,7 +422,7 @@ class CoreWorker {
std::unique_ptr<CoreWorkerDirectActorTaskSubmitter> direct_actor_submitter_; std::unique_ptr<CoreWorkerDirectActorTaskSubmitter> direct_actor_submitter_;
/// Map from actor ID to a handle to that actor. /// Map from actor ID to a handle to that actor.
std::unordered_map<ActorID, std::unique_ptr<ActorHandle>> actor_handles_; absl::flat_hash_map<ActorID, std::unique_ptr<ActorHandle>> actor_handles_;
/* Fields related to task execution. */ /* Fields related to task execution. */

View file

@ -9,10 +9,10 @@ namespace ray {
/// A class that represents a `Get` request. /// A class that represents a `Get` request.
class GetRequest { class GetRequest {
public: public:
GetRequest(std::unordered_set<ObjectID> object_ids, size_t num_objects, GetRequest(absl::flat_hash_set<ObjectID> object_ids, size_t num_objects,
bool remove_after_get); bool remove_after_get);
const std::unordered_set<ObjectID> &ObjectIds() const; const absl::flat_hash_set<ObjectID> &ObjectIds() const;
/// Wait until all requested objects are available, or timeout happens. /// Wait until all requested objects are available, or timeout happens.
/// ///
@ -31,9 +31,9 @@ class GetRequest {
void Wait(); void Wait();
/// The object IDs involved in this request. /// The object IDs involved in this request.
const std::unordered_set<ObjectID> object_ids_; const absl::flat_hash_set<ObjectID> object_ids_;
/// The object information for the objects in this request. /// The object information for the objects in this request.
std::unordered_map<ObjectID, std::shared_ptr<RayObject>> objects_; absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> objects_;
/// Number of objects required. /// Number of objects required.
const size_t num_objects_; const size_t num_objects_;
@ -46,7 +46,7 @@ class GetRequest {
std::condition_variable cv_; std::condition_variable cv_;
}; };
GetRequest::GetRequest(std::unordered_set<ObjectID> object_ids, size_t num_objects, GetRequest::GetRequest(absl::flat_hash_set<ObjectID> object_ids, size_t num_objects,
bool remove_after_get) bool remove_after_get)
: object_ids_(std::move(object_ids)), : object_ids_(std::move(object_ids)),
num_objects_(num_objects), num_objects_(num_objects),
@ -55,7 +55,7 @@ GetRequest::GetRequest(std::unordered_set<ObjectID> object_ids, size_t num_objec
RAY_CHECK(num_objects_ <= object_ids_.size()); RAY_CHECK(num_objects_ <= object_ids_.size());
} }
const std::unordered_set<ObjectID> &GetRequest::ObjectIds() const { return object_ids_; } const absl::flat_hash_set<ObjectID> &GetRequest::ObjectIds() const { return object_ids_; }
bool GetRequest::ShouldRemoveObjects() const { return remove_after_get_; } bool GetRequest::ShouldRemoveObjects() const { return remove_after_get_; }
@ -144,8 +144,8 @@ Status CoreWorkerMemoryStore::Get(const std::vector<ObjectID> &object_ids,
std::shared_ptr<GetRequest> get_request; std::shared_ptr<GetRequest> get_request;
{ {
std::unordered_set<ObjectID> remaining_ids; absl::flat_hash_set<ObjectID> remaining_ids;
std::unordered_set<ObjectID> ids_to_remove; absl::flat_hash_set<ObjectID> ids_to_remove;
std::unique_lock<std::mutex> lock(lock_); std::unique_lock<std::mutex> lock(lock_);
// Check for existing objects and see if this get request can be fullfilled. // Check for existing objects and see if this get request can be fullfilled.

View file

@ -1,6 +1,8 @@
#ifndef RAY_CORE_WORKER_MEMORY_STORE_H #ifndef RAY_CORE_WORKER_MEMORY_STORE_H
#define RAY_CORE_WORKER_MEMORY_STORE_H #define RAY_CORE_WORKER_MEMORY_STORE_H
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "ray/common/id.h" #include "ray/common/id.h"
#include "ray/common/status.h" #include "ray/common/status.h"
#include "ray/core_worker/common.h" #include "ray/core_worker/common.h"
@ -45,10 +47,10 @@ class CoreWorkerMemoryStore {
private: private:
/// Map from object ID to `RayObject`. /// Map from object ID to `RayObject`.
std::unordered_map<ObjectID, std::shared_ptr<RayObject>> objects_; absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> objects_;
/// Map from object ID to its get requests. /// Map from object ID to its get requests.
std::unordered_map<ObjectID, std::vector<std::shared_ptr<GetRequest>>> absl::flat_hash_map<ObjectID, std::vector<std::shared_ptr<GetRequest>>>
object_get_requests_; object_get_requests_;
/// Protect the two maps above. /// Protect the two maps above.

View file

@ -23,9 +23,9 @@ Status CoreWorkerMemoryStoreProvider::Put(const RayObject &object,
} }
Status CoreWorkerMemoryStoreProvider::Get( Status CoreWorkerMemoryStoreProvider::Get(
const std::unordered_set<ObjectID> &object_ids, int64_t timeout_ms, const absl::flat_hash_set<ObjectID> &object_ids, int64_t timeout_ms,
const TaskID &task_id, const TaskID &task_id,
std::unordered_map<ObjectID, std::shared_ptr<RayObject>> *results, absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results,
bool *got_exception) { bool *got_exception) {
const std::vector<ObjectID> id_vector(object_ids.begin(), object_ids.end()); const std::vector<ObjectID> id_vector(object_ids.begin(), object_ids.end());
std::vector<std::shared_ptr<RayObject>> result_objects; std::vector<std::shared_ptr<RayObject>> result_objects;
@ -43,10 +43,9 @@ Status CoreWorkerMemoryStoreProvider::Get(
return Status::OK(); return Status::OK();
} }
Status CoreWorkerMemoryStoreProvider::Wait(const std::unordered_set<ObjectID> &object_ids, Status CoreWorkerMemoryStoreProvider::Wait(
int num_objects, int64_t timeout_ms, const absl::flat_hash_set<ObjectID> &object_ids, int num_objects, int64_t timeout_ms,
const TaskID &task_id, const TaskID &task_id, absl::flat_hash_set<ObjectID> *ready) {
std::unordered_set<ObjectID> *ready) {
std::vector<ObjectID> id_vector(object_ids.begin(), object_ids.end()); std::vector<ObjectID> id_vector(object_ids.begin(), object_ids.end());
std::vector<std::shared_ptr<RayObject>> result_objects; std::vector<std::shared_ptr<RayObject>> result_objects;
RAY_CHECK(object_ids.size() == id_vector.size()); RAY_CHECK(object_ids.size() == id_vector.size());
@ -63,7 +62,7 @@ Status CoreWorkerMemoryStoreProvider::Wait(const std::unordered_set<ObjectID> &o
} }
Status CoreWorkerMemoryStoreProvider::Delete( Status CoreWorkerMemoryStoreProvider::Delete(
const std::unordered_set<ObjectID> &object_ids) { const absl::flat_hash_set<ObjectID> &object_ids) {
std::vector<ObjectID> object_id_vector(object_ids.begin(), object_ids.end()); std::vector<ObjectID> object_id_vector(object_ids.begin(), object_ids.end());
store_->Delete(object_id_vector); store_->Delete(object_id_vector);
return Status::OK(); return Status::OK();

View file

@ -1,6 +1,8 @@
#ifndef RAY_CORE_WORKER_MEMORY_STORE_PROVIDER_H #ifndef RAY_CORE_WORKER_MEMORY_STORE_PROVIDER_H
#define RAY_CORE_WORKER_MEMORY_STORE_PROVIDER_H #define RAY_CORE_WORKER_MEMORY_STORE_PROVIDER_H
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "ray/common/buffer.h" #include "ray/common/buffer.h"
#include "ray/common/id.h" #include "ray/common/id.h"
#include "ray/common/status.h" #include "ray/common/status.h"
@ -21,18 +23,18 @@ class CoreWorkerMemoryStoreProvider {
Status Put(const RayObject &object, const ObjectID &object_id); Status Put(const RayObject &object, const ObjectID &object_id);
Status Get(const std::unordered_set<ObjectID> &object_ids, int64_t timeout_ms, Status Get(const absl::flat_hash_set<ObjectID> &object_ids, int64_t timeout_ms,
const TaskID &task_id, const TaskID &task_id,
std::unordered_map<ObjectID, std::shared_ptr<RayObject>> *results, absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results,
bool *got_exception); bool *got_exception);
/// Note that `num_objects` must equal to number of items in `object_ids`. /// Note that `num_objects` must equal to number of items in `object_ids`.
Status Wait(const std::unordered_set<ObjectID> &object_ids, int num_objects, Status Wait(const absl::flat_hash_set<ObjectID> &object_ids, int num_objects,
int64_t timeout_ms, const TaskID &task_id, int64_t timeout_ms, const TaskID &task_id,
std::unordered_set<ObjectID> *ready); absl::flat_hash_set<ObjectID> *ready);
/// Note that `local_only` must be true, and `delete_creating_tasks` must be false here. /// Note that `local_only` must be true, and `delete_creating_tasks` must be false here.
Status Delete(const std::unordered_set<ObjectID> &object_ids); Status Delete(const absl::flat_hash_set<ObjectID> &object_ids);
private: private:
/// Implementation. /// Implementation.

View file

@ -81,9 +81,9 @@ Status CoreWorkerPlasmaStoreProvider::Seal(const ObjectID &object_id) {
} }
Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore( Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore(
std::unordered_set<ObjectID> &remaining, const std::vector<ObjectID> &batch_ids, absl::flat_hash_set<ObjectID> &remaining, const std::vector<ObjectID> &batch_ids,
int64_t timeout_ms, bool fetch_only, const TaskID &task_id, int64_t timeout_ms, bool fetch_only, const TaskID &task_id,
std::unordered_map<ObjectID, std::shared_ptr<RayObject>> *results, absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results,
bool *got_exception) { bool *got_exception) {
RAY_RETURN_NOT_OK(raylet_client_->FetchOrReconstruct(batch_ids, fetch_only, task_id)); RAY_RETURN_NOT_OK(raylet_client_->FetchOrReconstruct(batch_ids, fetch_only, task_id));
@ -125,13 +125,13 @@ Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore(
} }
Status CoreWorkerPlasmaStoreProvider::Get( Status CoreWorkerPlasmaStoreProvider::Get(
const std::unordered_set<ObjectID> &object_ids, int64_t timeout_ms, const absl::flat_hash_set<ObjectID> &object_ids, int64_t timeout_ms,
const TaskID &task_id, const TaskID &task_id,
std::unordered_map<ObjectID, std::shared_ptr<RayObject>> *results, absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results,
bool *got_exception) { bool *got_exception) {
int64_t batch_size = RayConfig::instance().worker_fetch_request_size(); int64_t batch_size = RayConfig::instance().worker_fetch_request_size();
std::vector<ObjectID> batch_ids; std::vector<ObjectID> batch_ids;
std::unordered_set<ObjectID> remaining(object_ids.begin(), object_ids.end()); absl::flat_hash_set<ObjectID> remaining(object_ids.begin(), object_ids.end());
// First, attempt to fetch all of the required objects once without reconstructing. // First, attempt to fetch all of the required objects once without reconstructing.
std::vector<ObjectID> id_vector(object_ids.begin(), object_ids.end()); std::vector<ObjectID> id_vector(object_ids.begin(), object_ids.end());
@ -206,10 +206,9 @@ Status CoreWorkerPlasmaStoreProvider::Contains(const ObjectID &object_id,
return Status::OK(); return Status::OK();
} }
Status CoreWorkerPlasmaStoreProvider::Wait(const std::unordered_set<ObjectID> &object_ids, Status CoreWorkerPlasmaStoreProvider::Wait(
int num_objects, int64_t timeout_ms, const absl::flat_hash_set<ObjectID> &object_ids, int num_objects, int64_t timeout_ms,
const TaskID &task_id, const TaskID &task_id, absl::flat_hash_set<ObjectID> *ready) {
std::unordered_set<ObjectID> *ready) {
std::vector<ObjectID> id_vector(object_ids.begin(), object_ids.end()); std::vector<ObjectID> id_vector(object_ids.begin(), object_ids.end());
bool should_break = false; bool should_break = false;
@ -240,7 +239,7 @@ Status CoreWorkerPlasmaStoreProvider::Wait(const std::unordered_set<ObjectID> &o
} }
Status CoreWorkerPlasmaStoreProvider::Delete( Status CoreWorkerPlasmaStoreProvider::Delete(
const std::unordered_set<ObjectID> &object_ids, bool local_only, const absl::flat_hash_set<ObjectID> &object_ids, bool local_only,
bool delete_creating_tasks) { bool delete_creating_tasks) {
std::vector<ObjectID> object_id_vector(object_ids.begin(), object_ids.end()); std::vector<ObjectID> object_id_vector(object_ids.begin(), object_ids.end());
return raylet_client_->FreeObjects(object_id_vector, local_only, delete_creating_tasks); return raylet_client_->FreeObjects(object_id_vector, local_only, delete_creating_tasks);
@ -252,7 +251,7 @@ std::string CoreWorkerPlasmaStoreProvider::MemoryUsageString() {
} }
void CoreWorkerPlasmaStoreProvider::WarnIfAttemptedTooManyTimes( void CoreWorkerPlasmaStoreProvider::WarnIfAttemptedTooManyTimes(
int num_attempts, const std::unordered_set<ObjectID> &remaining) { int num_attempts, const absl::flat_hash_set<ObjectID> &remaining) {
if (num_attempts % RayConfig::instance().object_store_get_warn_per_num_attempts() == if (num_attempts % RayConfig::instance().object_store_get_warn_per_num_attempts() ==
0) { 0) {
std::ostringstream oss; std::ostringstream oss;

View file

@ -1,6 +1,8 @@
#ifndef RAY_CORE_WORKER_PLASMA_STORE_PROVIDER_H #ifndef RAY_CORE_WORKER_PLASMA_STORE_PROVIDER_H
#define RAY_CORE_WORKER_PLASMA_STORE_PROVIDER_H #define RAY_CORE_WORKER_PLASMA_STORE_PROVIDER_H
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "plasma/client.h" #include "plasma/client.h"
#include "ray/common/buffer.h" #include "ray/common/buffer.h"
#include "ray/common/id.h" #include "ray/common/id.h"
@ -33,18 +35,18 @@ class CoreWorkerPlasmaStoreProvider {
Status Seal(const ObjectID &object_id); Status Seal(const ObjectID &object_id);
Status Get(const std::unordered_set<ObjectID> &object_ids, int64_t timeout_ms, Status Get(const absl::flat_hash_set<ObjectID> &object_ids, int64_t timeout_ms,
const TaskID &task_id, const TaskID &task_id,
std::unordered_map<ObjectID, std::shared_ptr<RayObject>> *results, absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results,
bool *got_exception); bool *got_exception);
Status Contains(const ObjectID &object_id, bool *has_object); Status Contains(const ObjectID &object_id, bool *has_object);
Status Wait(const std::unordered_set<ObjectID> &object_ids, int num_objects, Status Wait(const absl::flat_hash_set<ObjectID> &object_ids, int num_objects,
int64_t timeout_ms, const TaskID &task_id, int64_t timeout_ms, const TaskID &task_id,
std::unordered_set<ObjectID> *ready); absl::flat_hash_set<ObjectID> *ready);
Status Delete(const std::unordered_set<ObjectID> &object_ids, bool local_only, Status Delete(const absl::flat_hash_set<ObjectID> &object_ids, bool local_only,
bool delete_creating_tasks); bool delete_creating_tasks);
std::string MemoryUsageString(); std::string MemoryUsageString();
@ -67,9 +69,9 @@ class CoreWorkerPlasmaStoreProvider {
/// exception. /// exception.
/// \return Status. /// \return Status.
Status FetchAndGetFromPlasmaStore( Status FetchAndGetFromPlasmaStore(
std::unordered_set<ObjectID> &remaining, const std::vector<ObjectID> &batch_ids, absl::flat_hash_set<ObjectID> &remaining, const std::vector<ObjectID> &batch_ids,
int64_t timeout_ms, bool fetch_only, const TaskID &task_id, int64_t timeout_ms, bool fetch_only, const TaskID &task_id,
std::unordered_map<ObjectID, std::shared_ptr<RayObject>> *results, absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> *results,
bool *got_exception); bool *got_exception);
/// Print a warning if we've attempted too many times, but some objects are still /// Print a warning if we've attempted too many times, but some objects are still
@ -78,7 +80,7 @@ class CoreWorkerPlasmaStoreProvider {
/// \param[in] num_attemps The number of attempted times. /// \param[in] num_attemps The number of attempted times.
/// \param[in] remaining The remaining objects. /// \param[in] remaining The remaining objects.
static void WarnIfAttemptedTooManyTimes(int num_attempts, static void WarnIfAttemptedTooManyTimes(int num_attempts,
const std::unordered_set<ObjectID> &remaining); const absl::flat_hash_set<ObjectID> &remaining);
const std::unique_ptr<RayletClient> &raylet_client_; const std::unique_ptr<RayletClient> &raylet_client_;
plasma::PlasmaClient store_client_; plasma::PlasmaClient store_client_;

View file

@ -2,6 +2,8 @@
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "ray/common/buffer.h" #include "ray/common/buffer.h"
#include "ray/common/ray_object.h" #include "ray/common/ray_object.h"
#include "ray/core_worker/context.h" #include "ray/core_worker/context.h"
@ -644,8 +646,8 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
RAY_CHECK_OK(provider.Put(buffers[i], ids[i])); RAY_CHECK_OK(provider.Put(buffers[i], ids[i]));
} }
std::unordered_set<ObjectID> wait_ids(ids.begin(), ids.end()); absl::flat_hash_set<ObjectID> wait_ids(ids.begin(), ids.end());
std::unordered_set<ObjectID> wait_results; absl::flat_hash_set<ObjectID> wait_results;
ObjectID nonexistent_id = ObjectID::FromRandom(); ObjectID nonexistent_id = ObjectID::FromRandom();
wait_ids.insert(nonexistent_id); wait_ids.insert(nonexistent_id);
@ -662,8 +664,8 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
// Test Get(). // Test Get().
bool got_exception = false; bool got_exception = false;
std::unordered_map<ObjectID, std::shared_ptr<RayObject>> results; absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> results;
std::unordered_set<ObjectID> ids_set(ids.begin(), ids.end()); absl::flat_hash_set<ObjectID> ids_set(ids.begin(), ids.end());
RAY_CHECK_OK(provider.Get(ids_set, -1, RandomTaskId(), &results, &got_exception)); RAY_CHECK_OK(provider.Get(ids_set, -1, RandomTaskId(), &results, &got_exception));
ASSERT_TRUE(!got_exception); ASSERT_TRUE(!got_exception);