Revert "Revert "Removing Pyarrow dependency (#7146)" (#7209) (#7214)

This commit is contained in:
Simon Mo 2020-02-19 10:08:52 -08:00 committed by GitHub
parent f76ce836b2
commit e8941b1b79
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 149 additions and 378 deletions

View file

@ -117,14 +117,6 @@ fi
pushd "$BUILD_DIR"
# The following line installs pyarrow from S3, these wheels have been
# generated from https://github.com/ray-project/arrow-build from
# the commit listed in the command.
if [ -z "$SKIP_THIRDPARTY_INSTALL" ]; then
"$PYTHON_EXECUTABLE" -m pip install -q \
--target="$ROOT_DIR/python/ray/pyarrow_files" pyarrow==0.14.0.RAY \
--find-links https://s3-us-west-2.amazonaws.com/arrow-wheels/3a11193d9530fe8ec7fdb98057f853b708f6f6ae/index.html
fi
WORK_DIR=`mktemp -d`
pushd $WORK_DIR

View file

@ -79,7 +79,6 @@ YAPF_FLAGS=(
YAPF_EXCLUDES=(
'--exclude' 'python/ray/cloudpickle/*'
'--exclude' 'python/build/*'
'--exclude' 'python/ray/pyarrow_files/*'
'--exclude' 'python/ray/core/src/ray/gcs/*'
'--exclude' 'python/ray/thirdparty_files/*'
)
@ -145,6 +144,7 @@ fi
# Ensure import ordering
# Make sure that for every import psutil; import setpproctitle
# There's a import ray above it.
python ci/travis/check_import_order.py . -s ci -s python/ray/pyarrow_files -s python/ray/thirdparty_files -s python/build
if ! git diff --quiet &>/dev/null; then

View file

@ -10,7 +10,6 @@ mock
numpy
opencv-python-headless
pandas
pyarrow
pygments
psutil
pyyaml

View file

@ -38,24 +38,8 @@ if os.path.exists(so_path):
from ctypes import CDLL
CDLL(so_path, ctypes.RTLD_GLOBAL)
# MUST import ray._raylet before pyarrow to initialize some global variables.
# It seems the library related to memory allocation in pyarrow will destroy the
# initialization of grpc if we import pyarrow at first.
# NOTE(JoeyJiang): See https://github.com/ray-project/ray/issues/5219 for more
# details.
import ray._raylet # noqa: E402
if "pyarrow" in sys.modules:
raise ImportError("Ray must be imported before pyarrow because Ray "
"requires a specific version of pyarrow (which is "
"packaged along with Ray).")
# Add the directory containing pyarrow to the Python path so that we find the
# pyarrow version packaged with ray and not a pre-existing pyarrow.
pyarrow_path = os.path.join(
os.path.abspath(os.path.dirname(__file__)), "pyarrow_files")
sys.path.insert(0, pyarrow_path)
# See https://github.com/ray-project/ray/issues/131.
helpful_message = """
@ -64,37 +48,6 @@ If you are using Anaconda, try fixing this problem by running:
conda install libgcc
"""
try:
import pyarrow # noqa: F401
# pyarrow is not imported inside of _raylet because of the issue described
# above. In order for Cython to compile _raylet, pyarrow is set to None
# in _raylet instead, so we give _raylet a real reference to it here.
# We first do the attribute checks here so that building the documentation
# succeeds without fully installing ray..
# TODO(edoakes): Fix this.
if hasattr(ray, "_raylet") and hasattr(ray._raylet, "pyarrow"):
ray._raylet.pyarrow = pyarrow
except ImportError as e:
if ((hasattr(e, "msg") and isinstance(e.msg, str)
and ("libstdc++" in e.msg or "CXX" in e.msg))):
# This code path should be taken with Python 3.
e.msg += helpful_message
elif (hasattr(e, "message") and isinstance(e.message, str)
and ("libstdc++" in e.message or "CXX" in e.message)):
# This code path should be taken with Python 2.
condition = (hasattr(e, "args") and isinstance(e.args, tuple)
and len(e.args) == 1 and isinstance(e.args[0], str))
if condition:
e.args = (e.args[0] + helpful_message, )
else:
if not hasattr(e, "args"):
e.args = ()
elif not isinstance(e.args, tuple):
e.args = (e.args, )
e.args += (helpful_message, )
raise
from ray._raylet import (
ActorCheckpointID,
ActorClassID,

View file

@ -75,6 +75,7 @@ cdef class CoreWorker:
unique_ptr[CCoreWorker] core_worker
object async_thread
object async_event_loop
object plasma_event_handler
cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata,
size_t data_size, ObjectID object_id,

View file

@ -98,14 +98,6 @@ from ray.ray_constants import (
DEFAULT_PUT_OBJECT_RETRIES,
)
# pyarrow cannot be imported until after _raylet finishes initializing
# (see ray/__init__.py for details).
# Unfortunately, Cython won't compile if 'pyarrow' is undefined, so we
# "forward declare" it here and then replace it with a reference to the
# imported package from ray/__init__.py.
# TODO(edoakes): Fix this.
pyarrow = None
cimport cpython
include "includes/unique_ids.pxi"
@ -552,6 +544,14 @@ cdef CRayStatus task_execution_handler(
return CRayStatus.OK()
cdef void async_plasma_callback(CObjectID object_id,
int64_t data_size,
int64_t metadata_size) with gil:
message = [tuple([ObjectID(object_id.Binary()), data_size, metadata_size])]
core_worker = ray.worker.global_worker.core_worker
event_handler = core_worker.get_plasma_event_handler()
if event_handler is not None:
event_handler.process_notifications(message)
cdef CRayStatus check_signals() nogil:
with gil:
@ -574,17 +574,20 @@ cdef shared_ptr[CBuffer] string_to_buffer(c_string& c_str):
cdef write_serialized_object(
serialized_object, const shared_ptr[CBuffer]& buf):
# avoid initializing pyarrow before raylet
from ray.serialization import Pickle5SerializedObject, RawSerializedObject
if isinstance(serialized_object, RawSerializedObject):
if buf.get() != NULL and buf.get().Size() > 0:
buffer = Buffer.make(buf)
# `Buffer` has a nullptr buffer underlying if size is 0,
# which will cause `pyarrow.py_buffer` crash
stream = pyarrow.FixedSizeBufferWriter(pyarrow.py_buffer(buffer))
stream.set_memcopy_threads(MEMCOPY_THREADS)
stream.write(pyarrow.py_buffer(serialized_object.value))
size = serialized_object.total_bytes
if MEMCOPY_THREADS > 1 and size > kMemcopyDefaultThreshold:
parallel_memcopy(buf.get().Data(),
<const uint8_t*> serialized_object.value,
size, kMemcopyDefaultBlocksize,
MEMCOPY_THREADS)
else:
memcpy(buf.get().Data(),
<const uint8_t*>serialized_object.value, size)
elif isinstance(serialized_object, Pickle5SerializedObject):
(<Pickle5Writer>serialized_object.writer).write_to(
serialized_object.inband, buf, MEMCOPY_THREADS)
@ -597,9 +600,6 @@ cdef class CoreWorker:
def __cinit__(self, is_driver, store_socket, raylet_socket,
JobID job_id, GcsClientOptions gcs_options, log_dir,
node_ip_address, node_manager_port):
assert pyarrow is not None, ("Expected pyarrow to be imported from "
"outside _raylet. See __init__.py for "
"details.")
self.core_worker.reset(new CCoreWorker(
WORKER_TYPE_DRIVER if is_driver else WORKER_TYPE_WORKER,
@ -628,6 +628,13 @@ cdef class CoreWorker:
def set_actor_title(self, title):
self.core_worker.get().SetActorTitle(title)
def subscribe_to_plasma(self, plasma_event_handler):
self.plasma_event_handler = plasma_event_handler
self.core_worker.get().SubscribeToAsyncPlasma(async_plasma_callback)
def get_plasma_event_handler(self):
return self.plasma_event_handler
def get_objects(self, object_ids, TaskID current_task_id,
int64_t timeout_ms=-1):
cdef:

View file

@ -162,7 +162,13 @@ class RayTimeoutError(RayError):
pass
class PlasmaObjectNotAvailable(RayError):
"""Called when an object was not available within the given timeout."""
pass
RAY_EXCEPTION_TYPES = [
PlasmaObjectNotAvailable,
RayError,
RayTaskError,
RayWorkerError,

View file

@ -1,84 +1,22 @@
# Note: asyncio is only compatible with Python 3
import asyncio
import functools
import threading
import pyarrow.plasma as plasma
import ray
from ray.experimental.async_plasma import PlasmaProtocol, PlasmaEventHandler
from ray.experimental.async_plasma import PlasmaEventHandler
from ray.services import logger
handler = None
transport = None
protocol = None
class _ThreadSafeProxy:
"""This class is used to create a thread-safe proxy for a given object.
Every method call will be guarded with a lock.
Attributes:
orig_obj (object): the original object.
lock (threading.Lock): the lock object.
_wrapper_cache (dict): a cache from original object's methods to
the proxy methods.
"""
def __init__(self, orig_obj, lock):
self.orig_obj = orig_obj
self.lock = lock
self._wrapper_cache = {}
def __getattr__(self, attr):
orig_attr = getattr(self.orig_obj, attr)
if not callable(orig_attr):
# If the original attr is a field, just return it.
return orig_attr
else:
# If the orginal attr is a method,
# return a wrapper that guards the original method with a lock.
wrapper = self._wrapper_cache.get(attr)
if wrapper is None:
@functools.wraps(orig_attr)
def _wrapper(*args, **kwargs):
with self.lock:
return orig_attr(*args, **kwargs)
self._wrapper_cache[attr] = _wrapper
wrapper = _wrapper
return wrapper
def thread_safe_client(client, lock=None):
"""Create a thread-safe proxy which locks every method call
for the given client.
Args:
client: the client object to be guarded.
lock: the lock object that will be used to lock client's methods.
If None, a new lock will be used.
Returns:
A thread-safe proxy for the given client.
"""
if lock is None:
lock = threading.Lock()
return _ThreadSafeProxy(client, lock)
async def _async_init():
global handler, transport, protocol
global handler
if handler is None:
worker = ray.worker.global_worker
plasma_client = thread_safe_client(
plasma.connect(worker.node.plasma_store_socket_name, 300))
loop = asyncio.get_event_loop()
plasma_client.subscribe()
rsock = plasma_client.get_notification_socket()
handler = PlasmaEventHandler(loop, worker)
transport, protocol = await loop.create_connection(
lambda: PlasmaProtocol(plasma_client, handler), sock=rsock)
worker.core_worker.subscribe_to_plasma(handler)
logger.debug("AsyncPlasma Connection Created!")
@ -126,10 +64,7 @@ def shutdown():
Cancels all related tasks and all the socket transportation.
"""
global handler, transport, protocol
global handler
if handler is not None:
handler.close()
transport.close()
handler = None
transport = None
protocol = None

View file

@ -1,179 +1,13 @@
import asyncio
import ctypes
import sys
import pyarrow.plasma as plasma
import ray
from ray.services import logger
INT64_SIZE = ctypes.sizeof(ctypes.c_int64)
def _release_waiter(waiter, *_):
if not waiter.done():
waiter.set_result(None)
class PlasmaProtocol(asyncio.Protocol):
"""Protocol control for the asyncio connection."""
def __init__(self, plasma_client, plasma_event_handler):
self.plasma_client = plasma_client
self.plasma_event_handler = plasma_event_handler
self.transport = None
self._buffer = b""
def connection_made(self, transport):
self.transport = transport
def data_received(self, data):
self._buffer += data
messages = []
i = 0
while i + INT64_SIZE <= len(self._buffer):
msg_len = int.from_bytes(self._buffer[i:i + INT64_SIZE],
sys.byteorder)
if i + INT64_SIZE + msg_len > len(self._buffer):
break
i += INT64_SIZE
segment = self._buffer[i:i + msg_len]
i += msg_len
(object_ids, object_sizes,
metadata_sizes) = self.plasma_client.decode_notifications(segment)
assert len(object_ids) == len(object_sizes) == len(metadata_sizes)
for j in range(len(object_ids)):
messages.append((object_ids[j], object_sizes[j],
metadata_sizes[j]))
self._buffer = self._buffer[i:]
self.plasma_event_handler.process_notifications(messages)
def connection_lost(self, exc):
# The socket has been closed
logger.debug("PlasmaProtocol - connection lost.")
def eof_received(self):
logger.debug("PlasmaProtocol - EOF received.")
self.transport.close()
from collections import defaultdict
class PlasmaObjectFuture(asyncio.Future):
"""This class manages the lifecycle of a Future contains an object_id.
Note:
This Future is an item in an linked list.
Attributes:
object_id: The object_id this Future contains.
"""
def __init__(self, loop, object_id):
super().__init__(loop=loop)
self.object_id = object_id
self.prev = None
self.next = None
@property
def ray_object_id(self):
return ray.ObjectID(self.object_id.binary())
def __repr__(self):
return super().__repr__() + "{object_id=%s}" % self.object_id
class PlasmaObjectLinkedList(asyncio.Future):
"""This class is a doubly-linked list.
It holds a ObjectID and maintains futures assigned to the ObjectID.
Args:
loop: an event loop.
plain_object_id (plasma.ObjectID):
The plasma ObjectID this class holds.
"""
def __init__(self, loop, plain_object_id):
super().__init__(loop=loop)
assert isinstance(plain_object_id, plasma.ObjectID)
self.object_id = plain_object_id
self.head = None
self.tail = None
def append(self, future):
"""Append an object to the linked list.
Args:
future (PlasmaObjectFuture): A PlasmaObjectFuture instance.
"""
future.prev = self.tail
if self.tail is None:
assert self.head is None
self.head = future
else:
self.tail.next = future
self.tail = future
# Once done, it will be removed from the list.
future.add_done_callback(self.remove)
def remove(self, future):
"""Remove an object from the linked list.
Args:
future (PlasmaObjectFuture): A PlasmaObjectFuture instance.
"""
if self._loop.get_debug():
logger.debug("Removing %s from the linked list.", future)
if future.prev is None:
assert future is self.head
self.head = future.next
if self.head is None:
self.tail = None
if not self.cancelled():
self.set_result(None)
else:
self.head.prev = None
elif future.next is None:
assert future is self.tail
self.tail = future.prev
if self.tail is None:
self.head = None
if not self.cancelled():
self.set_result(None)
else:
self.tail.prev = None
def cancel(self, *args, **kwargs):
"""Manually cancel all tasks assigned to this event loop."""
# Because remove all futures will trigger `set_result`,
# we cancel itself first.
super().cancel()
for future in self.traverse():
# All cancelled futures should have callbacks to removed itself
# from this linked list. However, these callbacks are scheduled in
# an event loop, so we could still find them in our list.
if not future.cancelled():
future.cancel()
def set_result(self, result):
"""Complete all tasks. """
for future in self.traverse():
# All cancelled futures should have callbacks to removed itself
# from this linked list. However, these callbacks are scheduled in
# an event loop, so we could still find them in our list.
future.set_result(result)
if not self.done():
super().set_result(result)
def traverse(self):
"""Traverse this linked list.
Yields:
PlasmaObjectFuture: PlasmaObjectFuture instances.
"""
current = self.head
while current is not None:
yield current
current = current.next
"""This class is a wrapper for a Future on Plasma."""
pass
class PlasmaEventHandler:
@ -183,30 +17,46 @@ class PlasmaEventHandler:
super().__init__()
self._loop = loop
self._worker = worker
self._waiting_dict = {}
self._waiting_dict = defaultdict(list)
def process_notifications(self, messages):
"""Process notifications."""
for object_id, object_size, metadata_size in messages:
if object_size > 0 and object_id in self._waiting_dict:
linked_list = self._waiting_dict[object_id]
self._complete_future(linked_list)
self._complete_future(object_id)
def close(self):
"""Clean up this handler."""
for linked_list in self._waiting_dict.values():
linked_list.cancel()
# All cancelled linked lists should have callbacks to removed itself
# from the waiting dict. However, these callbacks are scheduled in
# an event loop, so we don't check them now.
for futures in self._waiting_dict.values():
for fut in futures:
fut.cancel()
def _unregister_callback(self, fut):
del self._waiting_dict[fut.object_id]
def _complete_future(self, ray_object_id):
# TODO(ilr): Consider race condition between popping from the
# waiting_dict and as_future appending to the waiting_dict's list.
logger.debug(
"Completing plasma futures for object id {}".format(ray_object_id))
def _complete_future(self, fut):
obj = self._worker.get_objects([ray.ObjectID(
fut.object_id.binary())])[0]
fut.set_result(obj)
obj = self._worker.get_objects([ray_object_id])[0]
futures = self._waiting_dict.pop(ray_object_id)
for fut in futures:
loop = fut._loop
def complete_closure():
try:
fut.set_result(obj)
except asyncio.InvalidStateError:
# Avoid issues where process_notifications
# and check_ready both get executed
logger.debug("Failed to set result for future {}."
"Most likely already set.".format(fut))
loop.call_soon_threadsafe(complete_closure)
def check_immediately(self, object_id):
ready, _ = ray.wait([object_id], timeout=0)
if ready:
self._complete_future(object_id)
def as_future(self, object_id, check_ready=True):
"""Turn an object_id into a Future object.
@ -219,25 +69,10 @@ class PlasmaEventHandler:
PlasmaObjectFuture: A future object that waits the object_id.
"""
if not isinstance(object_id, ray.ObjectID):
raise TypeError("Input should be an ObjectID.")
raise TypeError("Input should be a Ray ObjectID.")
plain_object_id = plasma.ObjectID(object_id.binary())
fut = PlasmaObjectFuture(loop=self._loop, object_id=plain_object_id)
future = PlasmaObjectFuture(loop=self._loop)
self._waiting_dict[object_id].append(future)
self.check_immediately(object_id)
if check_ready:
ready, _ = ray.wait([object_id], timeout=0)
if ready:
if self._loop.get_debug():
logger.debug("%s has been ready.", plain_object_id)
self._complete_future(fut)
return fut
if plain_object_id not in self._waiting_dict:
linked_list = PlasmaObjectLinkedList(self._loop, plain_object_id)
linked_list.add_done_callback(self._unregister_callback)
self._waiting_dict[plain_object_id] = linked_list
self._waiting_dict[plain_object_id].append(fut)
if self._loop.get_debug():
logger.debug("%s added to the waiting list.", fut)
return fut
return future

View file

@ -1,5 +1,6 @@
import asyncio
import time
import os
import pytest
@ -9,6 +10,7 @@ from ray.experimental import async_api
@pytest.fixture
def init():
os.environ["RAY_FORCE_DIRECT"] = "0"
ray.init(num_cpus=4)
async_api.init()
asyncio.get_event_loop().set_debug(False)

View file

@ -10,8 +10,8 @@ cdef class Buffer:
"""Cython wrapper class of C++ `ray::Buffer`.
This class implements the Python 'buffer protocol', which allows
us to use it for calls into pyarrow (and other Python libraries
down the line) without having to copy the data.
us to use it for calls into Python libraries without having to
copy the data.
See https://docs.python.org/3/c-api/buffer.html for details.
"""

View file

@ -45,6 +45,9 @@ ctypedef void (*ray_callback_function) \
(shared_ptr[CRayObject] result_object,
CObjectID object_id, void* user_data)
ctypedef void (*plasma_callback_function) \
(CObjectID object_id, int64_t data_size, int64_t metadata_size)
cdef extern from "ray/core_worker/profiling.h" nogil:
cdef cppclass CProfiler "ray::worker::Profiler":
void Start()
@ -194,3 +197,5 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
CRayStatus SetResource(const c_string &resource_name,
const double capacity,
const CClientID &client_Id)
void SubscribeToAsyncPlasma(plasma_callback_function callback)

View file

@ -3,14 +3,13 @@ import logging
import time
import threading
import pyarrow.plasma as plasma
import ray.cloudpickle as pickle
from ray import ray_constants, JobID
import ray.utils
from ray.utils import _random_string
from ray.gcs_utils import ErrorType
from ray.exceptions import (
PlasmaObjectNotAvailable,
RayActorError,
RayWorkerError,
UnreconstructableError,
@ -233,7 +232,7 @@ class SerializationContext:
if not self.use_pickle:
raise ValueError("Receiving pickle5 serialized objects "
"while the serialization context is "
"using pyarrow as the backend.")
"using a custom raw backend.")
try:
in_band, buffers = unpack_pickle5_buffers(data)
if len(buffers) > 0:
@ -270,7 +269,7 @@ class SerializationContext:
# to the user. We should only reach this line if this object was
# deserialized as part of a list, and another object in the list
# throws an exception.
return plasma.ObjectNotAvailable
return PlasmaObjectNotAvailable
def deserialize_objects(self,
data_metadata_pairs,

View file

@ -13,7 +13,6 @@ import time
import redis
import colorama
import pyarrow
# Ray modules
import ray
import ray.ray_constants as ray_constants
@ -507,22 +506,21 @@ def wait_for_redis_to_start(redis_ip_address,
def _compute_version_info():
"""Compute the versions of Python, pyarrow, and Ray.
"""Compute the versions of Python, and Ray.
Returns:
A tuple containing the version information.
"""
ray_version = ray.__version__
python_version = ".".join(map(str, sys.version_info[:3]))
pyarrow_version = pyarrow.__version__
return ray_version, python_version, pyarrow_version
return ray_version, python_version
def _put_version_info_in_redis(redis_client):
"""Store version information in Redis.
This will be used to detect if workers or drivers are started using
different versions of Python, pyarrow, or Ray.
different versions of Python, or Ray.
Args:
redis_client: A client for the primary Redis shard.
@ -534,7 +532,7 @@ def check_version_info(redis_client):
"""Check if various version info of this process is correct.
This will be used to detect if workers or drivers are started using
different versions of Python, pyarrow, or Ray. If the version
different versions of Python, or Ray. If the version
information is not present in Redis, then no check is done.
Args:
@ -557,12 +555,10 @@ def check_version_info(redis_client):
error_message = ("Version mismatch: The cluster was started with:\n"
" Ray: " + true_version_info[0] + "\n"
" Python: " + true_version_info[1] + "\n"
" Pyarrow: " + str(true_version_info[2]) + "\n"
"This process on node " + node_ip_address +
" was started with:" + "\n"
" Ray: " + version_info[0] + "\n"
" Python: " + version_info[1] + "\n"
" Pyarrow: " + str(version_info[2]))
" Python: " + version_info[1] + "\n")
if version_info[:2] != true_version_info[:2]:
raise Exception(error_message)
else:

View file

@ -92,7 +92,7 @@ extras["all"] = list(set(chain.from_iterable(extras.values())))
class build_ext(_build_ext.build_ext):
def run(self):
# Note: We are passing in sys.executable so that we use the same
# version of Python to build pyarrow inside the build.sh script. Note
# version of Python to build packages inside the build.sh script. Note
# that certain flags will not be passed along such as --user or sudo.
# TODO(rkn): Fix this.
command = ["../build.sh", "-p", sys.executable]
@ -101,18 +101,13 @@ class build_ext(_build_ext.build_ext):
command += ["-l", "python,java"]
subprocess.check_call(command)
# We also need to install pyarrow along with Ray, so make sure that the
# relevant non-Python pyarrow files get copied.
pyarrow_files = self.walk_directory("./ray/pyarrow_files/pyarrow")
# We also need to install pickle5 along with Ray, so make sure that the
# relevant non-Python pickle5 files get copied.
pickle5_files = self.walk_directory("./ray/pickle5_files/pickle5")
thirdparty_files = self.walk_directory("./ray/thirdparty_files")
files_to_include = ray_files + pyarrow_files + pickle5_files + \
thirdparty_files
files_to_include = ray_files + pickle5_files + thirdparty_files
# Copy over the autogenerated protobuf Python bindings.
for directory in generated_python_directories:

View file

@ -4,7 +4,7 @@ import logging
import time
import base64
import numpy as np
import pyarrow
from ray import cloudpickle as pickle
from six import string_types
logger = logging.getLogger(__name__)
@ -27,7 +27,7 @@ def compression_supported():
@DeveloperAPI
def pack(data):
if LZ4_ENABLED:
data = pyarrow.serialize(data).to_buffer().to_pybytes()
data = pickle.dumps(data)
data = lz4.frame.compress(data)
# TODO(ekl) we shouldn't need to base64 encode this data, but this
# seems to not survive a transfer through the object store if we don't.
@ -47,7 +47,7 @@ def unpack(data):
if LZ4_ENABLED:
data = base64.b64decode(data)
data = lz4.frame.decompress(data)
data = pyarrow.deserialize(data)
data = pickle.loads(data)
return data

View file

@ -103,7 +103,6 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
RayLog::InstallFailureSignalHandler();
}
RAY_LOG(INFO) << "Initializing worker " << worker_context_.GetWorkerID();
// Initialize gcs client.
gcs_client_ = std::make_shared<gcs::RedisGcsClient>(gcs_options);
RAY_CHECK_OK(gcs_client_->Connect(io_service_));
@ -247,9 +246,15 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
if (direct_task_receiver_ != nullptr) {
direct_task_receiver_->Init(client_factory, rpc_address_);
}
plasma_notifier_.reset(new ObjectStoreNotificationManager(io_service_, store_socket,
/*exit_on_error*/ false));
}
CoreWorker::~CoreWorker() {
// ObjectStoreNotificationManager depends on io_service_ so we need to shut it down
// first.
plasma_notifier_->Shutdown();
io_service_.stop();
io_thread_.join();
if (log_dir_ != "") {
@ -1342,6 +1347,14 @@ void CoreWorker::GetAsync(const ObjectID &object_id, SetResultCallback success_c
});
}
void CoreWorker::SubscribeToAsyncPlasma(PlasmaSubscriptionCallback subscribe_callback) {
plasma_notifier_->SubscribeObjAdded(
[subscribe_callback](const object_manager::protocol::ObjectInfoT &info) {
subscribe_callback(ObjectID::FromPlasmaIdBinary(info.object_id), info.data_size,
info.metadata_size);
});
}
void CoreWorker::SetActorId(const ActorID &actor_id) {
absl::MutexLock lock(&mutex_);
RAY_CHECK(actor_id_.IsNil());

View file

@ -18,6 +18,7 @@
#include "ray/core_worker/transport/raylet_transport.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/gcs/subscription_executor.h"
#include "ray/object_manager/object_store_notification_manager.h"
#include "ray/raylet/raylet_client.h"
#include "ray/rpc/node_manager/node_manager_client.h"
#include "ray/rpc/worker/core_worker_client.h"
@ -507,6 +508,15 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
void GetAsync(const ObjectID &object_id, SetResultCallback success_callback,
SetResultCallback fallback_callback, void *python_future);
/// Connect to plasma store for async futures
using PlasmaSubscriptionCallback = std::function<void(ray::ObjectID, int64_t, int64_t)>;
/// Subscribe to plasma store
///
/// \param[in] subscribe_callback The callback when an item is added to plasma.
/// \return void
void SubscribeToAsyncPlasma(PlasmaSubscriptionCallback subscribe_callback);
private:
/// Run the io_service_ event loop. This should be called in a background thread.
void RunIOService();
@ -766,6 +776,9 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
// Queue of tasks to resubmit when the specified time passes.
std::deque<std::pair<int64_t, TaskSpecification>> to_resubmit_ GUARDED_BY(mutex_);
// Plasma notification manager
std::unique_ptr<ObjectStoreNotificationManager> plasma_notifier_;
friend class CoreWorkerTest;
};

View file

@ -14,12 +14,14 @@
namespace ray {
ObjectStoreNotificationManager::ObjectStoreNotificationManager(
boost::asio::io_service &io_service, const std::string &store_socket_name)
boost::asio::io_service &io_service, const std::string &store_socket_name,
bool exit_on_error)
: store_client_(),
length_(0),
num_adds_processed_(0),
num_removes_processed_(0),
socket_(io_service) {
socket_(io_service),
exit_on_error_(exit_on_error) {
RAY_ARROW_CHECK_OK(store_client_.Connect(store_socket_name.c_str(), "", 0, 300));
int c_socket; // TODO(mehrdadn): This should be type SOCKET for Windows
@ -57,6 +59,10 @@ ObjectStoreNotificationManager::~ObjectStoreNotificationManager() {
RAY_ARROW_CHECK_OK(store_client_.Disconnect());
}
void ObjectStoreNotificationManager::Shutdown() {
RAY_ARROW_CHECK_OK(store_client_.Disconnect());
}
void ObjectStoreNotificationManager::NotificationWait() {
boost::asio::async_read(socket_, boost::asio::buffer(&length_, sizeof(length_)),
boost::bind(&ObjectStoreNotificationManager::ProcessStoreLength,
@ -66,7 +72,7 @@ void ObjectStoreNotificationManager::NotificationWait() {
void ObjectStoreNotificationManager::ProcessStoreLength(
const boost::system::error_code &error) {
notification_.resize(length_);
if (error) {
if (error && exit_on_error_) {
// When shutting down a cluster, it's possible that the plasma store is killed
// earlier than raylet, in this case we don't want raylet to crash, we instead
// log an error message and exit.
@ -75,7 +81,10 @@ void ObjectStoreNotificationManager::ProcessStoreLength(
<< ", most likely plasma store is down, raylet will exit";
// Exit raylet process.
_exit(kRayletStoreErrorExitCode);
} else {
return;
}
boost::asio::async_read(
socket_, boost::asio::buffer(notification_),
boost::bind(&ObjectStoreNotificationManager::ProcessStoreNotification, this,
@ -84,10 +93,13 @@ void ObjectStoreNotificationManager::ProcessStoreLength(
void ObjectStoreNotificationManager::ProcessStoreNotification(
const boost::system::error_code &error) {
if (error) {
if (error && exit_on_error_) {
RAY_LOG(FATAL)
<< "Problem communicating with the object store from raylet, check logs or "
<< "dmesg for previous errors: " << boost_to_ray_status(error).ToString();
} else {
return;
}
const auto &object_notification =

View file

@ -28,8 +28,11 @@ class ObjectStoreNotificationManager {
///
/// \param io_service The asio service to be used.
/// \param store_socket_name The store socket to connect to.
/// \param exit_on_error The manager will exit with error when it fails
/// to process messages from socket.
ObjectStoreNotificationManager(boost::asio::io_service &io_service,
const std::string &store_socket_name);
const std::string &store_socket_name,
bool exit_on_error = true);
~ObjectStoreNotificationManager();
@ -46,6 +49,9 @@ class ObjectStoreNotificationManager {
/// \param callback A callback expecting an ObjectID.
void SubscribeObjDeleted(std::function<void(const ray::ObjectID &)> callback);
/// Explicitly shutdown the manager.
void Shutdown();
/// Returns debug string for class.
///
/// \return string.
@ -71,6 +77,8 @@ class ObjectStoreNotificationManager {
int64_t num_removes_processed_;
std::vector<uint8_t> notification_;
local_stream_protocol::socket socket_;
bool exit_on_error_;
};
} // namespace ray