[Streaming] Streaming data transfer and python integration (#6185)

This commit is contained in:
Chaokun Yang 2019-12-10 20:33:24 +08:00 committed by Hao Chen
parent c1d4ab8bb4
commit 6272907a57
93 changed files with 8434 additions and 1480 deletions

4
.gitignore vendored
View file

@ -130,6 +130,7 @@ scripts/nodes.txt
# Pytest Cache
**/.pytest_cache
**/.cache
.benchmarks
# Vscode
@ -145,6 +146,9 @@ java/**/.classpath
java/**/.project
java/runtime/native_dependencies/
# streaming/python
streaming/python/generated/
# python virtual env
venv

View file

@ -34,6 +34,21 @@ matrix:
- if [ $RAY_CI_JAVA_AFFECTED != "1" ]; then exit; fi
- ./java/test.sh
- os: linux
env: BAZEL_PYTHON_VERSION=PY3 PYTHON=3.5 PYTHONWARNINGS=ignore TESTSUITE=streaming
install:
- python $TRAVIS_BUILD_DIR/ci/travis/determine_tests_to_run.py
- eval `python $TRAVIS_BUILD_DIR/ci/travis/determine_tests_to_run.py`
- if [ $RAY_CI_STREAMING_PYTHON_AFFECTED != "1" ]; then exit; fi
- ./ci/suppress_output ./ci/travis/install-bazel.sh
- ./ci/suppress_output ./ci/travis/install-dependencies.sh
- export PATH="$HOME/miniconda/bin:$PATH"
- ./ci/suppress_output ./ci/travis/install-ray.sh
script:
# Streaming cpp test.
- if [ $RAY_CI_STREAMING_CPP_AFFECTED == "1" ]; then ./ci/suppress_output bash streaming/src/test/run_streaming_queue_test.sh; fi
- if [ RAY_CI_STREAMING_PYTHON_AFFECTED == "1" ]; then python -m pytest -v --durations=5 --timeout=300 python/ray/streaming/tests/; fi
- os: linux
env: LINT=1 PYTHONWARNINGS=ignore
before_install:
@ -51,7 +66,7 @@ matrix:
- sphinx-build -W -b html -d _build/doctrees source _build/html
- cd ..
# Run Python linting, ignore dict vs {} (C408), others are defaults
- flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605
- flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,streaming/python/generated,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605
- ./ci/travis/format.sh --all
# Make sure that the README is formatted properly.
- cd python

View file

@ -7,6 +7,28 @@ load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library")
load("@com_github_grpc_grpc//bazel:cython_library.bzl", "pyx_library")
load("@rules_proto_grpc//python:defs.bzl", "python_grpc_compile")
load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
load("//bazel:ray.bzl", "if_linux_x86_64")
config_setting(
name = "windows",
values = {"cpu": "x64_windows"},
visibility = ["//visibility:public"],
)
config_setting(
name = "macos",
values = {
"apple_platform_type": "macos",
"cpu": "darwin",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "linux_x86_64",
values = {"cpu": "k8"},
visibility = ["//visibility:public"],
)
# TODO(mehrdadn): (How to) support dynamic linking?
PROPAGATED_WINDOWS_DEFINES = ["RAY_STATIC"]
@ -219,6 +241,7 @@ cc_library(
includes = [
"@boost//:asio",
],
visibility = ["//visibility:public"],
deps = [
":common_cc_proto",
":gcs_cc_proto",
@ -327,6 +350,7 @@ cc_library(
"-lpthread",
],
}),
visibility = ["//streaming:__subpackages__"],
deps = [
":common_cc_proto",
":gcs",
@ -373,6 +397,7 @@ cc_library(
"src/ray/core_worker/transport/*.h",
]),
copts = COPTS,
visibility = ["//visibility:public"],
deps = [
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
@ -659,6 +684,7 @@ cc_library(
includes = [
"src",
],
visibility = ["//visibility:public"],
deps = [
":sha256",
"@com_github_google_glog//:glog",
@ -782,15 +808,51 @@ pyx_library(
name = "_raylet",
srcs = glob([
"python/ray/__init__.py",
"python/ray/_raylet.pxd",
"python/ray/_raylet.pyx",
"python/ray/includes/*.pxd",
"python/ray/includes/*.pxi",
]),
copts = COPTS,
# Export ray ABI symbols, which can then be used by _streaming.so.
# We need to dlopen this lib with RTLD_GLOBAL to use ABI in this
# shared lib, see python/ray/__init__.py.
cc_kwargs = {
"linkstatic": 1,
# see https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/lite/BUILD#L444
"linkopts": select({
"//:macos": [
"-Wl,-exported_symbols_list,$(location //:src/ray/ray_exported_symbols.lds)",
],
"//:windows": [],
"//conditions:default": [
"-Wl,--version-script,$(location //:src/ray/ray_version_script.lds)",
],
}),
},
copts = COPTS + if_linux_x86_64(["-fno-gnu-unique"]),
deps = [
"//:core_worker_lib",
"//:raylet_lib",
"//:serialization_cc_proto",
"//:src/ray/ray_exported_symbols.lds",
"//:src/ray/ray_version_script.lds",
],
)
pyx_library(
name = "_streaming",
srcs = glob([
"python/ray/streaming/_streaming.pyx",
"python/ray/__init__.py",
"python/ray/_raylet.pxd",
"python/ray/includes/*.pxd",
"python/ray/includes/*.pxi",
"python/ray/streaming/__init__.pxd",
"python/ray/streaming/includes/*.pxd",
"python/ray/streaming/includes/*.pxi",
]),
deps = [
"//streaming:streaming_lib",
],
)
@ -922,6 +984,7 @@ genrule(
name = "ray_pkg",
srcs = [
"python/ray/_raylet.so",
"python/ray/streaming/_streaming.so",
"//:python_sources",
"//:all_py_proto",
"//:redis-server",
@ -930,12 +993,14 @@ genrule(
"//:raylet",
"//:raylet_monitor",
"@plasma//:plasma_store_server",
"//streaming:copy_streaming_py_proto",
],
outs = ["ray_pkg.out"],
cmd = """
set -x &&
WORK_DIR=$$(pwd) &&
cp -f $(location python/ray/_raylet.so) "$$WORK_DIR/python/ray" &&
cp -f $(location python/ray/streaming/_streaming.so) $$WORK_DIR/python/ray/streaming &&
mkdir -p "$$WORK_DIR/python/ray/core/src/ray/thirdparty/redis/src/" &&
cp -f $(location //:redis-server) "$$WORK_DIR/python/ray/core/src/ray/thirdparty/redis/src/" &&
cp -f $(location //:redis-cli) "$$WORK_DIR/python/ray/core/src/ray/thirdparty/redis/src/" &&

View file

@ -64,3 +64,9 @@ def define_java_module(
"{auto_gen_header}": "<!-- This file is auto-generated by Bazel from pom_template.xml, do not modify it. -->",
},
)
def if_linux_x86_64(a):
return select({
"//:linux_x86_64": a,
"//conditions:default": [],
})

View file

@ -44,6 +44,7 @@ while [[ $# > 0 ]]; do
done
pushd $ROOT_DIR/../..
BAZEL_FILES="bazel/BUILD bazel/BUILD.plasma bazel/ray.bzl BUILD.bazel WORKSPACE"
BAZEL_FILES="bazel/BUILD bazel/BUILD.plasma bazel/ray.bzl BUILD.bazel
streaming/BUILD.bazel WORKSPACE"
buildifier -mode=$RUN_TYPE -diff_command="diff -u" $BAZEL_FILES
popd

View file

@ -38,6 +38,8 @@ if __name__ == "__main__":
RAY_CI_PYTHON_AFFECTED = 0
RAY_CI_LINUX_WHEELS_AFFECTED = 0
RAY_CI_MACOS_WHEELS_AFFECTED = 0
RAY_CI_STREAMING_CPP_AFFECTED = 0
RAY_CI_STREAMING_PYTHON_AFFECTED = 0
if os.environ["TRAVIS_EVENT_TYPE"] == "pull_request":
@ -71,6 +73,7 @@ if __name__ == "__main__":
RAY_CI_PYTHON_AFFECTED = 1
RAY_CI_LINUX_WHEELS_AFFECTED = 1
RAY_CI_MACOS_WHEELS_AFFECTED = 1
RAY_CI_STREAMING_PYTHON_AFFECTED = 1
elif changed_file.startswith("java/"):
RAY_CI_JAVA_AFFECTED = 1
elif any(
@ -86,6 +89,13 @@ if __name__ == "__main__":
RAY_CI_PYTHON_AFFECTED = 1
RAY_CI_LINUX_WHEELS_AFFECTED = 1
RAY_CI_MACOS_WHEELS_AFFECTED = 1
RAY_CI_STREAMING_CPP_AFFECTED = 1
RAY_CI_STREAMING_PYTHON_AFFECTED = 1
elif changed_file.startswith("streaming/src"):
RAY_CI_STREAMING_CPP_AFFECTED = 1
RAY_CI_STREAMING_PYTHON_AFFECTED = 1
elif changed_file.startswith("streaming/python"):
RAY_CI_STREAMING_PYTHON_AFFECTED = 1
else:
RAY_CI_TUNE_AFFECTED = 1
RAY_CI_RLLIB_AFFECTED = 1
@ -94,6 +104,7 @@ if __name__ == "__main__":
RAY_CI_PYTHON_AFFECTED = 1
RAY_CI_LINUX_WHEELS_AFFECTED = 1
RAY_CI_MACOS_WHEELS_AFFECTED = 1
RAY_CI_STREAMING_CPP_AFFECTED = 1
else:
RAY_CI_TUNE_AFFECTED = 1
RAY_CI_RLLIB_AFFECTED = 1
@ -102,6 +113,7 @@ if __name__ == "__main__":
RAY_CI_PYTHON_AFFECTED = 1
RAY_CI_LINUX_WHEELS_AFFECTED = 1
RAY_CI_MACOS_WHEELS_AFFECTED = 1
RAY_CI_STREAMING_CPP_AFFECTED = 1
# Log the modified environment variables visible in console.
for output_stream in [sys.stdout, sys.stderr]:
@ -116,3 +128,7 @@ if __name__ == "__main__":
.format(RAY_CI_LINUX_WHEELS_AFFECTED))
_print("export RAY_CI_MACOS_WHEELS_AFFECTED={}"
.format(RAY_CI_MACOS_WHEELS_AFFECTED))
_print("export RAY_CI_STREAMING_CPP_AFFECTED={}"
.format(RAY_CI_STREAMING_CPP_AFFECTED))
_print("export RAY_CI_STREAMING_PYTHON_AFFECTED={}"
.format(RAY_CI_STREAMING_PYTHON_AFFECTED))

View file

@ -79,14 +79,14 @@ format_changed() {
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
if which flake8 >/dev/null; then
git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \
flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605
flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,streaming/python/generated,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605
fi
fi
if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then
if which flake8 >/dev/null; then
git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \
flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605
flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,streaming/python/generated,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605
fi
fi

View file

@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import os
from os.path import dirname
import sys
# MUST add pickle5 to the import path because it will be imported by some
@ -19,6 +20,14 @@ pickle5_path = os.path.join(
os.path.abspath(os.path.dirname(__file__)), "pickle5_files")
sys.path.insert(0, pickle5_path)
# Expose ray ABI symbols which may be dependent by other shared
# libraries such as _streaming.so. See BUILD.bazel:_raylet
so_path = os.path.join(dirname(__file__), "_raylet.so")
if os.path.exists(so_path):
import ctypes
from ctypes import CDLL
CDLL(so_path, ctypes.RTLD_GLOBAL)
# MUST import ray._raylet before pyarrow to initialize some global variables.
# It seems the library related to memory allocation in pyarrow will destroy the
# initialization of grpc if we import pyarrow at first.

70
python/ray/_raylet.pxd Normal file
View file

@ -0,0 +1,70 @@
# cython: profile=False
# distutils: language = c++
# cython: embedsignature = True
# cython: language_level = 3
from libcpp cimport bool as c_bool
from libcpp.string cimport string as c_string
from libcpp.vector cimport vector as c_vector
from libcpp.memory cimport (
shared_ptr,
unique_ptr
)
from ray.includes.common cimport (
CBuffer,
CRayObject
)
from ray.includes.libcoreworker cimport CCoreWorker
from ray.includes.unique_ids cimport (
CObjectID,
CActorID
)
cdef class Buffer:
cdef:
shared_ptr[CBuffer] buffer
Py_ssize_t shape
Py_ssize_t strides
@staticmethod
cdef make(const shared_ptr[CBuffer]& buffer)
cdef class BaseID:
# To avoid the error of "Python int too large to convert to C ssize_t",
# here `cdef size_t` is required.
cdef size_t hash(self)
cdef class ObjectID(BaseID):
cdef:
CObjectID data
object buffer_ref
# Flag indicating whether or not this object ID was added to the set
# of active IDs in the core worker so we know whether we should clean
# it up.
c_bool in_core_worker
cdef CObjectID native(self)
cdef class ActorID(BaseID):
cdef CActorID data
cdef CActorID native(self)
cdef size_t hash(self)
cdef class CoreWorker:
cdef:
unique_ptr[CCoreWorker] core_worker
object async_thread
object async_event_loop
cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata,
size_t data_size, ObjectID object_id,
CObjectID *c_object_id, shared_ptr[CBuffer] *data)
# TODO: handle noreturn better
cdef store_task_outputs(
self, worker, outputs, const c_vector[CObjectID] return_ids,
c_vector[shared_ptr[CRayObject]] *returns)
cdef c_vector[c_string] string_vector_from_list(list string_list)

View file

@ -41,6 +41,7 @@ from libcpp.vector cimport vector as c_vector
from cython.operator import dereference, postincrement
from ray.includes.common cimport (
CBuffer,
CAddress,
CLanguage,
CRayObject,
@ -346,13 +347,29 @@ cdef c_vector[c_string] string_vector_from_list(list string_list):
return out
cdef:
c_string pickle_metadata_str = PICKLE_BUFFER_METADATA
shared_ptr[CBuffer] pickle_metadata = dynamic_pointer_cast[
CBuffer, LocalMemoryBuffer](
make_shared[LocalMemoryBuffer](
<uint8_t*>(pickle_metadata_str.data()),
pickle_metadata_str.size(), True))
c_string raw_meta_str = RAW_BUFFER_METADATA
shared_ptr[CBuffer] raw_metadata = dynamic_pointer_cast[
CBuffer, LocalMemoryBuffer](
make_shared[LocalMemoryBuffer](
<uint8_t*>(raw_meta_str.data()),
raw_meta_str.size(), True))
cdef void prepare_args(list args, c_vector[CTaskArg] *args_vector):
cdef:
c_string pickled_str
c_string metadata_str = PICKLE_BUFFER_METADATA
const unsigned char[:] buffer
size_t size
shared_ptr[CBuffer] arg_data
shared_ptr[CBuffer] arg_metadata
# TODO be consistent with store_task_outputs
for arg in args:
if isinstance(arg, ObjectID):
args_vector.push_back(
@ -360,23 +377,25 @@ cdef void prepare_args(list args, c_vector[CTaskArg] *args_vector):
elif not ray._raylet.check_simple_value(arg):
args_vector.push_back(
CTaskArg.PassByReference((<ObjectID>ray.put(arg)).native()))
else:
pickled_str = pickle.dumps(
arg, protocol=pickle.HIGHEST_PROTOCOL)
# TODO(edoakes): This makes a copy that could be avoided.
elif type(arg) is bytes:
buffer = arg
size = buffer.nbytes
arg_data = dynamic_pointer_cast[CBuffer, LocalMemoryBuffer](
make_shared[LocalMemoryBuffer](
<uint8_t*>(pickled_str.data()),
pickled_str.size(),
True))
arg_metadata = dynamic_pointer_cast[
CBuffer, LocalMemoryBuffer](
make_shared[LocalMemoryBuffer](
<uint8_t*>(
metadata_str.data()), metadata_str.size(), True))
<uint8_t*>(&buffer[0]), size, True))
args_vector.push_back(
CTaskArg.PassByValue(
make_shared[CRayObject](arg_data, arg_metadata)))
make_shared[CRayObject](arg_data, raw_metadata)))
else:
buffer = pickle.dumps(
arg, protocol=pickle.HIGHEST_PROTOCOL)
size = buffer.nbytes
arg_data = dynamic_pointer_cast[CBuffer, LocalMemoryBuffer](
make_shared[LocalMemoryBuffer](
<uint8_t*>(&buffer[0]), size, True))
args_vector.push_back(
CTaskArg.PassByValue(
make_shared[CRayObject](arg_data, pickle_metadata)))
cdef class RayletClient:
@ -738,10 +757,6 @@ cdef write_serialized_object(
cdef class CoreWorker:
cdef:
unique_ptr[CCoreWorker] core_worker
object async_thread
object async_event_loop
def __cinit__(self, is_driver, store_socket, raylet_socket,
JobID job_id, GcsClientOptions gcs_options, log_dir,
@ -1085,7 +1100,6 @@ cdef class CoreWorker:
c_vector[shared_ptr[CRayObject]] *returns):
cdef:
c_vector[size_t] data_sizes
c_string metadata_str
c_vector[shared_ptr[CBuffer]] metadatas
if return_ids.size() == 0:

View file

@ -1,216 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
import threading
import time
import ray
from ray.experimental import internal_kv
logger = logging.getLogger(__name__)
logger.setLevel("INFO")
def plasma_prefetch(object_id):
"""Tells plasma to prefetch the given object_id."""
local_sched_client = ray.worker.global_worker.raylet_client
ray_obj_id = ray.ObjectID(object_id)
local_sched_client.fetch_or_reconstruct([ray_obj_id], True)
# TODO: doing the timer in Python land is a bit slow
class FlushThread(threading.Thread):
"""A thread that flushes periodically to plasma.
Attributes:
interval: The flush timeout per batch.
flush_fn: The flush function.
"""
def __init__(self, interval, flush_fn):
threading.Thread.__init__(self)
self.interval = interval # Interval is the max_batch_time
self.flush_fn = flush_fn
self.daemon = True
def run(self):
while True:
time.sleep(self.interval) # Flushing period
self.flush_fn()
class BatchedQueue(object):
"""A batched queue for actor to actor communication.
Attributes:
max_size (int): The maximum size of the queue in number of batches
(if exceeded, backpressure kicks in)
max_batch_size (int): The size of each batch in number of records.
max_batch_time (float): The flush timeout per batch.
prefetch_depth (int): The number of batches to prefetch from plasma.
background_flush (bool): Denotes whether a daemon flush thread should
be used (True) to flush batches to plasma.
base (ndarray): A unique signature for the queue.
read_ack_key (bytes): The signature of the queue in bytes.
prefetch_batch_offset (int): The number of the last read prefetched
batch.
read_batch_offset (int): The number of the last read batch.
read_item_offset (int): The number of the last read record inside a
batch.
write_batch_offset (int): The number of the last written batch.
write_item_offset (int): The numebr of the last written item inside a
batch.
write_buffer (list): The write buffer, i.e. an in-memory batch.
last_flush_time (float): The time the last flushing to plasma took
place.
cached_remote_offset (int): The number of the last read batch as
recorded by the writer after the previous flush.
flush_lock (RLock): A python lock used for flushing batches to plasma.
flush_thread (Threading): The python thread used for flushing batches
to plasma.
"""
def __init__(self,
max_size=999999,
max_batch_size=99999,
max_batch_time=0.01,
prefetch_depth=10,
background_flush=True):
self.max_size = max_size
self.max_batch_size = max_batch_size
self.max_batch_time = max_batch_time
self.prefetch_depth = prefetch_depth
self.background_flush = background_flush
# Common queue metadata -- This serves as the unique id of the queue
self.base = np.random.randint(0, 2**32 - 1, size=5, dtype="uint32")
self.base[-2] = 0
self.base[-1] = 0
self.read_ack_key = np.ndarray.tobytes(self.base)
# Reader state
self.prefetch_batch_offset = 0
self.read_item_offset = 0
self.read_batch_offset = 0
self.read_buffer = []
# Writer state
self.write_item_offset = 0
self.write_batch_offset = 0
self.write_buffer = []
self.last_flush_time = 0.0
self.cached_remote_offset = 0
self.flush_lock = threading.RLock()
self.flush_thread = FlushThread(self.max_batch_time,
self._flush_writes)
def __getstate__(self):
state = dict(self.__dict__)
del state["flush_lock"]
del state["flush_thread"]
del state["write_buffer"]
return state
def __setstate__(self, state):
self.__dict__.update(state)
# This is to enable writing functionality in
# case the queue is not created by the writer
# The reason is that python locks cannot be serialized
def enable_writes(self):
"""Restores the state of the batched queue for writing."""
self.write_buffer = []
self.flush_lock = threading.RLock()
self.flush_thread = FlushThread(self.max_batch_time,
self._flush_writes)
# Batch ids consist of a unique queue id used as prefix along with
# two numbers generated using the batch offset in the queue
def _batch_id(self, batch_offset):
oid = self.base.copy()
oid[-2] = batch_offset // 2**32
oid[-1] = batch_offset % 2**32
return np.ndarray.tobytes(oid)
def _flush_writes(self):
with self.flush_lock:
if not self.write_buffer:
return
batch_id = self._batch_id(self.write_batch_offset)
ray.worker.global_worker.put_object(self.write_buffer,
ray.ObjectID(batch_id))
logger.debug("[writer] Flush batch {} offset {} size {}".format(
self.write_batch_offset, self.write_item_offset,
len(self.write_buffer)))
self.write_buffer = []
self.write_batch_offset += 1
self._wait_for_reader()
self.last_flush_time = time.time()
def _wait_for_reader(self):
"""Checks for backpressure by the downstream reader."""
if self.max_size <= 0: # Unlimited queue
return
if self.write_item_offset - self.cached_remote_offset <= self.max_size:
return # Hasn't reached max size
remote_offset = internal_kv._internal_kv_get(self.read_ack_key)
if remote_offset is None:
# logger.debug("[writer] Waiting for reader to start...")
while remote_offset is None:
time.sleep(0.01)
remote_offset = internal_kv._internal_kv_get(self.read_ack_key)
remote_offset = int(remote_offset)
if self.write_item_offset - remote_offset > self.max_size:
logger.debug(
"[writer] Waiting for reader to catch up {} to {} - {}".format(
remote_offset, self.write_item_offset, self.max_size))
while self.write_item_offset - remote_offset > self.max_size:
time.sleep(0.01)
remote_offset = int(
internal_kv._internal_kv_get(self.read_ack_key))
self.cached_remote_offset = remote_offset
def _read_next_batch(self):
while (self.prefetch_batch_offset <
self.read_batch_offset + self.prefetch_depth):
plasma_prefetch(self._batch_id(self.prefetch_batch_offset))
self.prefetch_batch_offset += 1
self.read_buffer = ray.get(
ray.ObjectID(self._batch_id(self.read_batch_offset)))
self.read_batch_offset += 1
logger.debug("[reader] Fetched batch {} offset {} size {}".format(
self.read_batch_offset, self.read_item_offset,
len(self.read_buffer)))
self._ack_reads(self.read_item_offset + len(self.read_buffer))
# Reader acks the key it reads so that writer knows reader's offset.
# This is to cap queue size and simulate backpressure
def _ack_reads(self, offset):
if self.max_size > 0:
internal_kv._internal_kv_put(
self.read_ack_key, offset, overwrite=True)
def put_next(self, item):
with self.flush_lock:
if self.background_flush and not self.flush_thread.is_alive():
logger.debug("[writer] Starting batch flush thread")
self.flush_thread.start()
self.write_buffer.append(item)
self.write_item_offset += 1
if not self.last_flush_time:
self.last_flush_time = time.time()
delay = time.time() - self.last_flush_time
if (len(self.write_buffer) > self.max_batch_size
or delay > self.max_batch_time):
self._flush_writes()
def read_next(self):
if not self.read_buffer:
self._read_next_batch()
assert self.read_buffer
self.read_item_offset += 1
return self.read_buffer.pop(0)

View file

@ -1,182 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import logging
import time
import ray
from ray.experimental.streaming.batched_queue import BatchedQueue
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument(
"--rounds", default=10, help="the number of experiment rounds")
parser.add_argument(
"--num-queues", default=1, help="the number of queues in the chain")
parser.add_argument(
"--queue-size", default=10000, help="the queue size in number of batches")
parser.add_argument(
"--batch-size", default=1000, help="the batch size in number of elements")
parser.add_argument(
"--flush-timeout", default=0.001, help="the timeout to flush a batch")
parser.add_argument(
"--prefetch-depth",
default=10,
help="the number of batches to prefetch from plasma")
parser.add_argument(
"--background-flush",
default=False,
help="whether to flush in the backrgound or not")
parser.add_argument(
"--max-throughput",
default="inf",
help="maximum read throughput (elements/s)")
@ray.remote
class Node(object):
"""An actor that reads from an input queue and writes to an output queue.
Attributes:
id (int): The id of the actor.
queue (BatchedQueue): The input queue.
out_queue (BatchedQueue): The output queue.
max_reads_per_second (int): The max read throughput (default: inf).
num_reads (int): Number of elements read.
num_writes (int): Number of elements written.
"""
def __init__(self,
id,
in_queue,
out_queue,
max_reads_per_second=float("inf")):
self.id = id
self.queue = in_queue
self.out_queue = out_queue
self.max_reads_per_second = max_reads_per_second
self.num_reads = 0
self.num_writes = 0
self.start = time.time()
def read_write_forever(self):
debug_log = "[actor {}] Reads throttled to {} reads/s"
log = ""
if self.out_queue is not None:
self.out_queue.enable_writes()
log += "[actor {}] Reads/Writes per second {}"
else: # It's just a reader
log += "[actor {}] Reads per second {}"
# Start spinning
expected_value = 0
while True:
start = time.time()
N = 100000
for _ in range(N):
x = self.queue.read_next()
assert x == expected_value, (x, expected_value)
expected_value += 1
self.num_reads += 1
if self.out_queue is not None:
self.out_queue.put_next(x)
self.num_writes += 1
while (self.num_reads / (time.time() - self.start) >
self.max_reads_per_second):
logger.debug(
debug_log.format(self.id, self.max_reads_per_second))
time.sleep(0.1)
logger.info(log.format(self.id, N / (time.time() - start)))
# Flush any remaining elements
if self.out_queue is not None:
self.out_queue._flush_writes()
def test_max_throughput(rounds,
max_queue_size,
max_batch_size,
batch_timeout,
prefetch_depth,
background_flush,
num_queues,
max_reads_per_second=float("inf")):
assert num_queues >= 1
first_queue = BatchedQueue(
max_size=max_queue_size,
max_batch_size=max_batch_size,
max_batch_time=batch_timeout,
prefetch_depth=prefetch_depth,
background_flush=background_flush)
previous_queue = first_queue
for i in range(num_queues):
# Construct the batched queue
in_queue = previous_queue
out_queue = None
if i < num_queues - 1:
out_queue = BatchedQueue(
max_size=max_queue_size,
max_batch_size=max_batch_size,
max_batch_time=batch_timeout,
prefetch_depth=prefetch_depth,
background_flush=background_flush)
node = Node.remote(i, in_queue, out_queue, max_reads_per_second)
node.read_write_forever.remote()
previous_queue = out_queue
value = 0
# Feed the chain
for round in range(rounds):
logger.info("Round {}".format(round))
N = 100000
start = time.time()
for i in range(N):
first_queue.put_next(value)
value += 1
log = "[writer] Puts per second {}"
logger.info(log.format(N / (time.time() - start)))
first_queue._flush_writes()
if __name__ == "__main__":
ray.init()
ray.register_custom_serializer(BatchedQueue, use_pickle=True)
args = parser.parse_args()
rounds = int(args.rounds)
max_queue_size = int(args.queue_size)
max_batch_size = int(args.batch_size)
batch_timeout = float(args.flush_timeout)
prefetch_depth = int(args.prefetch_depth)
background_flush = bool(args.background_flush)
num_queues = int(args.num_queues)
max_reads_per_second = float(args.max_throughput)
logger.info("== Parameters ==")
logger.info("Rounds: {}".format(rounds))
logger.info("Max queue size: {}".format(max_queue_size))
logger.info("Max batch size: {}".format(max_batch_size))
logger.info("Batch timeout: {}".format(batch_timeout))
logger.info("Prefetch depth: {}".format(prefetch_depth))
logger.info("Background flush: {}".format(background_flush))
logger.info("Max read throughput: {}".format(max_reads_per_second))
# Estimate the ideal throughput
value = 0
start = time.time()
for round in range(rounds):
N = 100000
for _ in range(N):
value += 1
logger.info("Ideal throughput: {}".format(value / (time.time() - start)))
logger.info("== Testing max throughput ==")
start = time.time()
test_max_throughput(rounds, max_queue_size, max_batch_size, batch_timeout,
prefetch_depth, background_flush, num_queues,
max_reads_per_second)
logger.info("Elapsed time: {}".format(time.time() - start))

View file

@ -1,359 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import logging
import sys
from ray.experimental.streaming.operator import PStrategy
from ray.experimental.streaming.batched_queue import BatchedQueue
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# Forward and broadcast stream partitioning strategies
forward_broadcast_strategies = [PStrategy.Forward, PStrategy.Broadcast]
# Used to choose output channel in case of hash-based shuffling
def _hash(value):
if isinstance(value, int):
return value
try:
return int(hashlib.sha1(value.encode("utf-8")).hexdigest(), 16)
except AttributeError:
return int(hashlib.sha1(value).hexdigest(), 16)
# A data channel is a batched queue between two
# operator instances in a streaming environment
class DataChannel(object):
"""A data channel for actor-to-actor communication.
Attributes:
env (Environment): The environment the channel belongs to.
src_operator_id (UUID): The id of the source operator of the channel.
dst_operator_id (UUID): The id of the destination operator of the
channel.
src_instance_id (int): The id of the source instance.
dst_instance_id (int): The id of the destination instance.
queue (BatchedQueue): The batched queue used for data movement.
"""
def __init__(self, env, src_operator_id, dst_operator_id, src_instance_id,
dst_instance_id):
self.env = env
self.src_operator_id = src_operator_id
self.dst_operator_id = dst_operator_id
self.src_instance_id = src_instance_id
self.dst_instance_id = dst_instance_id
self.queue = BatchedQueue(
max_size=self.env.config.queue_config.max_size,
max_batch_size=self.env.config.queue_config.max_batch_size,
max_batch_time=self.env.config.queue_config.max_batch_time,
prefetch_depth=self.env.config.queue_config.prefetch_depth,
background_flush=self.env.config.queue_config.background_flush)
def __repr__(self):
return "({},{},{},{})".format(
self.src_operator_id, self.dst_operator_id, self.src_instance_id,
self.dst_instance_id)
# Pulls and merges data from multiple input channels
class DataInput(object):
"""An input gate of an operator instance.
The input gate pulls records from all input channels in a round-robin
fashion.
Attributes:
input_channels (list): The list of input channels.
channel_index (int): The index of the next channel to pull from.
max_index (int): The number of input channels.
closed (list): A list of flags indicating whether an input channel
has been marked as 'closed'.
all_closed (bool): Denotes whether all input channels have been
closed (True) or not (False).
"""
def __init__(self, channels):
self.input_channels = channels
self.channel_index = 0
self.max_index = len(channels)
self.closed = [False] * len(
self.input_channels) # Tracks the channels that have been closed
self.all_closed = False
# Fetches records from input channels in a round-robin fashion
# TODO (john): Make sure the instance is not blocked on any of its input
# channels
# TODO (john): In case of input skew, it might be better to pull from
# the largest queue more often
def _pull(self):
while True:
if self.max_index == 0:
# TODO (john): We should detect this earlier
return None
# Channel to pull from
channel = self.input_channels[self.channel_index]
self.channel_index += 1
if self.channel_index == self.max_index: # Reset channel index
self.channel_index = 0
if self.closed[self.channel_index - 1]:
continue # Channel has been 'closed', check next
record = channel.queue.read_next()
logger.debug("Actor ({},{}) pulled '{}'.".format(
channel.src_operator_id, channel.src_instance_id, record))
if record is None:
# Mark channel as 'closed' and pull from the next open one
self.closed[self.channel_index - 1] = True
self.all_closed = True
for flag in self.closed:
if flag is False:
self.all_closed = False
break
if not self.all_closed:
continue
# Returns 'None' iff all input channels are 'closed'
return record
# Selects output channel(s) and pushes data
class DataOutput(object):
"""An output gate of an operator instance.
The output gate pushes records to output channels according to the
user-defined partitioning scheme.
Attributes:
partitioning_schemes (dict): A mapping from destination operator ids
to partitioning schemes (see: PScheme in operator.py).
forward_channels (list): A list of channels to forward records.
shuffle_channels (list(list)): A list of output channels to shuffle
records grouped by destination operator.
shuffle_key_channels (list(list)): A list of output channels to
shuffle records by a key grouped by destination operator.
shuffle_exists (bool): A flag indicating that there exists at least
one shuffle_channel.
shuffle_key_exists (bool): A flag indicating that there exists at
least one shuffle_key_channel.
"""
def __init__(self, channels, partitioning_schemes):
self.key_selector = None
self.round_robin_indexes = [0]
self.partitioning_schemes = partitioning_schemes
# Prepare output -- collect channels by type
self.forward_channels = [] # Forward and broadcast channels
slots = sum(1 for scheme in self.partitioning_schemes.values()
if scheme.strategy == PStrategy.RoundRobin)
self.round_robin_channels = [[]] * slots # RoundRobin channels
self.round_robin_indexes = [-1] * slots
slots = sum(1 for scheme in self.partitioning_schemes.values()
if scheme.strategy == PStrategy.Shuffle)
# Flag used to avoid hashing when there is no shuffling
self.shuffle_exists = slots > 0
self.shuffle_channels = [[]] * slots # Shuffle channels
slots = sum(1 for scheme in self.partitioning_schemes.values()
if scheme.strategy == PStrategy.ShuffleByKey)
# Flag used to avoid hashing when there is no shuffling by key
self.shuffle_key_exists = slots > 0
self.shuffle_key_channels = [[]] * slots # Shuffle by key channels
# Distinct shuffle destinations
shuffle_destinations = {}
# Distinct shuffle by key destinations
shuffle_by_key_destinations = {}
# Distinct round robin destinations
round_robin_destinations = {}
index_1 = 0
index_2 = 0
index_3 = 0
for channel in channels:
p_scheme = self.partitioning_schemes[channel.dst_operator_id]
strategy = p_scheme.strategy
if strategy in forward_broadcast_strategies:
self.forward_channels.append(channel)
elif strategy == PStrategy.Shuffle:
pos = shuffle_destinations.setdefault(channel.dst_operator_id,
index_1)
self.shuffle_channels[pos].append(channel)
if pos == index_1:
index_1 += 1
elif strategy == PStrategy.ShuffleByKey:
pos = shuffle_by_key_destinations.setdefault(
channel.dst_operator_id, index_2)
self.shuffle_key_channels[pos].append(channel)
if pos == index_2:
index_2 += 1
elif strategy == PStrategy.RoundRobin:
pos = round_robin_destinations.setdefault(
channel.dst_operator_id, index_3)
self.round_robin_channels[pos].append(channel)
if pos == index_3:
index_3 += 1
else: # TODO (john): Add support for other strategies
sys.exit("Unrecognized or unsupported partitioning strategy.")
# A KeyedDataStream can only be shuffled by key
assert not (self.shuffle_exists and self.shuffle_key_exists)
# Flushes any remaining records in the output channels
# 'close' indicates whether we should also 'close' the channel (True)
# by propagating 'None'
# or just flush the remaining records to plasma (False)
def _flush(self, close=False):
"""Flushes remaining output records in the output queues to plasma.
None is used as special type of record that is propagated from sources
to sink to notify that the end of data in a stream.
Attributes:
close (bool): A flag denoting whether the channel should be
also marked as 'closed' (True) or not (False) after flushing.
"""
for channel in self.forward_channels:
if close is True:
channel.queue.put_next(None)
channel.queue._flush_writes()
for channels in self.shuffle_channels:
for channel in channels:
if close is True:
channel.queue.put_next(None)
channel.queue._flush_writes()
for channels in self.shuffle_key_channels:
for channel in channels:
if close is True:
channel.queue.put_next(None)
channel.queue._flush_writes()
for channels in self.round_robin_channels:
for channel in channels:
if close is True:
channel.queue.put_next(None)
channel.queue._flush_writes()
# TODO (john): Add more channel types
# Returns all destination actor ids
def _destination_actor_ids(self):
destinations = []
for channel in self.forward_channels:
destinations.append((channel.dst_operator_id,
channel.dst_instance_id))
for channels in self.shuffle_channels:
for channel in channels:
destinations.append((channel.dst_operator_id,
channel.dst_instance_id))
for channels in self.shuffle_key_channels:
for channel in channels:
destinations.append((channel.dst_operator_id,
channel.dst_instance_id))
for channels in self.round_robin_channels:
for channel in channels:
destinations.append((channel.dst_operator_id,
channel.dst_instance_id))
# TODO (john): Add more channel types
return destinations
# Pushes the record to the output
# Each individual output queue flushes batches to plasma periodically
# based on 'batch_max_size' and 'batch_max_time'
def _push(self, record):
# Forward record
for channel in self.forward_channels:
logger.debug("[writer] Push record '{}' to channel {}".format(
record, channel))
channel.queue.put_next(record)
# Forward record
index = 0
for channels in self.round_robin_channels:
self.round_robin_indexes[index] += 1
if self.round_robin_indexes[index] == len(channels):
self.round_robin_indexes[index] = 0 # Reset index
channel = channels[self.round_robin_indexes[index]]
logger.debug("[writer] Push record '{}' to channel {}".format(
record, channel))
channel.queue.put_next(record)
index += 1
# Hash-based shuffling by key
if self.shuffle_key_exists:
key, _ = record
h = _hash(key)
for channels in self.shuffle_key_channels:
num_instances = len(channels) # Downstream instances
channel = channels[h % num_instances]
logger.debug(
"[key_shuffle] Push record '{}' to channel {}".format(
record, channel))
channel.queue.put_next(record)
elif self.shuffle_exists: # Hash-based shuffling per destination
h = _hash(record)
for channels in self.shuffle_channels:
num_instances = len(channels) # Downstream instances
channel = channels[h % num_instances]
logger.debug("[shuffle] Push record '{}' to channel {}".format(
record, channel))
channel.queue.put_next(record)
else: # TODO (john): Handle rescaling
pass
# Pushes a list of records to the output
# Each individual output queue flushes batches to plasma periodically
# based on 'batch_max_size' and 'batch_max_time'
def _push_all(self, records):
# Forward records
for record in records:
for channel in self.forward_channels:
logger.debug("[writer] Push record '{}' to channel {}".format(
record, channel))
channel.queue.put_next(record)
# Hash-based shuffling by key per destination
if self.shuffle_key_exists:
for record in records:
key, _ = record
h = _hash(key)
for channels in self.shuffle_channels:
num_instances = len(channels) # Downstream instances
channel = channels[h % num_instances]
logger.debug(
"[key_shuffle] Push record '{}' to channel {}".format(
record, channel))
channel.queue.put_next(record)
elif self.shuffle_exists: # Hash-based shuffling per destination
for record in records:
h = _hash(record)
for channels in self.shuffle_channels:
num_instances = len(channels) # Downstream instances
channel = channels[h % num_instances]
logger.debug(
"[shuffle] Push record '{}' to channel {}".format(
record, channel))
channel.queue.put_next(record)
else: # TODO (john): Handle rescaling
pass
# Batched queue configuration
class QueueConfig(object):
"""The configuration of a batched queue.
Attributes:
max_size (int): The maximum size of the queue in number of batches
(if exceeded, backpressure kicks in).
max_batch_size (int): The size of each batch in number of records.
max_batch_time (float): The flush timeout per batch.
prefetch_depth (int): The number of batches to prefetch from plasma.
background_flush (bool): Denotes whether a daemon flush thread should
be used (True) to flush batches to plasma.
"""
def __init__(self,
max_size=999999,
max_batch_size=99999,
max_batch_time=0.01,
prefetch_depth=10,
background_flush=False):
self.max_size = max_size
self.max_batch_size = max_batch_size
self.max_batch_time = max_batch_time
self.prefetch_depth = prefetch_depth
self.background_flush = background_flush

View file

@ -1,365 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import sys
import time
import types
import ray
logger = logging.getLogger(__name__)
logger.setLevel("DEBUG")
#
# Each Ray actor corresponds to an operator instance in the physical dataflow
# Actors communicate using batched queues as data channels (no standing TCP
# connections)
# Currently, batched queues are based on Eric's implementation (see:
# batched_queue.py)
def _identity(element):
return element
# TODO (john): Specify the interface of state keepers
class OperatorInstance(object):
"""A streaming operator instance.
Attributes:
instance_id (UUID): The id of the instance.
input (DataInput): The input gate that manages input channels of
the instance (see: DataInput in communication.py).
input (DataOutput): The output gate that manages output channels of
the instance (see: DataOutput in communication.py).
state_keepers (list): A list of actor handlers to query the state of
the operator instance.
"""
def __init__(self, instance_id, input_gate, output_gate,
state_keeper=None):
self.key_index = None # Index for key selection
self.key_attribute = None # Attribute name for key selection
self.instance_id = instance_id
self.input = input_gate
self.output = output_gate
# Handle(s) to one or more user-defined actors
# that can retrieve actor's state
self.state_keeper = state_keeper
# Enable writes
for channel in self.output.forward_channels:
channel.queue.enable_writes()
for channels in self.output.shuffle_channels:
for channel in channels:
channel.queue.enable_writes()
for channels in self.output.shuffle_key_channels:
for channel in channels:
channel.queue.enable_writes()
for channels in self.output.round_robin_channels:
for channel in channels:
channel.queue.enable_writes()
# TODO (john): Add more channel types here
# Registers actor's handle so that the actor can schedule itself
def register_handle(self, actor_handle):
self.this_actor = actor_handle
# Used for index-based key extraction, e.g. for tuples
def index_based_selector(self, record):
return record[self.key_index]
# Used for attribute-based key extraction, e.g. for classes
def attribute_based_selector(self, record):
return vars(record)[self.key_attribute]
# Starts the actor
def start(self):
pass
# A source actor that reads a text file line by line
@ray.remote
class ReadTextFile(OperatorInstance):
"""A source operator instance that reads a text file line by line.
Attributes:
filepath (string): The path to the input file.
"""
def __init__(self,
instance_id,
operator_metadata,
input_gate,
output_gate,
state_keepers=None):
OperatorInstance.__init__(self, instance_id, input_gate, output_gate,
state_keepers)
self.filepath = operator_metadata.other_args
# TODO (john): Handle possible exception here
self.reader = open(self.filepath, "r")
# Read input file line by line
def start(self):
while True:
record = self.reader.readline()
# Reader returns empty string ('') on EOF
if not record:
# Flush any remaining records to plasma and close the file
self.output._flush(close=True)
self.reader.close()
return
self.output._push(
record[:-1]) # Push after removing newline characters
# Map actor
@ray.remote
class Map(OperatorInstance):
"""A map operator instance that applies a user-defined
stream transformation.
A map produces exactly one output record for each record in
the input stream.
Attributes:
map_fn (function): The user-defined function.
"""
def __init__(self, instance_id, operator_metadata, input_gate,
output_gate):
OperatorInstance.__init__(self, instance_id, input_gate, output_gate)
self.map_fn = operator_metadata.logic
# Applies the mapper each record of the input stream(s)
# and pushes resulting records to the output stream(s)
def start(self):
start = time.time()
elements = 0
while True:
record = self.input._pull()
if record is None:
self.output._flush(close=True)
logger.debug("[map {}] read/writes per second: {}".format(
self.instance_id, elements / (time.time() - start)))
return
self.output._push(self.map_fn(record))
elements += 1
# Flatmap actor
@ray.remote
class FlatMap(OperatorInstance):
"""A map operator instance that applies a user-defined
stream transformation.
A flatmap produces one or more output records for each record in
the input stream.
Attributes:
flatmap_fn (function): The user-defined function.
"""
def __init__(self, instance_id, operator_metadata, input_gate,
output_gate):
OperatorInstance.__init__(self, instance_id, input_gate, output_gate)
self.flatmap_fn = operator_metadata.logic
# Applies the splitter to the records of the input stream(s)
# and pushes resulting records to the output stream(s)
def start(self):
while True:
record = self.input._pull()
if record is None:
self.output._flush(close=True)
return
self.output._push_all(self.flatmap_fn(record))
# Filter actor
@ray.remote
class Filter(OperatorInstance):
"""A filter operator instance that applies a user-defined filter to
each record of the stream.
Output records are those that pass the filter, i.e. those for which
the filter function returns True.
Attributes:
filter_fn (function): The user-defined boolean function.
"""
def __init__(self, instance_id, operator_metadata, input_gate,
output_gate):
OperatorInstance.__init__(self, instance_id, input_gate, output_gate)
self.filter_fn = operator_metadata.logic
# Applies the filter to the records of the input stream(s)
# and pushes resulting records to the output stream(s)
def start(self):
while True:
record = self.input._pull()
if record is None: # Close channel and return
self.output._flush(close=True)
return
if self.filter_fn(record):
self.output._push(record)
# Inspect actor
@ray.remote
class Inspect(OperatorInstance):
"""A inspect operator instance that inspects the content of the stream.
Inspect is useful for printing the records in the stream.
Attributes:
inspect_fn (function): The user-defined inspect logic.
"""
def __init__(self, instance_id, operator_metadata, input_gate,
output_gate):
OperatorInstance.__init__(self, instance_id, input_gate, output_gate)
self.inspect_fn = operator_metadata.logic
# Applies the inspect logic (e.g. print) to the records of
# the input stream(s)
# and leaves stream unaffected by simply pushing the records to
# the output stream(s)
while True:
record = self.input._pull()
if record is None:
self.output._flush(close=True)
return
self.output._push(record)
self.inspect_fn(record)
# Reduce actor
@ray.remote
class Reduce(OperatorInstance):
"""A reduce operator instance that combines a new value for a key
with the last reduced one according to a user-defined logic.
Attributes:
reduce_fn (function): The user-defined reduce logic.
value_attribute (int): The index of the value to reduce
(assuming tuple records).
state (dict): A mapping from keys to values.
"""
def __init__(self, instance_id, operator_metadata, input_gate,
output_gate):
OperatorInstance.__init__(self, instance_id, input_gate, output_gate,
operator_metadata.state_actor)
self.reduce_fn = operator_metadata.logic
# Set the attribute selector
self.attribute_selector = operator_metadata.other_args
if self.attribute_selector is None:
self.attribute_selector = _identity
elif isinstance(self.attribute_selector, int):
self.key_index = self.attribute_selector
self.attribute_selector = self.index_based_selector
elif isinstance(self.attribute_selector, str):
self.key_attribute = self.attribute_selector
self.attribute_selector = self.attribute_based_selector
elif not isinstance(self.attribute_selector, types.FunctionType):
sys.exit("Unrecognized or unsupported key selector.")
self.state = {} # key -> value
# Combines the input value for a key with the last reduced
# value for that key to produce a new value.
# Outputs the result as (key,new value)
def start(self):
while True:
record = self.input._pull()
if record is None:
self.output._flush(close=True)
del self.state
return
key, rest = record
new_value = self.attribute_selector(rest)
# TODO (john): Is there a way to update state with
# a single dictionary lookup?
try:
old_value = self.state[key]
new_value = self.reduce_fn(old_value, new_value)
self.state[key] = new_value
except KeyError: # Key does not exist in state
self.state.setdefault(key, new_value)
self.output._push((key, new_value))
# Returns the state of the actor
def get_state(self):
return self.state
@ray.remote
class KeyBy(OperatorInstance):
"""A key_by operator instance that physically partitions the
stream based on a key.
Attributes:
key_attribute (int): The index of the value to reduce
(assuming tuple records).
"""
def __init__(self, instance_id, operator_metadata, input_gate,
output_gate):
OperatorInstance.__init__(self, instance_id, input_gate, output_gate)
# Set the key selector
self.key_selector = operator_metadata.other_args
if isinstance(self.key_selector, int):
self.key_index = self.key_selector
self.key_selector = self.index_based_selector
elif isinstance(self.key_selector, str):
self.key_attribute = self.key_selector
self.key_selector = self.attribute_based_selector
elif not isinstance(self.key_selector, types.FunctionType):
sys.exit("Unrecognized or unsupported key selector.")
# The actual partitioning is done by the output gate
def start(self):
while True:
record = self.input._pull()
if record is None:
self.output._flush(close=True)
return
key = self.key_selector(record)
self.output._push((key, record))
# A custom source actor
@ray.remote
class Source(OperatorInstance):
def __init__(self, instance_id, operator_metadata, input_gate,
output_gate):
OperatorInstance.__init__(self, instance_id, input_gate, output_gate)
# The user-defined source with a get_next() method
self.source = operator_metadata.other_args
# Starts the source by calling get_next() repeatedly
def start(self):
start = time.time()
elements = 0
while True:
next = self.source.get_next()
if next is None:
self.output._flush(close=True)
logger.debug("[writer {}] puts per second: {}".format(
self.instance_id, elements / (time.time() - start)))
return
self.output._push(next)
elements += 1
# TODO(john): Time window actor (uses system time)
@ray.remote
class TimeWindow(OperatorInstance):
def __init__(self, queue, width):
self.width = width # In milliseconds
def time_window(self):
while True:
pass

View file

@ -15,11 +15,6 @@ cdef class Buffer:
See https://docs.python.org/3/c-api/buffer.html for details.
"""
cdef:
shared_ptr[CBuffer] buffer
Py_ssize_t shape
Py_ssize_t strides
@staticmethod
cdef make(const shared_ptr[CBuffer]& buffer):
cdef Buffer self = Buffer.__new__(Buffer)

View file

@ -19,7 +19,7 @@ from ray.includes.unique_ids cimport (
CObjectID,
CTaskID,
CUniqueID,
CWorkerID,
CWorkerID
)
import ray
@ -40,8 +40,6 @@ cdef extern from "ray/common/constants.h" nogil:
cdef class BaseID:
# To avoid the error of "Python int too large to convert to C ssize_t",
# here `cdef size_t` is required.
cdef size_t hash(self):
pass
@ -129,13 +127,6 @@ cdef class UniqueID(BaseID):
cdef class ObjectID(BaseID):
cdef:
CObjectID data
object buffer_ref
# Flag indicating whether or not this object ID was added to the set
# of active IDs in the core worker so we know whether we should clean
# it up.
c_bool in_core_worker
def __init__(self, id):
check_id(id)
@ -332,8 +323,6 @@ cdef class WorkerID(UniqueID):
return <CWorkerID>self.data
cdef class ActorID(BaseID):
cdef CActorID data
def __init__(self, id):
check_id(id, CActorID.Size())
self.data = CActorID.FromBinary(<c_string>id)

1
python/ray/streaming Symbolic link
View file

@ -0,0 +1 @@
../../streaming/python/

View file

@ -233,14 +233,6 @@ py_test(
deps = ["//:ray_lib"],
)
py_test(
name = "test_logical_graph",
size = "small",
srcs = ["test_logical_graph.py"],
tags = ["exclusive"],
deps = ["//:ray_lib"],
)
py_test(
name = "test_memory_limits",
size = "medium",

View file

@ -177,6 +177,7 @@ requires = [
"six >= 1.0.0",
"faulthandler;python_version<'3.3'",
"protobuf >= 3.8.0",
"cloudpickle",
]
setup(

View file

@ -603,11 +603,13 @@ Status CoreWorker::SubmitTask(const RayFunction &function,
TaskID::ForNormalTask(worker_context_.GetCurrentJobID(),
worker_context_.GetCurrentTaskID(), next_task_index);
const std::unordered_map<std::string, double> required_resources;
// TODO(ekl) offload task building onto a thread pool for performance
BuildCommonTaskSpec(
builder, worker_context_.GetCurrentJobID(), task_id,
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_,
function, args, task_options.num_returns, task_options.resources, {},
function, args, task_options.num_returns, task_options.resources,
required_resources,
task_options.is_direct_call ? TaskTransportType::DIRECT : TaskTransportType::RAYLET,
return_ids);
TaskSpecification task_spec = builder.Build();
@ -681,10 +683,11 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
const TaskID actor_task_id = TaskID::ForActorTask(
worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(),
next_task_index, actor_handle->GetActorID());
const std::unordered_map<std::string, double> required_resources;
BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id,
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(),
rpc_address_, function, args, num_returns, task_options.resources,
{}, transport_type, return_ids);
required_resources, transport_type, return_ids);
const ObjectID new_cursor = return_ids->back();
actor_handle->SetActorTaskSpec(builder, transport_type, new_cursor);

View file

@ -0,0 +1,27 @@
# This file defines the C++ symbols that need to be exported (aka ABI, application binary interface).
# These symbols will be used by other libraries (e.g., streaming).
# Note: This file is used for macOS only, and should be kept in sync with `ray_version_script.lds`.
# Ray ABI is not finalized, the exact set of exported (C/C++) APIs is subject to change.
# common
*ray*Language*
*ray*RayObject*
*ray*Status*
*ray*RayFunction*
*ray*TaskArg*
*ray*TaskOptions*
*ray*Buffer*
*ray*LocalMemoryBuffer*
# util
*ray*RayLog*
*ray*RayLogLevel*
# id
*ray*MurmurHash64A*
*ray*JobID*
*ray*TaskID*
*ray*ActorID*
*ray*ObjectID*
# Others
*ray*CoreWorker*
*PyInit*
*init_raylet*
*Java*

View file

@ -0,0 +1,31 @@
# This file defines the C++ symbols that need to be exported (aka ABI, application binary interface).
# These symbols will be used by other libraries (e.g., streaming).
# Note: This file is used for linux only, and should be kept in sync with `ray_exported_symbols.lds`.
# Ray ABI is not finalized, the exact set of exported (C/C++) APIs is subject to change.
VERSION_1.0 {
global:
# common
*ray*Language*;
*ray*RayObject*;
*ray*Status*;
*ray*RayFunction*;
*ray*TaskArg*;
*ray*TaskOptions*;
*ray*Buffer*;
*ray*LocalMemoryBuffer*;
# util
*ray*RayLog*;
*ray*RayLogLevel*;
# id
*ray*MurmurHash64A*;
*ray*JobID*;
*ray*TaskID*;
*ray*ActorID*;
*ray*ObjectID*;
# Others
*ray*CoreWorker*;
*PyInit*;
*init_raylet*;
*Java*;
local: *;
};

235
streaming/BUILD.bazel Normal file
View file

@ -0,0 +1,235 @@
# Bazel build
# C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html
load("@com_github_grpc_grpc//bazel:cython_library.bzl", "pyx_library")
load("@rules_proto_grpc//python:defs.bzl", "python_proto_compile")
proto_library(
name = "streaming_proto",
srcs = ["src/protobuf/streaming.proto"],
visibility = ["//visibility:public"],
)
cc_proto_library(
name = "streaming_cc_proto",
deps = [":streaming_proto"],
)
proto_library(
name = "streaming_queue_proto",
srcs = ["src/protobuf/streaming_queue.proto"],
)
cc_proto_library(
name = "streaming_queue_cc_proto",
deps = ["streaming_queue_proto"],
)
# Use `linkshared` to ensure ray related symbols are not packed into streaming libs
# to avoid duplicate symbols. In runtime we expose ray related symbols, which can
# be linked into streaming libs by dynamic linker. See bazel rule `//:_raylet`
cc_binary(
name = "ray_util.so",
linkshared = 1,
visibility = ["//visibility:public"],
deps = ["//:ray_util"],
)
cc_binary(
name = "ray_common.so",
linkshared = 1,
visibility = ["//visibility:public"],
deps = ["//:ray_common"],
)
cc_binary(
name = "core_worker_lib.so",
linkshared = 1,
deps = ["//:core_worker_lib"],
)
cc_library(
name = "streaming_util",
srcs = glob([
"src/util/*.cc",
]),
hdrs = glob([
"src/util/*.h",
]),
includes = [
"src",
],
visibility = ["//visibility:public"],
deps = [
"ray_util.so",
"@boost//:any",
"@com_google_googletest//:gtest",
],
)
cc_library(
name = "streaming_config",
srcs = glob([
"src/config/*.cc",
]),
hdrs = glob([
"src/config/*.h",
]),
deps = [
"ray_common.so",
":streaming_cc_proto",
":streaming_util",
],
)
cc_library(
name = "streaming_message",
srcs = glob([
"src/message/*.cc",
]),
hdrs = glob([
"src/message/*.h",
]),
deps = [
"ray_common.so",
":streaming_config",
":streaming_util",
],
)
cc_library(
name = "streaming_queue",
srcs = glob([
"src/queue/*.cc",
]),
hdrs = glob([
"src/queue/*.h",
]),
deps = [
"core_worker_lib.so",
"ray_common.so",
"ray_util.so",
":streaming_config",
":streaming_message",
":streaming_queue_cc_proto",
":streaming_util",
"@boost//:asio",
"@boost//:thread",
],
)
cc_library(
name = "streaming_lib",
srcs = glob([
"src/*.cc",
]),
hdrs = glob([
"src/*.h",
"src/queue/*.h",
"src/test/*.h",
]),
includes = ["src"],
visibility = ["//visibility:public"],
deps = [
"ray_common.so",
"ray_util.so",
":streaming_config",
":streaming_message",
":streaming_queue",
":streaming_util",
"@boost//:circular_buffer",
],
)
test_common_deps = [
":streaming_lib",
"//:ray_common",
"//:ray_util",
"//:core_worker_lib",
]
# streaming queue mock actor binary
cc_binary(
name = "streaming_test_worker",
srcs = glob(["src/test/*.h"]) + [
"src/test/mock_actor.cc",
],
includes = [
"streaming/src/test",
],
deps = test_common_deps,
)
# use src/test/run_streaming_queue_test.sh to run this test
cc_binary(
name = "streaming_queue_tests",
srcs = glob(["src/test/*.h"]) + [
"src/test/streaming_queue_tests.cc",
],
deps = test_common_deps,
)
cc_test(
name = "streaming_message_ring_buffer_tests",
srcs = [
"src/test/ring_buffer_tests.cc",
],
includes = [
"streaming/src/test",
],
deps = test_common_deps,
)
cc_test(
name = "streaming_message_serialization_tests",
srcs = [
"src/test/message_serialization_tests.cc",
],
deps = test_common_deps,
)
cc_test(
name = "streaming_mock_transfer",
srcs = [
"src/test/mock_transfer_tests.cc",
],
deps = test_common_deps,
)
cc_test(
name = "streaming_util_tests",
srcs = [
"src/test/streaming_util_tests.cc",
],
deps = test_common_deps,
)
python_proto_compile(
name = "streaming_py_proto",
deps = ["//streaming:streaming_proto"],
)
genrule(
name = "copy_streaming_py_proto",
srcs = [
":streaming_py_proto",
],
outs = [
"copy_streaming_py_proto.out",
],
cmd = """
set -e
set -x
WORK_DIR=$$(pwd)
# Copy generated files.
GENERATED_DIR=$$WORK_DIR/streaming/python/generated
rm -rf $$GENERATED_DIR
mkdir -p $$GENERATED_DIR
for f in $(locations //streaming:streaming_py_proto); do
cp $$f $$GENERATED_DIR
done
echo $$(date) > $@
""",
local = 1,
visibility = ["//visibility:public"],
)

28
streaming/README.md Normal file
View file

@ -0,0 +1,28 @@
# Ray Streaming
1. Build streaming java
* build ray
* `sh build.sh -l java`
* `cd java && mvn clean install -Dmaven.test.skip=true`
* build streaming
* `cd ray/streaming/java && bazel build all_modules`
* `mvn clean install -Dmaven.test.skip=true`
2. Build ray will build ray streaming python.
3. Run examples
```bash
# c++ test
cd streaming/ && bazel test ...
sh src/test/run_streaming_queue_test.sh
cd ..
# python test
cd python/ray/streaming/
pushd examples
python simple.py --input-file toy.txt
popd
pushd tests
pytest .
popd
```

View file

@ -0,0 +1,3 @@
# flake8: noqa
# Ray should be imported before streaming
import ray

View file

@ -0,0 +1,6 @@
# cython: profile=False
# distutils: language = c++
# cython: embedsignature = True
# cython: language_level = 3
include "includes/transfer.pxi"

View file

@ -0,0 +1,283 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import logging
import pickle
import sys
import time
import ray
import ray.streaming.runtime.transfer as transfer
from ray.streaming.config import Config
from ray.streaming.operator import PStrategy
from ray.streaming.runtime.transfer import ChannelID
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# Forward and broadcast stream partitioning strategies
forward_broadcast_strategies = [PStrategy.Forward, PStrategy.Broadcast]
# Used to choose output channel in case of hash-based shuffling
def _hash(value):
if isinstance(value, int):
return value
try:
return int(hashlib.sha1(value.encode("utf-8")).hexdigest(), 16)
except AttributeError:
return int(hashlib.sha1(value).hexdigest(), 16)
class DataChannel(object):
"""A data channel for actor-to-actor communication.
Attributes:
env (Environment): The environment the channel belongs to.
src_operator_id (UUID): The id of the source operator of the channel.
src_instance_index (int): The id of the source instance.
dst_operator_id (UUID): The id of the destination operator of the
channel.
dst_instance_index (int): The id of the destination instance.
"""
def __init__(self, src_operator_id, src_instance_index, dst_operator_id,
dst_instance_index, str_qid):
self.src_operator_id = src_operator_id
self.src_instance_index = src_instance_index
self.dst_operator_id = dst_operator_id
self.dst_instance_index = dst_instance_index
self.str_qid = str_qid
self.qid = ChannelID(str_qid)
def __repr__(self):
return "(src({},{}),dst({},{}), qid({}))".format(
self.src_operator_id, self.src_instance_index,
self.dst_operator_id, self.dst_instance_index, self.str_qid)
_CLOSE_FLAG = b" "
# Pulls and merges data from multiple input channels
class DataInput(object):
"""An input gate of an operator instance.
The input gate pulls records from all input channels in a round-robin
fashion.
Attributes:
input_channels (list): The list of input channels.
channel_index (int): The index of the next channel to pull from.
max_index (int): The number of input channels.
closed (list): A list of flags indicating whether an input channel
has been marked as 'closed'.
all_closed (bool): Denotes whether all input channels have been
closed (True) or not (False).
"""
def __init__(self, env, channels):
assert len(channels) > 0
self.env = env
self.reader = None # created in `init` method
self.input_channels = channels
self.channel_index = 0
self.max_index = len(channels)
# Tracks the channels that have been closed. qid: close status
self.closed = {}
def init(self):
channels = [c.str_qid for c in self.input_channels]
input_actors = []
for c in self.input_channels:
actor = self.env.execution_graph.get_actor(c.src_operator_id,
c.src_instance_index)
input_actors.append(actor)
logger.info("DataInput input_actors %s", input_actors)
conf = {
Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context()
.current_driver_id,
Config.CHANNEL_TYPE: self.env.config.channel_type
}
self.reader = transfer.DataReader(channels, input_actors, conf)
def pull(self):
# pull from channel
item = self.reader.read(100)
while item is None:
time.sleep(0.001)
item = self.reader.read(100)
msg_data = item.body()
if msg_data == _CLOSE_FLAG:
self.closed[item.channel_id] = True
if len(self.closed) == len(self.input_channels):
return None
else:
return self.pull()
else:
return pickle.loads(msg_data)
def close(self):
self.reader.stop()
# Selects output channel(s) and pushes data
class DataOutput(object):
"""An output gate of an operator instance.
The output gate pushes records to output channels according to the
user-defined partitioning scheme.
Attributes:
partitioning_schemes (dict): A mapping from destination operator ids
to partitioning schemes (see: PScheme in operator.py).
forward_channels (list): A list of channels to forward records.
shuffle_channels (list(list)): A list of output channels to shuffle
records grouped by destination operator.
shuffle_key_channels (list(list)): A list of output channels to
shuffle records by a key grouped by destination operator.
shuffle_exists (bool): A flag indicating that there exists at least
one shuffle_channel.
shuffle_key_exists (bool): A flag indicating that there exists at
least one shuffle_key_channel.
"""
def __init__(self, env, channels, partitioning_schemes):
assert len(channels) > 0
self.env = env
self.writer = None # created in `init` method
self.channels = channels
self.key_selector = None
self.round_robin_indexes = [0]
self.partitioning_schemes = partitioning_schemes
# Prepare output -- collect channels by type
self.forward_channels = [] # Forward and broadcast channels
slots = sum(1 for scheme in self.partitioning_schemes.values()
if scheme.strategy == PStrategy.RoundRobin)
self.round_robin_channels = [[]] * slots # RoundRobin channels
self.round_robin_indexes = [-1] * slots
slots = sum(1 for scheme in self.partitioning_schemes.values()
if scheme.strategy == PStrategy.Shuffle)
# Flag used to avoid hashing when there is no shuffling
self.shuffle_exists = slots > 0
self.shuffle_channels = [[]] * slots # Shuffle channels
slots = sum(1 for scheme in self.partitioning_schemes.values()
if scheme.strategy == PStrategy.ShuffleByKey)
# Flag used to avoid hashing when there is no shuffling by key
self.shuffle_key_exists = slots > 0
self.shuffle_key_channels = [[]] * slots # Shuffle by key channels
# Distinct shuffle destinations
shuffle_destinations = {}
# Distinct shuffle by key destinations
shuffle_by_key_destinations = {}
# Distinct round robin destinations
round_robin_destinations = {}
index_1 = 0
index_2 = 0
index_3 = 0
for channel in channels:
p_scheme = self.partitioning_schemes[channel.dst_operator_id]
strategy = p_scheme.strategy
if strategy in forward_broadcast_strategies:
self.forward_channels.append(channel)
elif strategy == PStrategy.Shuffle:
pos = shuffle_destinations.setdefault(channel.dst_operator_id,
index_1)
self.shuffle_channels[pos].append(channel)
if pos == index_1:
index_1 += 1
elif strategy == PStrategy.ShuffleByKey:
pos = shuffle_by_key_destinations.setdefault(
channel.dst_operator_id, index_2)
self.shuffle_key_channels[pos].append(channel)
if pos == index_2:
index_2 += 1
elif strategy == PStrategy.RoundRobin:
pos = round_robin_destinations.setdefault(
channel.dst_operator_id, index_3)
self.round_robin_channels[pos].append(channel)
if pos == index_3:
index_3 += 1
else: # TODO (john): Add support for other strategies
sys.exit("Unrecognized or unsupported partitioning strategy.")
# A KeyedDataStream can only be shuffled by key
assert not (self.shuffle_exists and self.shuffle_key_exists)
def init(self):
"""init DataOutput which creates DataWriter"""
channel_ids = [c.str_qid for c in self.channels]
to_actors = []
for c in self.channels:
actor = self.env.execution_graph.get_actor(c.dst_operator_id,
c.dst_instance_index)
to_actors.append(actor)
logger.info("DataOutput output_actors %s", to_actors)
conf = {
Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context()
.current_driver_id,
Config.CHANNEL_TYPE: self.env.config.channel_type
}
self.writer = transfer.DataWriter(channel_ids, to_actors, conf)
def close(self):
"""Close the channel (True) by propagating _CLOSE_FLAG
_CLOSE_FLAG is used as special type of record that is propagated from
sources to sink to notify that the end of data in a stream.
"""
for c in self.channels:
self.writer.write(c.qid, _CLOSE_FLAG)
# must ensure DataWriter send None flag to peer actor
self.writer.stop()
def push(self, record):
target_channels = []
# Forward record
for c in self.forward_channels:
logger.debug("[writer] Push record '{}' to channel {}".format(
record, c))
target_channels.append(c)
# Forward record
index = 0
for channels in self.round_robin_channels:
self.round_robin_indexes[index] += 1
if self.round_robin_indexes[index] == len(channels):
self.round_robin_indexes[index] = 0 # Reset index
c = channels[self.round_robin_indexes[index]]
logger.debug("[writer] Push record '{}' to channel {}".format(
record, c))
target_channels.append(c)
index += 1
# Hash-based shuffling by key
if self.shuffle_key_exists:
key, _ = record
h = _hash(key)
for channels in self.shuffle_key_channels:
num_instances = len(channels) # Downstream instances
c = channels[h % num_instances]
logger.debug(
"[key_shuffle] Push record '{}' to channel {}".format(
record, c))
target_channels.append(c)
elif self.shuffle_exists: # Hash-based shuffling per destination
h = _hash(record)
for channels in self.shuffle_channels:
num_instances = len(channels) # Downstream instances
c = channels[h % num_instances]
logger.debug("[shuffle] Push record '{}' to channel {}".format(
record, c))
target_channels.append(c)
else: # TODO (john): Handle rescaling
pass
msg_data = pickle.dumps(record)
for c in target_channels:
# send data to channel
self.writer.write(c.qid, msg_data)
def push_all(self, records):
for record in records:
self.push(record)

View file

@ -0,0 +1,23 @@
class Config:
STREAMING_JOB_NAME = "streaming.job.name"
STREAMING_OP_NAME = "streaming.op_name"
TASK_JOB_ID = "streaming.task_job_id"
STREAMING_WORKER_NAME = "streaming.worker_name"
# channel
CHANNEL_TYPE = "channel_type"
MEMORY_CHANNEL = "memory_channel"
NATIVE_CHANNEL = "native_channel"
CHANNEL_SIZE = "channel_size"
CHANNEL_SIZE_DEFAULT = 10**8
IS_RECREATE = "streaming.is_recreate"
# return from StreamingReader.getBundle if only empty message read in this
# interval.
TIMER_INTERVAL_MS = "timer_interval_ms"
STREAMING_RING_BUFFER_CAPACITY = "streaming.ring_buffer_capacity"
# write an empty message if there is no data to be written in this
# interval.
STREAMING_EMPTY_MESSAGE_INTERVAL = "streaming.empty_message_interval"
# operator type
OPERATOR_TYPE = "operator_type"

View file

@ -7,9 +7,7 @@ import logging
import time
import ray
from ray.experimental.streaming.streaming import Environment
from ray.experimental.streaming.batched_queue import BatchedQueue
from ray.experimental.streaming.operator import OpType, PStrategy
from ray.streaming.streaming import Environment
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
@ -48,9 +46,6 @@ if __name__ == "__main__":
ray.init()
ray.register_custom_serializer(Record, use_dict=True)
ray.register_custom_serializer(BatchedQueue, use_pickle=True)
ray.register_custom_serializer(OpType, use_pickle=True)
ray.register_custom_serializer(PStrategy, use_pickle=True)
# A Ray streaming environment with the default configuration
env = Environment()

View file

@ -7,9 +7,8 @@ import logging
import time
import ray
from ray.experimental.streaming.streaming import Environment
from ray.experimental.streaming.batched_queue import BatchedQueue
from ray.experimental.streaming.operator import OpType, PStrategy
from ray.streaming.config import Config
from ray.streaming.streaming import Environment, Conf
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
@ -33,27 +32,25 @@ if __name__ == "__main__":
args = parser.parse_args()
ray.init()
ray.register_custom_serializer(BatchedQueue, use_pickle=True)
ray.register_custom_serializer(OpType, use_pickle=True)
ray.register_custom_serializer(PStrategy, use_pickle=True)
ray.init(local_mode=False)
# A Ray streaming environment with the default configuration
env = Environment()
env = Environment(config=Conf(channel_type=Config.NATIVE_CHANNEL))
# Stream represents the ouput of the filter and
# can be forked into other dataflows
stream = env.read_text_file(args.input_file) \
.shuffle() \
.flat_map(splitter) \
.set_parallelism(4) \
.set_parallelism(2) \
.filter(filter_fn) \
.set_parallelism(2) \
.inspect(print) # Prints the contents of the
.inspect(lambda x: print("result", x)) # Prints the contents of the
# stream to stdout
start = time.time()
env_handle = env.execute()
ray.get(env_handle) # Stay alive until execution finishes
env.wait_finish()
end = time.time()
logger.info("Elapsed time: {} secs".format(end - start))
logger.debug("Output stream id: {}".format(stream.id))

View file

@ -5,12 +5,10 @@ from __future__ import print_function
import argparse
import logging
import time
import wikipedia
import ray
from ray.experimental.streaming.streaming import Environment
from ray.experimental.streaming.batched_queue import BatchedQueue
from ray.experimental.streaming.operator import OpType, PStrategy
import wikipedia
from ray.streaming.streaming import Environment
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
@ -86,9 +84,6 @@ if __name__ == "__main__":
titles_file = str(args.titles_file)
ray.init()
ray.register_custom_serializer(BatchedQueue, use_pickle=True)
ray.register_custom_serializer(OpType, use_pickle=True)
ray.register_custom_serializer(PStrategy, use_pickle=True)
# A Ray streaming environment with the default configuration
env = Environment()
@ -108,6 +103,7 @@ if __name__ == "__main__":
start = time.time()
env_handle = env.execute() # Deploys and executes the dataflow
ray.get(env_handle) # Stay alive until execution finishes
env.wait_finish()
end = time.time()
logger.info("Elapsed time: {} secs".format(end - start))
logger.debug("Output stream id: {}".format(stream.id))

View file

View file

@ -0,0 +1,153 @@
# cython: profile=False
# distutils: language = c++
# cython: embedsignature = True
# cython: language_level = 3
# flake8: noqa
from libc.stdint cimport *
from libcpp cimport bool as c_bool
from libcpp.memory cimport shared_ptr
from libcpp.vector cimport vector as c_vector
from libcpp.list cimport list as c_list
from cpython cimport PyObject
cimport cpython
cdef inline object PyObject_to_object(PyObject* o):
# Cast to "object" increments reference count
cdef object result = <object> o
cpython.Py_DECREF(result)
return result
from ray.includes.common cimport (
CLanguage,
CRayObject,
CRayStatus,
CRayFunction
)
from ray.includes.unique_ids cimport (
CActorID,
CJobID,
CTaskID,
CObjectID,
)
from ray.includes.libcoreworker cimport CCoreWorker
cdef extern from "status.h" namespace "ray::streaming" nogil:
cdef cppclass CStreamingStatus "ray::streaming::StreamingStatus":
pass
cdef CStreamingStatus StatusOK "ray::streaming::StreamingStatus::OK"
cdef CStreamingStatus StatusReconstructTimeOut "ray::streaming::StreamingStatus::ReconstructTimeOut"
cdef CStreamingStatus StatusQueueIdNotFound "ray::streaming::StreamingStatus::QueueIdNotFound"
cdef CStreamingStatus StatusResubscribeFailed "ray::streaming::StreamingStatus::ResubscribeFailed"
cdef CStreamingStatus StatusEmptyRingBuffer "ray::streaming::StreamingStatus::EmptyRingBuffer"
cdef CStreamingStatus StatusFullChannel "ray::streaming::StreamingStatus::FullChannel"
cdef CStreamingStatus StatusNoSuchItem "ray::streaming::StreamingStatus::NoSuchItem"
cdef CStreamingStatus StatusInitQueueFailed "ray::streaming::StreamingStatus::InitQueueFailed"
cdef CStreamingStatus StatusGetBundleTimeOut "ray::streaming::StreamingStatus::GetBundleTimeOut"
cdef CStreamingStatus StatusSkipSendEmptyMessage "ray::streaming::StreamingStatus::SkipSendEmptyMessage"
cdef CStreamingStatus StatusInterrupted "ray::streaming::StreamingStatus::Interrupted"
cdef CStreamingStatus StatusWaitQueueTimeOut "ray::streaming::StreamingStatus::WaitQueueTimeOut"
cdef CStreamingStatus StatusOutOfMemory "ray::streaming::StreamingStatus::OutOfMemory"
cdef CStreamingStatus StatusInvalid "ray::streaming::StreamingStatus::Invalid"
cdef CStreamingStatus StatusUnknownError "ray::streaming::StreamingStatus::UnknownError"
cdef CStreamingStatus StatusTailStatus "ray::streaming::StreamingStatus::TailStatus"
cdef cppclass CStreamingCommon "ray::streaming::StreamingCommon":
void SetConfig(const uint8_t *, uint32_t size)
cdef extern from "runtime_context.h" namespace "ray::streaming" nogil:
cdef cppclass CRuntimeContext "ray::streaming::RuntimeContext":
CRuntimeContext()
void SetConfig(const uint8_t *data, uint32_t size)
inline void MarkMockTest()
inline c_bool IsMockTest()
cdef extern from "message/message.h" namespace "ray::streaming" nogil:
cdef cppclass CStreamingMessageType "ray::streaming::StreamingMessageType":
pass
cdef CStreamingMessageType MessageTypeBarrier "ray::streaming::StreamingMessageType::Barrier"
cdef CStreamingMessageType MessageTypeMessage "ray::streaming::StreamingMessageType::Message"
cdef cppclass CStreamingMessage "ray::streaming::StreamingMessage":
inline uint8_t *RawData() const
inline uint32_t GetDataSize() const
inline CStreamingMessageType GetMessageType() const
inline uint64_t GetMessageSeqId() const
cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil:
cdef cppclass CStreamingMessageBundleType "ray::streaming::StreamingMessageBundleType":
pass
cdef CStreamingMessageBundleType BundleTypeEmpty "ray::streaming::StreamingMessageBundleType::Empty"
cdef CStreamingMessageBundleType BundleTypeBarrier "ray::streaming::StreamingMessageBundleType::Barrier"
cdef CStreamingMessageBundleType BundleTypeBundle "ray::streaming::StreamingMessageBundleType::Bundle"
cdef cppclass CStreamingMessageBundleMeta "ray::streaming::StreamingMessageBundleMeta":
CStreamingMessageBundleMeta()
inline uint64_t GetMessageBundleTs() const
inline uint64_t GetLastMessageId() const
inline uint32_t GetMessageListSize() const
inline CStreamingMessageBundleType GetBundleType() const
inline c_bool IsBarrier()
inline c_bool IsBundle()
ctypedef shared_ptr[CStreamingMessageBundleMeta] CStreamingMessageBundleMetaPtr
uint32_t kMessageBundleHeaderSize "ray::streaming::kMessageBundleHeaderSize"
cdef cppclass CStreamingMessageBundle "ray::streaming::StreamingMessageBundle"(CStreamingMessageBundleMeta):
@staticmethod
void GetMessageListFromRawData(const uint8_t *data, uint32_t size, uint32_t msg_nums,
c_list[shared_ptr[CStreamingMessage]] &msg_list);
cdef extern from "queue/queue_client.h" namespace "ray::streaming" nogil:
cdef cppclass CReaderClient "ray::streaming::ReaderClient":
CReaderClient(CCoreWorker *core_worker,
CRayFunction &async_func,
CRayFunction &sync_func)
void OnReaderMessage(shared_ptr[CLocalMemoryBuffer] buffer);
shared_ptr[CLocalMemoryBuffer] OnReaderMessageSync(shared_ptr[CLocalMemoryBuffer] buffer);
cdef cppclass CWriterClient "ray::streaming::WriterClient":
CWriterClient(CCoreWorker *core_worker,
CRayFunction &async_func,
CRayFunction &sync_func)
void OnWriterMessage(shared_ptr[CLocalMemoryBuffer] buffer);
shared_ptr[CLocalMemoryBuffer] OnWriterMessageSync(shared_ptr[CLocalMemoryBuffer] buffer);
cdef extern from "data_reader.h" namespace "ray::streaming" nogil:
cdef cppclass CDataBundle "ray::streaming::DataBundle":
uint8_t *data
uint32_t data_size
CObjectID c_from "from"
uint64_t seq_id
CStreamingMessageBundleMetaPtr meta
cdef cppclass CDataReader "ray::streaming::DataReader"(CStreamingCommon):
CDataReader(shared_ptr[CRuntimeContext] &runtime_context)
void Init(const c_vector[CObjectID] &input_ids,
const c_vector[CActorID] &actor_ids,
const c_vector[uint64_t] &seq_ids,
const c_vector[uint64_t] &msg_ids,
int64_t timer_interval);
CStreamingStatus GetBundle(const uint32_t timeout_ms,
shared_ptr[CDataBundle] &message)
void Stop()
cdef extern from "data_writer.h" namespace "ray::streaming" nogil:
cdef cppclass CDataWriter "ray::streaming::DataWriter"(CStreamingCommon):
CDataWriter(shared_ptr[CRuntimeContext] &runtime_context)
CStreamingStatus Init(const c_vector[CObjectID] &channel_ids,
const c_vector[CActorID] &actor_ids,
const c_vector[uint64_t] &message_ids,
const c_vector[uint64_t] &queue_size_vec);
long WriteMessageToBufferRing(
const CObjectID &q_id, uint8_t *data, uint32_t data_size)
void Run()
void Stop()
cdef extern from "ray/common/buffer.h" nogil:
cdef cppclass CLocalMemoryBuffer "ray::LocalMemoryBuffer":
uint8_t *Data() const
size_t Size() const

View file

@ -0,0 +1,323 @@
# flake8: noqa
from libc.stdint cimport *
from libcpp cimport bool as c_bool
from libcpp.memory cimport shared_ptr, make_shared, dynamic_pointer_cast
from libcpp.string cimport string as c_string
from libcpp.vector cimport vector as c_vector
from libcpp.list cimport list as c_list
from ray.includes.common cimport (
CRayFunction,
LANGUAGE_PYTHON,
CBuffer
)
from ray.includes.unique_ids cimport (
CActorID,
CObjectID
)
from ray._raylet cimport (
Buffer,
CoreWorker,
ActorID,
ObjectID,
string_vector_from_list
)
from ray.includes.libcoreworker cimport CCoreWorker
cimport ray.streaming.includes.libstreaming as libstreaming
from ray.streaming.includes.libstreaming cimport (
CStreamingStatus,
CStreamingMessage,
CStreamingMessageBundle,
CRuntimeContext,
CDataBundle,
CDataWriter,
CDataReader,
CReaderClient,
CWriterClient,
CLocalMemoryBuffer,
)
import logging
from ray.function_manager import FunctionDescriptor
channel_logger = logging.getLogger(__name__)
cdef class ReaderClient:
cdef:
CReaderClient *client
def __cinit__(self,
CoreWorker worker,
async_func: FunctionDescriptor,
sync_func: FunctionDescriptor):
cdef:
CCoreWorker *core_worker = worker.core_worker.get()
CRayFunction async_native_func
CRayFunction sync_native_func
async_native_func = CRayFunction(
LANGUAGE_PYTHON, string_vector_from_list(async_func.get_function_descriptor_list()))
sync_native_func = CRayFunction(
LANGUAGE_PYTHON, string_vector_from_list(sync_func.get_function_descriptor_list()))
self.client = new CReaderClient(core_worker, async_native_func, sync_native_func)
def __dealloc__(self):
del self.client
self.client = NULL
def on_reader_message(self, const unsigned char[:] value):
cdef:
size_t size = value.nbytes
shared_ptr[CLocalMemoryBuffer] local_buf = \
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
with nogil:
self.client.OnReaderMessage(local_buf)
def on_reader_message_sync(self, const unsigned char[:] value):
cdef:
size_t size = value.nbytes
shared_ptr[CLocalMemoryBuffer] local_buf = \
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
shared_ptr[CLocalMemoryBuffer] result_buffer
with nogil:
result_buffer = self.client.OnReaderMessageSync(local_buf)
return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer))
cdef class WriterClient:
cdef:
CWriterClient * client
def __cinit__(self,
CoreWorker worker,
async_func: FunctionDescriptor,
sync_func: FunctionDescriptor):
cdef:
CCoreWorker *core_worker = worker.core_worker.get()
CRayFunction async_native_func
CRayFunction sync_native_func
async_native_func = CRayFunction(
LANGUAGE_PYTHON, string_vector_from_list(async_func.get_function_descriptor_list()))
sync_native_func = CRayFunction(
LANGUAGE_PYTHON, string_vector_from_list(sync_func.get_function_descriptor_list()))
self.client = new CWriterClient(core_worker, async_native_func, sync_native_func)
def __dealloc__(self):
del self.client
self.client = NULL
def on_writer_message(self, const unsigned char[:] value):
cdef:
size_t size = value.nbytes
shared_ptr[CLocalMemoryBuffer] local_buf = \
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
with nogil:
self.client.OnWriterMessage(local_buf)
def on_writer_message_sync(self, const unsigned char[:] value):
cdef:
size_t size = value.nbytes
shared_ptr[CLocalMemoryBuffer] local_buf = \
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
shared_ptr[CLocalMemoryBuffer] result_buffer
with nogil:
result_buffer = self.client.OnWriterMessageSync(local_buf)
return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer))
cdef class DataWriter:
cdef:
CDataWriter *writer
def __init__(self):
raise Exception("use create() to create DataWriter")
@staticmethod
def create(list py_output_channels,
list output_actor_ids: list[ActorID],
uint64_t queue_size,
list py_msg_ids,
bytes config_bytes,
c_bool is_mock):
cdef:
c_vector[CObjectID] channel_ids = bytes_list_to_qid_vec(py_output_channels)
c_vector[CActorID] actor_ids
c_vector[uint64_t] msg_ids
CDataWriter *c_writer
cdef const unsigned char[:] config_data
for actor_id in output_actor_ids:
actor_ids.push_back((<ActorID>actor_id).data)
for py_msg_id in py_msg_ids:
msg_ids.push_back(<uint64_t>py_msg_id)
cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]()
if is_mock:
ctx.get().MarkMockTest()
if config_bytes:
config_data = config_bytes
channel_logger.info("load config, config bytes size: %s", config_data.nbytes)
ctx.get().SetConfig(<uint8_t *>(&config_data[0]), config_data.nbytes)
c_writer = new CDataWriter(ctx)
cdef:
c_vector[CObjectID] remain_id_vec
c_vector[uint64_t] queue_size_vec
for i in range(channel_ids.size()):
queue_size_vec.push_back(queue_size)
cdef CStreamingStatus status = c_writer.Init(channel_ids, actor_ids, msg_ids, queue_size_vec)
if remain_id_vec.size() != 0:
channel_logger.warning("failed queue amounts => %s", remain_id_vec.size())
if <uint32_t>status != <uint32_t> libstreaming.StatusOK:
msg = "initialize writer failed, status={}".format(<uint32_t>status)
channel_logger.error(msg)
del c_writer
import ray.streaming.runtime.transfer as transfer
raise transfer.ChannelInitException(msg, qid_vector_to_list(remain_id_vec))
c_writer.Run()
channel_logger.info("create native writer succeed")
cdef DataWriter writer = DataWriter.__new__(DataWriter)
writer.writer = c_writer
return writer
def __dealloc__(self):
if self.writer != NULL:
del self.writer
channel_logger.info("deleted DataWriter")
self.writer = NULL
def write(self, ObjectID qid, const unsigned char[:] value):
"""support zero-copy bytes, bytearray, array of unsigned char"""
cdef:
CObjectID native_id = qid.data
uint64_t msg_id
uint8_t *data = <uint8_t *>(&value[0])
uint32_t size = value.nbytes
with nogil:
msg_id = self.writer.WriteMessageToBufferRing(native_id, data, size)
return msg_id
def stop(self):
self.writer.Stop()
channel_logger.info("stopped DataWriter")
cdef class DataReader:
cdef:
CDataReader *reader
readonly bytes meta
readonly bytes data
def __init__(self):
raise Exception("use create() to create DataReader")
@staticmethod
def create(list py_input_queues,
list input_actor_ids: list[ActorID],
list py_seq_ids,
list py_msg_ids,
int64_t timer_interval,
c_bool is_recreate,
bytes config_bytes,
c_bool is_mock):
cdef:
c_vector[CObjectID] queue_id_vec = bytes_list_to_qid_vec(py_input_queues)
c_vector[CActorID] actor_ids
c_vector[uint64_t] seq_ids
c_vector[uint64_t] msg_ids
CDataReader *c_reader
cdef const unsigned char[:] config_data
for actor_id in input_actor_ids:
actor_ids.push_back((<ActorID>actor_id).data)
for py_seq_id in py_seq_ids:
seq_ids.push_back(<uint64_t>py_seq_id)
for py_msg_id in py_msg_ids:
msg_ids.push_back(<uint64_t>py_msg_id)
cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]()
if config_bytes:
config_data = config_bytes
channel_logger.info("load config, config bytes size: %s", config_data.nbytes)
ctx.get().SetConfig(<uint8_t *>(&(config_data[0])), config_data.nbytes)
if is_mock:
ctx.get().MarkMockTest()
c_reader = new CDataReader(ctx)
c_reader.Init(queue_id_vec, actor_ids, seq_ids, msg_ids, timer_interval)
channel_logger.info("create native reader succeed")
cdef DataReader reader = DataReader.__new__(DataReader)
reader.reader = c_reader
return reader
def __dealloc__(self):
if self.reader != NULL:
del self.reader
channel_logger.info("deleted DataReader")
self.reader = NULL
def read(self, uint32_t timeout_millis):
cdef:
shared_ptr[CDataBundle] bundle
CStreamingStatus status
with nogil:
status = self.reader.GetBundle(timeout_millis, bundle)
cdef uint32_t bundle_type = <uint32_t>(bundle.get().meta.get().GetBundleType())
if <uint32_t> status != <uint32_t> libstreaming.StatusOK:
if <uint32_t> status == <uint32_t> libstreaming.StatusInterrupted:
# avoid cyclic import
import ray.streaming.runtime.transfer as transfer
raise transfer.ChannelInterruptException("reader interrupted")
elif <uint32_t> status == <uint32_t> libstreaming.StatusInitQueueFailed:
raise Exception("init channel failed")
elif <uint32_t> status == <uint32_t> libstreaming.StatusWaitQueueTimeOut:
raise Exception("wait channel object timeout")
cdef:
uint32_t msg_nums
CObjectID queue_id
c_list[shared_ptr[CStreamingMessage]] msg_list
list msgs = []
uint64_t timestamp
uint64_t msg_id
if bundle_type == <uint32_t> libstreaming.BundleTypeBundle:
msg_nums = bundle.get().meta.get().GetMessageListSize()
CStreamingMessageBundle.GetMessageListFromRawData(
bundle.get().data + libstreaming.kMessageBundleHeaderSize,
bundle.get().data_size - libstreaming.kMessageBundleHeaderSize,
msg_nums,
msg_list)
timestamp = bundle.get().meta.get().GetMessageBundleTs()
for msg in msg_list:
msg_bytes = msg.get().RawData()[:msg.get().GetDataSize()]
qid_bytes = queue_id.Binary()
msg_id = msg.get().GetMessageSeqId()
msgs.append((msg_bytes, msg_id, timestamp, qid_bytes))
return msgs
elif bundle_type == <uint32_t> libstreaming.BundleTypeEmpty:
return []
else:
raise Exception("Unsupported bundle type {}".format(bundle_type))
def stop(self):
self.reader.Stop()
channel_logger.info("stopped DataReader")
cdef c_vector[CObjectID] bytes_list_to_qid_vec(list py_queue_ids) except *:
assert len(py_queue_ids) > 0
cdef:
c_vector[CObjectID] queue_id_vec
c_string q_id_data
for q_id in py_queue_ids:
q_id_data = q_id
assert q_id_data.size() == CObjectID.Size()
obj_id = CObjectID.FromBinary(q_id_data)
queue_id_vec.push_back(obj_id)
return queue_id_vec
cdef c_vector[c_string] qid_vector_to_list(c_vector[CObjectID] queue_id_vec):
queues = []
for obj_id in queue_id_vec:
queues.append(obj_id.Binary())
return queues

View file

@ -0,0 +1,124 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import pickle
import threading
import ray
import ray.streaming._streaming as _streaming
from ray.streaming.config import Config
from ray.function_manager import FunctionDescriptor
from ray.streaming.communication import DataInput, DataOutput
logger = logging.getLogger(__name__)
@ray.remote
class JobWorker(object):
"""A streaming job worker.
Attributes:
worker_id: The id of the instance.
input_channels: The input gate that manages input channels of
the instance (see: DataInput in communication.py).
output_channels (DataOutput): The output gate that manages output
channels of the instance (see: DataOutput in communication.py).
the operator instance.
"""
def __init__(self, worker_id, operator, input_channels, output_channels):
self.env = None
self.worker_id = worker_id
self.operator = operator
processor_name = operator.processor_class.__name__
processor_instance = operator.processor_class(operator)
self.processor_name = processor_name
self.processor_instance = processor_instance
self.input_channels = input_channels
self.output_channels = output_channels
self.input_gate = None
self.output_gate = None
self.reader_client = None
self.writer_client = None
def init(self, env):
"""init streaming actor"""
env = pickle.loads(env)
self.env = env
logger.info("init operator instance %s", self.processor_name)
if env.config.channel_type == Config.NATIVE_CHANNEL:
core_worker = ray.worker.global_worker.core_worker
reader_async_func = FunctionDescriptor(
__name__, self.on_reader_message.__name__,
self.__class__.__name__)
reader_sync_func = FunctionDescriptor(
__name__, self.on_reader_message_sync.__name__,
self.__class__.__name__)
self.reader_client = _streaming.ReaderClient(
core_worker, reader_async_func, reader_sync_func)
writer_async_func = FunctionDescriptor(
__name__, self.on_writer_message.__name__,
self.__class__.__name__)
writer_sync_func = FunctionDescriptor(
__name__, self.on_writer_message_sync.__name__,
self.__class__.__name__)
self.writer_client = _streaming.WriterClient(
core_worker, writer_async_func, writer_sync_func)
if len(self.input_channels) > 0:
self.input_gate = DataInput(env, self.input_channels)
self.input_gate.init()
if len(self.output_channels) > 0:
self.output_gate = DataOutput(
env, self.output_channels,
self.operator.partitioning_strategies)
self.output_gate.init()
logger.info("init operator instance %s succeed", self.processor_name)
return True
# Starts the actor
def start(self):
self.t = threading.Thread(target=self.run, daemon=True)
self.t.start()
actor_id = ray.worker.global_worker.actor_id
logger.info("%s %s started, actor id %s", self.__class__.__name__,
self.processor_name, actor_id)
def run(self):
logger.info("%s start running", self.processor_name)
self.processor_instance.run(self.input_gate, self.output_gate)
logger.info("%s finished running", self.processor_name)
self.close()
def close(self):
if self.input_gate:
self.input_gate.close()
if self.output_gate:
self.output_gate.close()
def is_finished(self):
return not self.t.is_alive()
def on_reader_message(self, buffer: bytes):
"""used in direct call mode"""
self.reader_client.on_reader_message(buffer)
def on_reader_message_sync(self, buffer: bytes):
"""used in direct call mode"""
if self.reader_client is None:
return b" " * 4 # special flag to indicate this actor not ready
result = self.reader_client.on_reader_message_sync(buffer)
return result.to_pybytes()
def on_writer_message(self, buffer: bytes):
"""used in direct call mode"""
self.writer_client.on_writer_message(buffer)
def on_writer_message_sync(self, buffer: bytes):
"""used in direct call mode"""
if self.writer_client is None:
return b" " * 4 # special flag to indicate this actor not ready
result = self.writer_client.on_writer_message_sync(buffer)
return result.to_pybytes()

View file

@ -5,6 +5,8 @@ from __future__ import print_function
import enum
import logging
import cloudpickle
logger = logging.getLogger(__name__)
logger.setLevel("DEBUG")
@ -52,16 +54,18 @@ class OpType(enum.Enum):
class Operator(object):
def __init__(self,
id,
type,
op_type,
processor_class,
name="",
logic=None,
num_instances=1,
other=None,
state_actor=None):
self.id = id
self.type = type
self.type = op_type
self.processor_class = processor_class
self.name = name
self.logic = logic # The operator's logic
self._logic = cloudpickle.dumps(logic) # The operator's logic
self.num_instances = num_instances
# One partitioning strategy per downstream operator (default: forward)
self.partitioning_strategies = {}
@ -96,10 +100,14 @@ class Operator(object):
self.partitioning_strategies = strategies
def print(self):
log = "Operator<\nID = {}\nName = {}\nType = {}\n"
log = "Operator<\nID = {}\nName = {}\nprocessor_class = {}\n"
log += "Logic = {}\nNumber_of_Instances = {}\n"
log += "Partitioning_Scheme = {}\nOther_Args = {}>\n"
logger.debug(
log.format(self.id, self.name, self.type, self.logic,
log.format(self.id, self.name, self.processor_class, self.logic,
self.num_instances, self.partitioning_strategies,
self.other_args))
@property
def logic(self):
return cloudpickle.loads(self._logic)

View file

@ -0,0 +1,226 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import sys
import time
import types
logger = logging.getLogger(__name__)
logger.setLevel("INFO")
def _identity(element):
return element
class ReadTextFile:
"""A source operator instance that reads a text file line by line.
Attributes:
filepath (string): The path to the input file.
"""
def __init__(self, operator):
self.filepath = operator.other_args
# TODO (john): Handle possible exception here
self.reader = open(self.filepath, "r")
# Read input file line by line
def run(self, input_gate, output_gate):
while True:
record = self.reader.readline()
# Reader returns empty string ('') on EOF
if not record:
self.reader.close()
return
output_gate.push(
record[:-1]) # Push after removing newline characters
class Map:
"""A map operator instance that applies a user-defined
stream transformation.
A map produces exactly one output record for each record in
the input stream.
"""
def __init__(self, operator):
self.map_fn = operator.logic
# Applies the mapper each record of the input stream(s)
# and pushes resulting records to the output stream(s)
def run(self, input_gate, output_gate):
elements = 0
while True:
record = input_gate.pull()
if record is None:
return
output_gate.push(self.map_fn(record))
elements += 1
class FlatMap:
"""A map operator instance that applies a user-defined
stream transformation.
A flatmap produces one or more output records for each record in
the input stream.
Attributes:
flatmap_fn (function): The user-defined function.
"""
def __init__(self, operator):
self.flatmap_fn = operator.logic
# Applies the splitter to the records of the input stream(s)
# and pushes resulting records to the output stream(s)
def run(self, input_gate, output_gate):
while True:
record = input_gate.pull()
if record is None:
return
output_gate.push_all(self.flatmap_fn(record))
class Filter:
"""A filter operator instance that applies a user-defined filter to
each record of the stream.
Output records are those that pass the filter, i.e. those for which
the filter function returns True.
Attributes:
filter_fn (function): The user-defined boolean function.
"""
def __init__(self, operator):
self.filter_fn = operator.logic
# Applies the filter to the records of the input stream(s)
# and pushes resulting records to the output stream(s)
def run(self, input_gate, output_gate):
while True:
record = input_gate.pull()
if record is None:
return
if self.filter_fn(record):
output_gate.push(record)
class Inspect:
"""A inspect operator instance that inspects the content of the stream.
Inspect is useful for printing the records in the stream.
"""
def __init__(self, operator):
self.inspect_fn = operator.logic
def run(self, input_gate, output_gate):
# Applies the inspect logic (e.g. print) to the records of
# the input stream(s)
# and leaves stream unaffected by simply pushing the records to
# the output stream(s)
while True:
record = input_gate.pull()
if record is None:
return
if output_gate:
output_gate.push(record)
self.inspect_fn(record)
class Reduce:
"""A reduce operator instance that combines a new value for a key
with the last reduced one according to a user-defined logic.
"""
def __init__(self, operator):
self.reduce_fn = operator.logic
# Set the attribute selector
self.attribute_selector = operator.other_args
if self.attribute_selector is None:
self.attribute_selector = _identity
elif isinstance(self.attribute_selector, int):
self.key_index = self.attribute_selector
self.attribute_selector =\
lambda record: record[self.attribute_selector]
elif isinstance(self.attribute_selector, str):
self.attribute_selector =\
lambda record: vars(record)[self.attribute_selector]
elif not isinstance(self.attribute_selector, types.FunctionType):
sys.exit("Unrecognized or unsupported key selector.")
self.state = {} # key -> value
# Combines the input value for a key with the last reduced
# value for that key to produce a new value.
# Outputs the result as (key,new value)
def run(self, input_gate, output_gate):
while True:
record = input_gate.pull()
if record is None:
return
key, rest = record
new_value = self.attribute_selector(rest)
# TODO (john): Is there a way to update state with
# a single dictionary lookup?
try:
old_value = self.state[key]
new_value = self.reduce_fn(old_value, new_value)
self.state[key] = new_value
except KeyError: # Key does not exist in state
self.state.setdefault(key, new_value)
output_gate.push((key, new_value))
# Returns the state of the actor
def get_state(self):
return self.state
class KeyBy:
"""A key_by operator instance that physically partitions the
stream based on a key.
"""
def __init__(self, operator):
# Set the key selector
self.key_selector = operator.other_args
if isinstance(self.key_selector, int):
self.key_selector = lambda r: r[self.key_selector]
elif isinstance(self.key_selector, str):
self.key_selector = lambda record: vars(record)[self.key_selector]
elif not isinstance(self.key_selector, types.FunctionType):
sys.exit("Unrecognized or unsupported key selector.")
# The actual partitioning is done by the output gate
def run(self, input_gate, output_gate):
while True:
record = input_gate.pull()
if record is None:
return
key = self.key_selector(record)
output_gate.push((key, record))
# A custom source actor
class Source:
def __init__(self, operator):
# The user-defined source with a get_next() method
self.source = operator.logic
# Starts the source by calling get_next() repeatedly
def run(self, input_gate, output_gate):
start = time.time()
elements = 0
while True:
record = self.source.get_next()
if not record:
logger.debug("[writer] puts per second: {}".format(
elements / (time.time() - start)))
return
output_gate.push(record)
elements += 1

View file

View file

@ -0,0 +1,291 @@
import logging
import random
from queue import Queue
from typing import List
import ray
import ray.streaming._streaming as _streaming
import ray.streaming.generated.streaming_pb2 as streaming_pb
from ray.actor import ActorHandle, ActorID
from ray.streaming.config import Config
CHANNEL_ID_LEN = 20
class ChannelID:
"""
ChannelID is used to identify a transfer channel between
a upstream worker and downstream worker.
"""
def __init__(self, channel_id_str: str):
"""
Args:
channel_id_str: string representation of channel id
"""
self.channel_id_str = channel_id_str
self.object_qid = ray.ObjectID(channel_id_str_to_bytes(channel_id_str))
def __eq__(self, other):
if other is None:
return False
if type(other) is ChannelID:
return self.channel_id_str == other.channel_id_str
else:
return False
def __hash__(self):
return hash(self.channel_id_str)
def __repr__(self):
return self.channel_id_str
@staticmethod
def gen_random_id():
"""Generate a random channel id string
"""
res = ""
for i in range(CHANNEL_ID_LEN * 2):
res += str(chr(random.randint(0, 5) + ord("A")))
return res
@staticmethod
def gen_id(from_index, to_index, ts):
"""Generate channel id, which is 20 character"""
channel_id = bytearray(20)
for i in range(11, 7, -1):
channel_id[i] = ts & 0xff
ts >>= 8
channel_id[16] = (from_index & 0xffff) >> 8
channel_id[17] = (from_index & 0xff)
channel_id[18] = (to_index & 0xffff) >> 8
channel_id[19] = (to_index & 0xff)
return channel_bytes_to_str(bytes(channel_id))
def channel_id_str_to_bytes(channel_id_str):
"""
Args:
channel_id_str: string representation of channel id
Returns:
bytes representation of channel id
"""
assert type(channel_id_str) in [str, bytes]
if isinstance(channel_id_str, bytes):
return channel_id_str
qid_bytes = bytes.fromhex(channel_id_str)
assert len(qid_bytes) == CHANNEL_ID_LEN
return qid_bytes
def channel_bytes_to_str(id_bytes):
"""
Args:
id_bytes: bytes representation of channel id
Returns:
string representation of channel id
"""
assert type(id_bytes) in [str, bytes]
if isinstance(id_bytes, str):
return id_bytes
return bytes.hex(id_bytes)
class DataMessage:
"""
DataMessage represents data between upstream and downstream operator
"""
def __init__(self,
body,
timestamp,
channel_id,
message_id_,
is_empty_message=False):
self.__body = body
self.__timestamp = timestamp
self.__channel_id = channel_id
self.__message_id = message_id_
self.__is_empty_message = is_empty_message
def __len__(self):
return len(self.__body)
def body(self):
"""Message data"""
return self.__body
def timestamp(self):
"""Get timestamp when item is written by upstream DataWriter
"""
return self.__timestamp
def channel_id(self):
"""Get string id of channel where data is coming from
"""
return self.__channel_id
def is_empty_message(self):
"""Whether this message is an empty message.
Upstream DataWriter will send an empty message when this is no data
in specified interval.
"""
return self.__is_empty_message
@property
def message_id(self):
return self.__message_id
logger = logging.getLogger(__name__)
class DataWriter:
"""Data Writer is a wrapper of streaming c++ DataWriter, which sends data
to downstream workers
"""
def __init__(self, output_channels, to_actors: List[ActorHandle],
conf: dict):
"""Get DataWriter of output channels
Args:
output_channels: output channels ids
to_actors: downstream output actors
Returns:
DataWriter
"""
assert len(output_channels) > 0
py_output_channels = [
channel_id_str_to_bytes(qid_str) for qid_str in output_channels
]
output_actor_ids: List[ActorID] = [
handle._ray_actor_id for handle in to_actors
]
channel_size = conf.get(Config.CHANNEL_SIZE,
Config.CHANNEL_SIZE_DEFAULT)
py_msg_ids = [0 for _ in range(len(output_channels))]
config_bytes = _to_native_conf(conf)
is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL
self.writer = _streaming.DataWriter.create(
py_output_channels, output_actor_ids, channel_size, py_msg_ids,
config_bytes, is_mock)
logger.info("create DataWriter succeed")
def write(self, channel_id: ChannelID, item: bytes):
"""Write data into native channel
Args:
channel_id: channel id
item: bytes data
Returns:
msg_id
"""
assert type(item) == bytes
msg_id = self.writer.write(channel_id.object_qid, item)
return msg_id
def stop(self):
logger.info("stopping channel writer.")
self.writer.stop()
# destruct DataWriter
self.writer = None
def close(self):
logger.info("closing channel writer.")
class DataReader:
"""Data Reader is wrapper of streaming c++ DataReader, which read data
from channels of upstream workers
"""
def __init__(self, input_channels: List, from_actors: List[ActorHandle],
conf: dict):
"""Get DataReader of input channels
Args:
input_channels: input channels
from_actors: upstream input actors
Returns:
DataReader
"""
assert len(input_channels) > 0
py_input_channels = [
channel_id_str_to_bytes(qid_str) for qid_str in input_channels
]
input_actor_ids: List[ActorID] = [
handle._ray_actor_id for handle in from_actors
]
py_seq_ids = [0 for _ in range(len(input_channels))]
py_msg_ids = [0 for _ in range(len(input_channels))]
timer_interval = int(conf.get(Config.TIMER_INTERVAL_MS, -1))
is_recreate = bool(conf.get(Config.IS_RECREATE, False))
config_bytes = _to_native_conf(conf)
self.__queue = Queue(10000)
is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL
self.reader = _streaming.DataReader.create(
py_input_channels, input_actor_ids, py_seq_ids, py_msg_ids,
timer_interval, is_recreate, config_bytes, is_mock)
logger.info("create DataReader succeed")
def read(self, timeout_millis):
"""Read data from channel
Args:
timeout_millis: timeout millis when there is no data in channel
for this duration
Returns:
channel item
"""
if self.__queue.empty():
msgs = self.reader.read(timeout_millis)
for msg in msgs:
msg_bytes, msg_id, timestamp, qid_bytes = msg
data_msg = DataMessage(msg_bytes, timestamp,
channel_bytes_to_str(qid_bytes), msg_id)
self.__queue.put(data_msg)
if self.__queue.empty():
return None
return self.__queue.get()
def stop(self):
logger.info("stopping Data Reader.")
self.reader.stop()
# destruct DataReader
self.reader = None
def close(self):
logger.info("closing Data Reader.")
def _to_native_conf(conf):
config = streaming_pb.StreamingConfig()
if Config.STREAMING_JOB_NAME in conf:
config.job_name = conf[Config.STREAMING_JOB_NAME]
if Config.TASK_JOB_ID in conf:
job_id = conf[Config.TASK_JOB_ID]
config.task_job_id = job_id.hex()
if Config.STREAMING_WORKER_NAME in conf:
config.worker_name = conf[Config.STREAMING_WORKER_NAME]
if Config.STREAMING_OP_NAME in conf:
config.op_name = conf[Config.STREAMING_OP_NAME]
# TODO set operator type
if Config.STREAMING_RING_BUFFER_CAPACITY in conf:
config.ring_buffer_capacity = \
conf[Config.STREAMING_RING_BUFFER_CAPACITY]
if Config.STREAMING_EMPTY_MESSAGE_INTERVAL in conf:
config.empty_message_interval = \
conf[Config.STREAMING_EMPTY_MESSAGE_INTERVAL]
logger.info("conf: %s", str(config))
return config.SerializeToString()
class ChannelInitException(Exception):
def __init__(self, msg, abnormal_channels):
self.abnormal_channels = abnormal_channels
self.msg = msg
class ChannelInterruptException(Exception):
def __init__(self, msg=None):
self.msg = msg

View file

@ -3,26 +3,24 @@ from __future__ import division
from __future__ import print_function
import logging
import pickle
import sys
import uuid
import time
import networkx as nx
from ray.experimental.streaming.communication import DataChannel, DataInput
from ray.experimental.streaming.communication import DataOutput, QueueConfig
from ray.experimental.streaming.operator import Operator, OpType
from ray.experimental.streaming.operator import PScheme, PStrategy
import ray.experimental.streaming.operator_instance as operator_instance
import ray
import ray.streaming.processor as processor
import ray.streaming.runtime.transfer as transfer
from ray.streaming.communication import DataChannel
from ray.streaming.config import Config
from ray.streaming.jobworker import JobWorker
from ray.streaming.operator import Operator, OpType
from ray.streaming.operator import PScheme, PStrategy
logger = logging.getLogger(__name__)
logger.setLevel("INFO")
# Generates UUIDs
def _generate_uuid():
return uuid.uuid4()
# Rolling sum's logic
def _sum(value_1, value_2):
return value_1 + value_2
@ -36,28 +34,224 @@ all_to_all_strategies = [
# Environment configuration
class Config(object):
class Conf(object):
"""Environment configuration.
This class includes all information about the configuration of the
streaming environment.
Attributes:
queue_config (QueueConfig): Batched Queue configuration
(see: communication.py)
A batched queue configuration includes the max queue size,
the size of each batch (in number of elements), the batch flush
timeout, and the number of batches to prefetch from plasma
parallelism (int): The number of isntances (actors) for each logical
dataflow operator (default: 1)
"""
def __init__(self, parallelism=1):
self.queue_config = QueueConfig()
def __init__(self, parallelism=1, channel_type=Config.MEMORY_CHANNEL):
self.parallelism = parallelism
self.channel_type = channel_type
# ...
class ExecutionGraph:
def __init__(self, env):
self.env = env
self.physical_topo = nx.DiGraph() # DAG
# Handles to all actors in the physical dataflow
self.actor_handles = []
# (op_id, op_instance_index) -> ActorID
self.actors_map = {}
# execution graph build time: milliseconds since epoch
self.build_time = 0
self.task_id_counter = 0
self.task_ids = {}
self.input_channels = {} # operator id -> input channels
self.output_channels = {} # operator id -> output channels
# Constructs and deploys a Ray actor of a specific type
# TODO (john): Actor placement information should be specified in
# the environment's configuration
def __generate_actor(self, instance_index, operator, input_channels,
output_channels):
"""Generates an actor that will execute a particular instance of
the logical operator
Attributes:
instance_index: The index of the instance the actor will execute.
operator: The metadata of the logical operator.
input_channels: The input channels of the instance.
output_channels The output channels of the instance.
"""
worker_id = (operator.id, instance_index)
# Record the physical dataflow graph (for debugging purposes)
self.__add_channel(worker_id, output_channels)
# Note direct_call only support pass by value
return JobWorker._remote(
args=[worker_id, operator, input_channels, output_channels],
is_direct_call=True)
# Constructs and deploys a Ray actor for each instance of
# the given operator
def __generate_actors(self, operator, upstream_channels,
downstream_channels):
"""Generates one actor for each instance of the given logical
operator.
Attributes:
operator (Operator): The logical operator metadata.
upstream_channels (list): A list of all upstream channels for
all instances of the operator.
downstream_channels (list): A list of all downstream channels
for all instances of the operator.
"""
num_instances = operator.num_instances
logger.info("Generating {} actors of type {}...".format(
num_instances, operator.type))
handles = []
for i in range(num_instances):
# Collect input and output channels for the particular instance
ip = [c for c in upstream_channels if c.dst_instance_index == i]
op = [c for c in downstream_channels if c.src_instance_index == i]
log = "Constructed {} input and {} output channels "
log += "for the {}-th instance of the {} operator."
logger.debug(log.format(len(ip), len(op), i, operator.type))
handle = self.__generate_actor(i, operator, ip, op)
if handle:
handles.append(handle)
self.actors_map[(operator.id, i)] = handle
return handles
# Adds a channel/edge to the physical dataflow graph
def __add_channel(self, actor_id, output_channels):
for c in output_channels:
dest_actor_id = (c.dst_operator_id, c.dst_instance_index)
self.physical_topo.add_edge(actor_id, dest_actor_id)
# Generates all required data channels between an operator
# and its downstream operators
def _generate_channels(self, operator):
"""Generates all output data channels
(see: DataChannel in communication.py) for all instances of
the given logical operator.
The function constructs one data channel for each pair of
communicating operator instances (instance_1,instance_2),
where instance_1 is an instance of the given operator and instance_2
is an instance of a direct downstream operator.
The number of total channels generated depends on the partitioning
strategy specified by the user.
"""
channels = {} # destination operator id -> channels
strategies = operator.partitioning_strategies
for dst_operator, p_scheme in strategies.items():
num_dest_instances = self.env.operators[dst_operator].num_instances
entry = channels.setdefault(dst_operator, [])
if p_scheme.strategy == PStrategy.Forward:
for i in range(operator.num_instances):
# ID of destination instance to connect
id = i % num_dest_instances
qid = self._gen_str_qid(operator.id, i, dst_operator, id)
c = DataChannel(operator.id, i, dst_operator, id, qid)
entry.append(c)
elif p_scheme.strategy in all_to_all_strategies:
for i in range(operator.num_instances):
for j in range(num_dest_instances):
qid = self._gen_str_qid(operator.id, i, dst_operator,
j)
c = DataChannel(operator.id, i, dst_operator, j, qid)
entry.append(c)
else:
# TODO (john): Add support for other partitioning strategies
sys.exit("Unrecognized or unsupported partitioning strategy.")
return channels
def _gen_str_qid(self, src_operator_id, src_instance_index,
dst_operator_id, dst_instance_index):
from_task_id = self.env.execution_graph.get_task_id(
src_operator_id, src_instance_index)
to_task_id = self.env.execution_graph.get_task_id(
dst_operator_id, dst_instance_index)
return transfer.ChannelID.gen_id(from_task_id, to_task_id,
self.build_time)
def _gen_task_id(self):
task_id = self.task_id_counter
self.task_id_counter += 1
return task_id
def get_task_id(self, op_id, op_instance_id):
return self.task_ids[(op_id, op_instance_id)]
def get_actor(self, op_id, op_instance_id):
return self.actors_map[(op_id, op_instance_id)]
# Prints the physical dataflow graph
def print_physical_graph(self):
logger.info("===================================")
logger.info("======Physical Dataflow Graph======")
logger.info("===================================")
# Print all data channels between operator instances
log = "(Source Operator ID,Source Operator Name,Source Instance ID)"
log += " --> "
log += "(Destination Operator ID,Destination Operator Name,"
log += "Destination Instance ID)"
logger.info(log)
for src_actor_id, dst_actor_id in self.physical_topo.edges:
src_operator_id, src_instance_index = src_actor_id
dst_operator_id, dst_instance_index = dst_actor_id
logger.info("({},{},{}) --> ({},{},{})".format(
src_operator_id, self.env.operators[src_operator_id].name,
src_instance_index, dst_operator_id,
self.env.operators[dst_operator_id].name, dst_instance_index))
def build_graph(self):
self.build_channels()
# to support cyclic reference serialization
try:
ray.register_custom_serializer(Environment, use_pickle=True)
ray.register_custom_serializer(ExecutionGraph, use_pickle=True)
ray.register_custom_serializer(OpType, use_pickle=True)
ray.register_custom_serializer(PStrategy, use_pickle=True)
except Exception:
# local mode can't use pickle
pass
# Each operator instance is implemented as a Ray actor
# Actors are deployed in topological order, as we traverse the
# logical dataflow from sources to sinks.
for node in nx.topological_sort(self.env.logical_topo):
operator = self.env.operators[node]
# Instantiate Ray actors
handles = self.__generate_actors(
operator, self.input_channels.get(node, []),
self.output_channels.get(node, []))
if handles:
self.actor_handles.extend(handles)
def build_channels(self):
self.build_time = int(time.time() * 1000)
# gen auto-incremented unique task id for every operator instance
for node in nx.topological_sort(self.env.logical_topo):
operator = self.env.operators[node]
for i in range(operator.num_instances):
operator_instance_id = (operator.id, i)
self.task_ids[operator_instance_id] = self._gen_task_id()
channels = {}
for node in nx.topological_sort(self.env.logical_topo):
operator = self.env.operators[node]
# Generate downstream data channels
downstream_channels = self._generate_channels(operator)
channels[node] = downstream_channels
# op_id -> channels
input_channels = {}
output_channels = {}
for op_id, all_downstream_channels in channels.items():
for dst_op_channels in all_downstream_channels.values():
for c in dst_op_channels:
dst = input_channels.setdefault(c.dst_operator_id, [])
dst.append(c)
src = output_channels.setdefault(c.src_operator_id, [])
src.append(c)
self.input_channels = input_channels
self.output_channels = output_channels
# The execution environment for a streaming job
class Environment(object):
"""A streaming environment.
@ -81,176 +275,18 @@ class Environment(object):
the streaming dataflow.
"""
def __init__(self, config=Config()):
def __init__(self, config=Conf()):
self.logical_topo = nx.DiGraph() # DAG
self.physical_topo = nx.DiGraph() # DAG
self.operators = {} # operator id --> operator object
self.config = config # Environment's configuration
self.topo_cleaned = False
# Handles to all actors in the physical dataflow
self.actor_handles = []
self.operator_id_counter = 0
self.execution_graph = None # set when executed
# Constructs and deploys a Ray actor of a specific type
# TODO (john): Actor placement information should be specified in
# the environment's configuration
def __generate_actor(self, instance_id, operator, input, output):
"""Generates an actor that will execute a particular instance of
the logical operator
Attributes:
instance_id (UUID): The id of the instance the actor will execute.
operator (Operator): The metadata of the logical operator.
input (DataInput): The input gate that manages input channels of
the instance (see: DataInput in communication.py).
input (DataOutput): The output gate that manages output channels
of the instance (see: DataOutput in communication.py).
"""
actor_id = (operator.id, instance_id)
# Record the physical dataflow graph (for debugging purposes)
self.__add_channel(actor_id, input, output)
# Select actor to construct
if operator.type == OpType.Source:
source = operator_instance.Source.remote(actor_id, operator, input,
output)
source.register_handle.remote(source)
return source.start.remote()
elif operator.type == OpType.Map:
map = operator_instance.Map.remote(actor_id, operator, input,
output)
map.register_handle.remote(map)
return map.start.remote()
elif operator.type == OpType.FlatMap:
flatmap = operator_instance.FlatMap.remote(actor_id, operator,
input, output)
flatmap.register_handle.remote(flatmap)
return flatmap.start.remote()
elif operator.type == OpType.Filter:
filter = operator_instance.Filter.remote(actor_id, operator, input,
output)
filter.register_handle.remote(filter)
return filter.start.remote()
elif operator.type == OpType.Reduce:
reduce = operator_instance.Reduce.remote(actor_id, operator, input,
output)
reduce.register_handle.remote(reduce)
return reduce.start.remote()
elif operator.type == OpType.TimeWindow:
pass
elif operator.type == OpType.KeyBy:
keyby = operator_instance.KeyBy.remote(actor_id, operator, input,
output)
keyby.register_handle.remote(keyby)
return keyby.start.remote()
elif operator.type == OpType.Sum:
sum = operator_instance.Reduce.remote(actor_id, operator, input,
output)
# Register target handle at state actor
state_actor = operator.state_actor
if state_actor is not None:
state_actor.register_target.remote(sum)
# Register own handle
sum.register_handle.remote(sum)
return sum.start.remote()
elif operator.type == OpType.Sink:
pass
elif operator.type == OpType.Inspect:
inspect = operator_instance.Inspect.remote(actor_id, operator,
input, output)
inspect.register_handle.remote(inspect)
return inspect.start.remote()
elif operator.type == OpType.ReadTextFile:
# TODO (john): Colocate the source with the input file
read = operator_instance.ReadTextFile.remote(
actor_id, operator, input, output)
read.register_handle.remote(read)
return read.start.remote()
else: # TODO (john): Add support for other types of operators
sys.exit("Unrecognized or unsupported {} operator type.".format(
operator.type))
# Constructs and deploys a Ray actor for each instance of
# the given operator
def __generate_actors(self, operator, upstream_channels,
downstream_channels):
"""Generates one actor for each instance of the given logical
operator.
Attributes:
operator (Operator): The logical operator metadata.
upstream_channels (list): A list of all upstream channels for
all instances of the operator.
downstream_channels (list): A list of all downstream channels
for all instances of the operator.
"""
num_instances = operator.num_instances
logger.info("Generating {} actors of type {}...".format(
num_instances, operator.type))
in_channels = upstream_channels.pop(
operator.id) if upstream_channels else []
handles = []
for i in range(num_instances):
# Collect input and output channels for the particular instance
ip = [
channel for channel in in_channels
if channel.dst_instance_id == i
] if in_channels else []
op = [
channel for channels_list in downstream_channels.values()
for channel in channels_list if channel.src_instance_id == i
]
log = "Constructed {} input and {} output channels "
log += "for the {}-th instance of the {} operator."
logger.debug(log.format(len(ip), len(op), i, operator.type))
input_gate = DataInput(ip)
output_gate = DataOutput(op, operator.partitioning_strategies)
handle = self.__generate_actor(i, operator, input_gate,
output_gate)
if handle:
handles.append(handle)
return handles
# Adds a channel/edge to the physical dataflow graph
def __add_channel(self, actor_id, input, output):
for dest_actor_id in output._destination_actor_ids():
self.physical_topo.add_edge(actor_id, dest_actor_id)
# Generates all required data channels between an operator
# and its downstream operators
def _generate_channels(self, operator):
"""Generates all output data channels
(see: DataChannel in communication.py) for all instances of
the given logical operator.
The function constructs one data channel for each pair of
communicating operator instances (instance_1,instance_2),
where instance_1 is an instance of the given operator and instance_2
is an instance of a direct downstream operator.
The number of total channels generated depends on the partitioning
strategy specified by the user.
"""
channels = {} # destination operator id -> channels
strategies = operator.partitioning_strategies
for dst_operator, p_scheme in strategies.items():
num_dest_instances = self.operators[dst_operator].num_instances
entry = channels.setdefault(dst_operator, [])
if p_scheme.strategy == PStrategy.Forward:
for i in range(operator.num_instances):
# ID of destination instance to connect
id = i % num_dest_instances
channel = DataChannel(self, operator.id, dst_operator, i,
id)
entry.append(channel)
elif p_scheme.strategy in all_to_all_strategies:
for i in range(operator.num_instances):
for j in range(num_dest_instances):
channel = DataChannel(self, operator.id, dst_operator,
i, j)
entry.append(channel)
else:
# TODO (john): Add support for other partitioning strategies
sys.exit("Unrecognized or unsupported partitioning strategy.")
return channels
def gen_operator_id(self):
op_id = self.operator_id_counter
self.operator_id_counter += 1
return op_id
# An edge denotes a flow of data between logical operators
# and may correspond to multiple data channels in the physical dataflow
@ -275,19 +311,15 @@ class Environment(object):
def set_parallelism(self, parallelism):
self.config.parallelism = parallelism
# Sets batched queue configuration for the environment
def set_queue_config(self, queue_config):
self.config.queue_config = queue_config
# Creates and registers a user-defined data source
# TODO (john): There should be different types of sources, e.g. sources
# reading from Kafka, text files, etc.
# TODO (john): Handle case where environment parallelism is set
def source(self, source):
source_id = _generate_uuid()
source_id = self.gen_operator_id()
source_stream = DataStream(self, source_id)
self.operators[source_id] = Operator(
source_id, OpType.Source, "Source", other=source)
source_id, OpType.Source, processor.Source, "Source", logic=source)
return source_stream
# Creates and registers a new data source that reads a
@ -296,10 +328,14 @@ class Environment(object):
# e.g. sources reading from Kafka, text files, etc.
# TODO (john): Handle case where environment parallelism is set
def read_text_file(self, filepath):
source_id = _generate_uuid()
source_id = self.gen_operator_id()
source_stream = DataStream(self, source_id)
self.operators[source_id] = Operator(
source_id, OpType.ReadTextFile, "Read Text File", other=filepath)
source_id,
OpType.ReadTextFile,
processor.ReadTextFile,
"Read Text File",
other=filepath)
return source_stream
# Constructs and deploys the physical dataflow
@ -312,24 +348,27 @@ class Environment(object):
# upstream instances, some of the downstream instances will not be
# used at all
# Each operator instance is implemented as a Ray actor
# Actors are deployed in topological order, as we traverse the
# logical dataflow from sources to sinks. At each step, data
# producers wait for acknowledge from consumers before starting
# generating data.
upstream_channels = {}
for node in nx.topological_sort(self.logical_topo):
operator = self.operators[node]
# Generate downstream data channels
downstream_channels = self._generate_channels(operator)
# Instantiate Ray actors
handles = self.__generate_actors(operator, upstream_channels,
downstream_channels)
if handles:
self.actor_handles.extend(handles)
upstream_channels.update(downstream_channels)
logger.debug("Running...")
return self.actor_handles
self.execution_graph = ExecutionGraph(self)
self.execution_graph.build_graph()
logger.info("init...")
# init
init_waits = []
for actor_handle in self.execution_graph.actor_handles:
init_waits.append(actor_handle.init.remote(pickle.dumps(self)))
for wait in init_waits:
assert ray.get(wait) is True
logger.info("running...")
# start
exec_handles = []
for actor_handle in self.execution_graph.actor_handles:
exec_handles.append(actor_handle.start.remote())
return exec_handles
def wait_finish(self):
for actor_handle in self.execution_graph.actor_handles:
while not ray.get(actor_handle.is_finished.remote()):
time.sleep(1)
# Prints the logical dataflow graph
def print_logical_graph(self):
@ -349,25 +388,6 @@ class Environment(object):
for downstream_node in downstream_neighbors:
self.operators[downstream_node].print()
# Prints the physical dataflow graph
def print_physical_graph(self):
logger.info("===================================")
logger.info("======Physical Dataflow Graph======")
logger.info("===================================")
# Print all data channels between operator instances
log = "(Source Operator ID,Source Operator Name,Source Instance ID)"
log += " --> "
log += "(Destination Operator ID,Destination Operator Name,"
log += "Destination Instance ID)"
logger.info(log)
for src_actor_id, dst_actor_id in self.physical_topo.edges:
src_operator_id, src_instance_id = src_actor_id
dst_operator_id, dst_instance_id = dst_actor_id
logger.info("({},{},{}) --> ({},{},{})".format(
src_operator_id, self.operators[src_operator_id].name,
src_instance_id, dst_operator_id,
self.operators[dst_operator_id].name, dst_instance_id))
# TODO (john): We also need KeyedDataStream and WindowedDataStream as
# subclasses of DataStream to prevent ill-defined logical dataflows
@ -389,14 +409,16 @@ class DataStream(object):
is_partitioned (bool): Denotes if there is a partitioning strategy
(e.g. shuffle) for the stream or not (default stategy: Forward).
"""
stream_id_counter = 0
def __init__(self,
environment,
source_id=None,
dest_id=None,
is_partitioned=False):
self.id = _generate_uuid()
self.env = environment
self.id = DataStream.stream_id_counter
DataStream.stream_id_counter += 1
self.src_operator_id = source_id
self.dst_operator_id = dest_id
# True if a partitioning strategy for this stream exists,
@ -448,17 +470,17 @@ class DataStream(object):
src_operator = self.env.operators[self.src_operator_id]
if self.is_partitioned is True:
partitioning, _ = src_operator._get_partition_strategy(self.id)
src_operator._set_partition_strategy(_generate_uuid(),
partitioning, operator.id)
src_operator._set_partition_strategy(self.id, partitioning,
operator.id)
elif src_operator.type == OpType.KeyBy:
# Set the output partitioning strategy to shuffle by key
partitioning = PScheme(PStrategy.ShuffleByKey)
src_operator._set_partition_strategy(_generate_uuid(),
partitioning, operator.id)
src_operator._set_partition_strategy(self.id, partitioning,
operator.id)
else: # No partitioning strategy has been defined - set default
partitioning = PScheme(PStrategy.Forward)
src_operator._set_partition_strategy(_generate_uuid(),
partitioning, operator.id)
src_operator._set_partition_strategy(self.id, partitioning,
operator.id)
return self.__expand()
# Sets the level of parallelism for an operator, i.e. its total
@ -525,8 +547,9 @@ class DataStream(object):
map_fn (function): The user-defined logic of the map.
"""
op = Operator(
_generate_uuid(),
self.env.gen_operator_id(),
OpType.Map,
processor.Map,
name,
map_fn,
num_instances=self.env.config.parallelism)
@ -541,8 +564,9 @@ class DataStream(object):
(e.g. split()).
"""
op = Operator(
_generate_uuid(),
self.env.gen_operator_id(),
OpType.FlatMap,
processor.FlatMap,
"FlatMap",
flatmap_fn,
num_instances=self.env.config.parallelism)
@ -558,8 +582,9 @@ class DataStream(object):
(assuming tuple records).
"""
op = Operator(
_generate_uuid(),
self.env.gen_operator_id(),
OpType.KeyBy,
processor.KeyBy,
"KeyBy",
other=key_selector,
num_instances=self.env.config.parallelism)
@ -574,8 +599,9 @@ class DataStream(object):
(assuming tuple records).
"""
op = Operator(
_generate_uuid(),
self.env.gen_operator_id(),
OpType.Reduce,
processor.Reduce,
"Sum",
reduce_fn,
num_instances=self.env.config.parallelism)
@ -590,8 +616,9 @@ class DataStream(object):
(assuming tuple records).
"""
op = Operator(
_generate_uuid(),
self.env.gen_operator_id(),
OpType.Sum,
processor.Reduce,
"Sum",
_sum,
other=attribute_selector,
@ -608,13 +635,7 @@ class DataStream(object):
Attributes:
window_width_ms (int): The length of the window in ms.
"""
op = Operator(
_generate_uuid(),
OpType.TimeWindow,
"TimeWindow",
num_instances=self.env.config.parallelism,
other=window_width_ms)
return self.__register(op)
raise Exception("time_window is unsupported")
# Registers filter operator to the environment
def filter(self, filter_fn):
@ -624,8 +645,9 @@ class DataStream(object):
filter_fn (function): The user-defined filter function.
"""
op = Operator(
_generate_uuid(),
self.env.gen_operator_id(),
OpType.Filter,
processor.Filter,
"Filter",
filter_fn,
num_instances=self.env.config.parallelism)
@ -634,8 +656,9 @@ class DataStream(object):
# TODO (john): Registers window join operator to the environment
def window_join(self, other_stream, join_attribute, window_width):
op = Operator(
_generate_uuid(),
self.env.gen_operator_id(),
OpType.WindowJoin,
processor.WindowJoin,
"WindowJoin",
num_instances=self.env.config.parallelism)
return self.__register(op)
@ -648,8 +671,9 @@ class DataStream(object):
inspect_logic (function): The user-defined inspect function.
"""
op = Operator(
_generate_uuid(),
self.env.gen_operator_id(),
OpType.Inspect,
processor.Inspect,
"Inspect",
inspect_logic,
num_instances=self.env.config.parallelism)
@ -661,8 +685,9 @@ class DataStream(object):
def sink(self):
"""Closes the stream with a sink operator."""
op = Operator(
_generate_uuid(),
self.env.gen_operator_id(),
OpType.Sink,
processor.Sink,
"Sink",
num_instances=self.env.config.parallelism)
return self.__register(op)

View file

View file

@ -0,0 +1,127 @@
import pickle
import threading
import time
import ray
import ray.streaming._streaming as _streaming
import ray.streaming.runtime.transfer as transfer
from ray.function_manager import FunctionDescriptor
from ray.streaming.config import Config
@ray.remote
class Worker:
def __init__(self):
core_worker = ray.worker.global_worker.core_worker
writer_async_func = FunctionDescriptor(
__name__, self.on_writer_message.__name__, self.__class__.__name__)
writer_sync_func = FunctionDescriptor(
__name__, self.on_writer_message_sync.__name__,
self.__class__.__name__)
self.writer_client = _streaming.WriterClient(
core_worker, writer_async_func, writer_sync_func)
reader_async_func = FunctionDescriptor(
__name__, self.on_reader_message.__name__, self.__class__.__name__)
reader_sync_func = FunctionDescriptor(
__name__, self.on_reader_message_sync.__name__,
self.__class__.__name__)
self.reader_client = _streaming.ReaderClient(
core_worker, reader_async_func, reader_sync_func)
self.writer = None
self.output_channel_id = None
self.reader = None
def init_writer(self, output_channel, reader_actor):
conf = {
Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context()
.current_driver_id,
Config.CHANNEL_TYPE: Config.NATIVE_CHANNEL
}
self.writer = transfer.DataWriter([output_channel],
[pickle.loads(reader_actor)], conf)
self.output_channel_id = transfer.ChannelID(output_channel)
def init_reader(self, input_channel, writer_actor):
conf = {
Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context()
.current_driver_id,
Config.CHANNEL_TYPE: Config.NATIVE_CHANNEL
}
self.reader = transfer.DataReader([input_channel],
[pickle.loads(writer_actor)], conf)
def start_write(self, msg_nums):
self.t = threading.Thread(
target=self.run_writer, args=[msg_nums], daemon=True)
self.t.start()
def run_writer(self, msg_nums):
for i in range(msg_nums):
self.writer.write(self.output_channel_id, pickle.dumps(i))
print("WriterWorker done.")
def start_read(self, msg_nums):
self.t = threading.Thread(
target=self.run_reader, args=[msg_nums], daemon=True)
self.t.start()
def run_reader(self, msg_nums):
count = 0
msg = None
while count != msg_nums:
item = self.reader.read(100)
if item is None:
time.sleep(0.01)
else:
msg = pickle.loads(item.body())
count += 1
assert msg == msg_nums - 1
print("ReaderWorker done.")
def is_finished(self):
return not self.t.is_alive()
def on_reader_message(self, buffer: bytes):
"""used in direct call mode"""
self.reader_client.on_reader_message(buffer)
def on_reader_message_sync(self, buffer: bytes):
"""used in direct call mode"""
if self.reader_client is None:
return b" " * 4 # special flag to indicate this actor not ready
result = self.reader_client.on_reader_message_sync(buffer)
return result.to_pybytes()
def on_writer_message(self, buffer: bytes):
"""used in direct call mode"""
self.writer_client.on_writer_message(buffer)
def on_writer_message_sync(self, buffer: bytes):
"""used in direct call mode"""
if self.writer_client is None:
return b" " * 4 # special flag to indicate this actor not ready
result = self.writer_client.on_writer_message_sync(buffer)
return result.to_pybytes()
def test_queue():
ray.init()
writer = Worker._remote(is_direct_call=True)
reader = Worker._remote(is_direct_call=True)
channel_id_str = transfer.ChannelID.gen_random_id()
inits = [
writer.init_writer.remote(channel_id_str, pickle.dumps(reader)),
reader.init_reader.remote(channel_id_str, pickle.dumps(writer))
]
ray.get(inits)
msg_nums = 1000
print("start read/write")
reader.start_read.remote(msg_nums)
writer.start_write.remote(msg_nums)
while not ray.get(reader.is_finished.remote()):
time.sleep(0.1)
ray.shutdown()
if __name__ == "__main__":
test_queue()

View file

@ -2,8 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.experimental.streaming.streaming import Environment
from ray.experimental.streaming.operator import OpType, PStrategy
from ray.streaming.streaming import Environment, ExecutionGraph
from ray.streaming.operator import OpType, PStrategy
def test_parallelism():
@ -169,18 +169,20 @@ def _test_channels(environment, expected_channels):
if operator.type == OpType.Map:
map_id = id
# Collect channels
environment.execution_graph = ExecutionGraph(environment)
environment.execution_graph.build_channels()
channels_per_destination = []
for operator in environment.operators.values():
channels_per_destination.append(
environment._generate_channels(operator))
environment.execution_graph._generate_channels(operator))
# Check actual connectivity
actual = []
for destination in channels_per_destination:
for channels in destination.values():
for channel in channels:
src_instance_id = channel.src_instance_id
dst_instance_id = channel.dst_instance_id
connection = (src_instance_id, dst_instance_id)
src_instance_index = channel.src_instance_index
dst_instance_index = channel.dst_instance_index
connection = (src_instance_index, dst_instance_index)
assert channel.dst_operator_id == map_id, (
channel.dst_operator_id, map_id)
actual.append(connection)
@ -205,6 +207,4 @@ def test_wordcount():
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
test_channel_generation()

View file

@ -0,0 +1,20 @@
import ray
from ray.streaming.config import Config
from ray.streaming.streaming import Environment, Conf
def test_word_count():
ray.init()
env = Environment(config=Conf(channel_type=Config.NATIVE_CHANNEL))
env.read_text_file(__file__) \
.set_parallelism(1) \
.filter(lambda x: "word" in x) \
.inspect(lambda x: print("result", x))
env_handle = env.execute()
ray.get(env_handle) # Stay alive until execution finishes
env.wait_finish()
ray.shutdown()
if __name__ == "__main__":
test_word_count()

274
streaming/src/channel.cc Normal file
View file

@ -0,0 +1,274 @@
#include "channel.h"
#include <unordered_map>
namespace ray {
namespace streaming {
ProducerChannel::ProducerChannel(std::shared_ptr<Config> &transfer_config,
ProducerChannelInfo &p_channel_info)
: transfer_config_(transfer_config), channel_info(p_channel_info) {}
ConsumerChannel::ConsumerChannel(std::shared_ptr<Config> &transfer_config,
ConsumerChannelInfo &c_channel_info)
: transfer_config_(transfer_config), channel_info(c_channel_info) {}
StreamingQueueProducer::StreamingQueueProducer(std::shared_ptr<Config> &transfer_config,
ProducerChannelInfo &p_channel_info)
: ProducerChannel(transfer_config, p_channel_info) {
STREAMING_LOG(INFO) << "Producer Init";
}
StreamingQueueProducer::~StreamingQueueProducer() {
STREAMING_LOG(INFO) << "Producer Destory";
}
StreamingStatus StreamingQueueProducer::CreateTransferChannel() {
CreateQueue();
uint64_t queue_last_seq_id = 0;
uint64_t last_message_id_in_queue = 0;
if (!last_message_id_in_queue) {
if (last_message_id_in_queue < channel_info.current_message_id) {
STREAMING_LOG(WARNING) << "last message id in queue : " << last_message_id_in_queue
<< " is less than message checkpoint loaded id : "
<< channel_info.current_message_id
<< ", an old queue object " << channel_info.channel_id
<< " was fond in store";
}
last_message_id_in_queue = channel_info.current_message_id;
}
if (queue_last_seq_id == static_cast<uint64_t>(-1)) {
queue_last_seq_id = 0;
}
channel_info.current_seq_id = queue_last_seq_id;
STREAMING_LOG(WARNING) << "existing last message id => " << last_message_id_in_queue
<< ", message id in channel => "
<< channel_info.current_message_id << ", queue last seq id => "
<< queue_last_seq_id;
channel_info.message_last_commit_id = last_message_id_in_queue;
return StreamingStatus::OK;
}
StreamingStatus StreamingQueueProducer::CreateQueue() {
STREAMING_LOG(INFO) << "CreateQueue qid: " << channel_info.channel_id
<< " data_size: " << channel_info.queue_size;
auto upstream_handler = ray::streaming::UpstreamQueueMessageHandler::GetService();
if (upstream_handler->UpstreamQueueExists(channel_info.channel_id)) {
RAY_LOG(INFO) << "StreamingQueueWriter::CreateQueue duplicate!!!";
return StreamingStatus::OK;
}
upstream_handler->SetPeerActorID(channel_info.channel_id, channel_info.actor_id);
queue_ = upstream_handler->CreateUpstreamQueue(
channel_info.channel_id, channel_info.actor_id, channel_info.queue_size);
STREAMING_CHECK(queue_ != nullptr);
std::vector<ObjectID> queue_ids, failed_queues;
queue_ids.push_back(channel_info.channel_id);
upstream_handler->WaitQueues(queue_ids, 10 * 1000, failed_queues);
STREAMING_LOG(INFO) << "q id => " << channel_info.channel_id << ", queue size => "
<< channel_info.queue_size;
return StreamingStatus::OK;
}
StreamingStatus StreamingQueueProducer::DestroyTransferChannel() {
return StreamingStatus::OK;
}
StreamingStatus StreamingQueueProducer::ClearTransferCheckpoint(
uint64_t checkpoint_id, uint64_t checkpoint_offset) {
return StreamingStatus::OK;
}
StreamingStatus StreamingQueueProducer::NotifyChannelConsumed(uint64_t channel_offset) {
queue_->SetQueueEvictionLimit(channel_offset);
return StreamingStatus::OK;
}
StreamingStatus StreamingQueueProducer::ProduceItemToChannel(uint8_t *data,
uint32_t data_size) {
Status status =
PushQueueItem(channel_info.current_seq_id + 1, data, data_size, current_time_ms());
if (status.code() != StatusCode::OK) {
STREAMING_LOG(DEBUG) << channel_info.channel_id << " => Queue is full"
<< " meesage => " << status.message();
// Assume that only status OutOfMemory and OK are acceptable.
// OutOfMemory means queue is full at that moment.
STREAMING_CHECK(status.code() == StatusCode::OutOfMemory)
<< "status => " << status.message()
<< ", perhaps data block is so large that it can't be stored in"
<< ", data block size => " << data_size;
return StreamingStatus::FullChannel;
}
return StreamingStatus::OK;
}
Status StreamingQueueProducer::PushQueueItem(uint64_t seq_id, uint8_t *data,
uint32_t data_size, uint64_t timestamp) {
STREAMING_LOG(INFO) << "StreamingQueueProducer::PushQueueItem:"
<< " qid: " << channel_info.channel_id << " seq_id: " << seq_id
<< " data_size: " << data_size;
Status status = queue_->Push(seq_id, data, data_size, timestamp, false);
if (status.IsOutOfMemory()) {
status = queue_->TryEvictItems();
if (!status.ok()) {
STREAMING_LOG(INFO) << "Evict fail.";
return status;
}
status = queue_->Push(seq_id, data, data_size, timestamp, false);
}
queue_->Send();
return status;
}
StreamingQueueConsumer::StreamingQueueConsumer(std::shared_ptr<Config> &transfer_config,
ConsumerChannelInfo &c_channel_info)
: ConsumerChannel(transfer_config, c_channel_info) {
STREAMING_LOG(INFO) << "Consumer Init";
}
StreamingQueueConsumer::~StreamingQueueConsumer() {
STREAMING_LOG(INFO) << "Consumer Destroy";
}
StreamingStatus StreamingQueueConsumer::CreateTransferChannel() {
auto downstream_handler = ray::streaming::DownstreamQueueMessageHandler::GetService();
STREAMING_LOG(INFO) << "GetQueue qid: " << channel_info.channel_id
<< " start_seq_id: " << channel_info.current_seq_id + 1;
if (downstream_handler->DownstreamQueueExists(channel_info.channel_id)) {
RAY_LOG(INFO) << "StreamingQueueReader::GetQueue duplicate!!!";
return StreamingStatus::OK;
}
downstream_handler->SetPeerActorID(channel_info.channel_id, channel_info.actor_id);
STREAMING_LOG(INFO) << "Create ReaderQueue " << channel_info.channel_id
<< " pull from start_seq_id: " << channel_info.current_seq_id + 1;
queue_ = downstream_handler->CreateDownstreamQueue(channel_info.channel_id,
channel_info.actor_id);
return StreamingStatus::OK;
}
StreamingStatus StreamingQueueConsumer::DestroyTransferChannel() {
return StreamingStatus::OK;
}
StreamingStatus StreamingQueueConsumer::ClearTransferCheckpoint(
uint64_t checkpoint_id, uint64_t checkpoint_offset) {
return StreamingStatus::OK;
}
StreamingStatus StreamingQueueConsumer::ConsumeItemFromChannel(uint64_t &offset_id,
uint8_t *&data,
uint32_t &data_size,
uint32_t timeout) {
STREAMING_LOG(INFO) << "GetQueueItem qid: " << channel_info.channel_id;
STREAMING_CHECK(queue_ != nullptr);
QueueItem item = queue_->PopPendingBlockTimeout(timeout * 1000);
if (item.SeqId() == QUEUE_INVALID_SEQ_ID) {
STREAMING_LOG(INFO) << "GetQueueItem timeout.";
data = nullptr;
data_size = 0;
offset_id = QUEUE_INVALID_SEQ_ID;
return StreamingStatus::OK;
}
data = item.Buffer()->Data();
offset_id = item.SeqId();
data_size = item.Buffer()->Size();
STREAMING_LOG(DEBUG) << "GetQueueItem qid: " << channel_info.channel_id
<< " seq_id: " << offset_id << " msg_id: " << item.MaxMsgId()
<< " data_size: " << data_size;
return StreamingStatus::OK;
}
StreamingStatus StreamingQueueConsumer::NotifyChannelConsumed(uint64_t offset_id) {
STREAMING_CHECK(queue_ != nullptr);
queue_->OnConsumed(offset_id);
return StreamingStatus::OK;
}
// For mock queue transfer
struct MockQueueItem {
uint64_t seq_id;
uint32_t data_size;
std::shared_ptr<uint8_t> data;
};
struct MockQueue {
std::unordered_map<ObjectID, std::shared_ptr<AbstractRingBufferImpl<MockQueueItem>>>
message_buffer_;
std::unordered_map<ObjectID, std::shared_ptr<AbstractRingBufferImpl<MockQueueItem>>>
consumed_buffer_;
};
static MockQueue mock_queue;
StreamingStatus MockProducer::CreateTransferChannel() {
mock_queue.message_buffer_[channel_info.channel_id] =
std::make_shared<RingBufferImplThreadSafe<MockQueueItem>>(500);
mock_queue.consumed_buffer_[channel_info.channel_id] =
std::make_shared<RingBufferImplThreadSafe<MockQueueItem>>(500);
return StreamingStatus::OK;
}
StreamingStatus MockProducer::DestroyTransferChannel() {
mock_queue.message_buffer_.erase(channel_info.channel_id);
mock_queue.consumed_buffer_.erase(channel_info.channel_id);
return StreamingStatus::OK;
}
StreamingStatus MockProducer::ProduceItemToChannel(uint8_t *data, uint32_t data_size) {
auto &ring_buffer = mock_queue.message_buffer_[channel_info.channel_id];
if (ring_buffer->Full()) {
return StreamingStatus::OutOfMemory;
}
MockQueueItem item;
item.seq_id = channel_info.current_seq_id + 1;
item.data.reset(new uint8_t[data_size]);
item.data_size = data_size;
std::memcpy(item.data.get(), data, data_size);
ring_buffer->Push(item);
return StreamingStatus::OK;
}
StreamingStatus MockConsumer::ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data,
uint32_t &data_size,
uint32_t timeout) {
auto &channel_id = channel_info.channel_id;
if (mock_queue.message_buffer_.find(channel_id) == mock_queue.message_buffer_.end()) {
return StreamingStatus::NoSuchItem;
}
if (mock_queue.message_buffer_[channel_id]->Empty()) {
return StreamingStatus::NoSuchItem;
}
MockQueueItem item = mock_queue.message_buffer_[channel_id]->Front();
mock_queue.message_buffer_[channel_id]->Pop();
mock_queue.consumed_buffer_[channel_id]->Push(item);
offset_id = item.seq_id;
data = item.data.get();
data_size = item.data_size;
return StreamingStatus::OK;
}
StreamingStatus MockConsumer::NotifyChannelConsumed(uint64_t offset_id) {
auto &channel_id = channel_info.channel_id;
auto &ring_buffer = mock_queue.consumed_buffer_[channel_id];
while (!ring_buffer->Empty() && ring_buffer->Front().seq_id <= offset_id) {
ring_buffer->Pop();
}
return StreamingStatus::OK;
}
} // namespace streaming
} // namespace ray

176
streaming/src/channel.h Normal file
View file

@ -0,0 +1,176 @@
#ifndef RAY_CHANNEL_H
#define RAY_CHANNEL_H
#include "config/streaming_config.h"
#include "queue/queue_handler.h"
#include "ring_buffer.h"
#include "status.h"
#include "util/streaming_util.h"
namespace ray {
namespace streaming {
struct StreamingQueueInfo {
uint64_t first_seq_id = 0;
uint64_t last_seq_id = 0;
uint64_t target_seq_id = 0;
uint64_t consumed_seq_id = 0;
};
/// PrducerChannelinfo and ConsumerChannelInfo contains channel information and
/// its metrics that help us to debug or show important messages in logging.
struct ProducerChannelInfo {
ObjectID channel_id;
StreamingRingBufferPtr writer_ring_buffer;
uint64_t current_message_id;
uint64_t current_seq_id;
uint64_t message_last_commit_id;
StreamingQueueInfo queue_info;
uint32_t queue_size;
int64_t message_pass_by_ts;
ActorID actor_id;
};
struct ConsumerChannelInfo {
ObjectID channel_id;
uint64_t current_message_id;
uint64_t current_seq_id;
uint64_t barrier_id;
uint64_t partial_barrier_id;
StreamingQueueInfo queue_info;
uint64_t last_queue_item_delay;
uint64_t last_queue_item_latency;
uint64_t last_queue_target_diff;
uint64_t get_queue_item_times;
ActorID actor_id;
};
/// Two types of channel are presented:
/// * ProducerChannel is supporting all writing operations for upperlevel.
/// * ConsumerChannel is for all reader operations.
/// They share similar interfaces:
/// * ClearTransferCheckpoint(it's empty and unsupported now, we will add
/// implementation in next PR)
/// * NotifychannelConsumed (notify owner of channel which range data should
// be release to avoid out of memory)
/// but some differences in read/write function.(named ProduceItemTochannel and
/// ConsumeItemFrom channel)
class ProducerChannel {
public:
explicit ProducerChannel(std::shared_ptr<Config> &transfer_config,
ProducerChannelInfo &p_channel_info);
virtual ~ProducerChannel() = default;
virtual StreamingStatus CreateTransferChannel() = 0;
virtual StreamingStatus DestroyTransferChannel() = 0;
virtual StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id,
uint64_t checkpoint_offset) = 0;
virtual StreamingStatus ProduceItemToChannel(uint8_t *data, uint32_t data_size) = 0;
virtual StreamingStatus NotifyChannelConsumed(uint64_t channel_offset) = 0;
protected:
std::shared_ptr<Config> transfer_config_;
ProducerChannelInfo &channel_info;
};
class ConsumerChannel {
public:
explicit ConsumerChannel(std::shared_ptr<Config> &transfer_config,
ConsumerChannelInfo &c_channel_info);
virtual ~ConsumerChannel() = default;
virtual StreamingStatus CreateTransferChannel() = 0;
virtual StreamingStatus DestroyTransferChannel() = 0;
virtual StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id,
uint64_t checkpoint_offset) = 0;
virtual StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data,
uint32_t &data_size,
uint32_t timeout) = 0;
virtual StreamingStatus NotifyChannelConsumed(uint64_t offset_id) = 0;
protected:
std::shared_ptr<Config> transfer_config_;
ConsumerChannelInfo &channel_info;
};
class StreamingQueueProducer : public ProducerChannel {
public:
explicit StreamingQueueProducer(std::shared_ptr<Config> &transfer_config,
ProducerChannelInfo &p_channel_info);
~StreamingQueueProducer() override;
StreamingStatus CreateTransferChannel() override;
StreamingStatus DestroyTransferChannel() override;
StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id,
uint64_t checkpoint_offset) override;
StreamingStatus ProduceItemToChannel(uint8_t *data, uint32_t data_size) override;
StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override;
private:
StreamingStatus CreateQueue();
Status PushQueueItem(uint64_t seq_id, uint8_t *data, uint32_t data_size,
uint64_t timestamp);
private:
std::shared_ptr<WriterQueue> queue_;
};
class StreamingQueueConsumer : public ConsumerChannel {
public:
explicit StreamingQueueConsumer(std::shared_ptr<Config> &transfer_config,
ConsumerChannelInfo &c_channel_info);
~StreamingQueueConsumer() override;
StreamingStatus CreateTransferChannel() override;
StreamingStatus DestroyTransferChannel() override;
StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id,
uint64_t checkpoint_offset) override;
StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data,
uint32_t &data_size, uint32_t timeout) override;
StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override;
private:
std::shared_ptr<ReaderQueue> queue_;
};
/// MockProducer and Mockconsumer are independent implementation of channels that
/// conduct a very simple memory channel for unit tests or intergation test.
class MockProducer : public ProducerChannel {
public:
explicit MockProducer(std::shared_ptr<Config> &transfer_config,
ProducerChannelInfo &channel_info)
: ProducerChannel(transfer_config, channel_info){};
StreamingStatus CreateTransferChannel() override;
StreamingStatus DestroyTransferChannel() override;
StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id,
uint64_t checkpoint_offset) override {
return StreamingStatus::OK;
}
StreamingStatus ProduceItemToChannel(uint8_t *data, uint32_t data_size) override;
StreamingStatus NotifyChannelConsumed(uint64_t channel_offset) override {
return StreamingStatus::OK;
}
};
class MockConsumer : public ConsumerChannel {
public:
explicit MockConsumer(std::shared_ptr<Config> &transfer_config,
ConsumerChannelInfo &c_channel_info)
: ConsumerChannel(transfer_config, c_channel_info){};
StreamingStatus CreateTransferChannel() override { return StreamingStatus::OK; }
StreamingStatus DestroyTransferChannel() override { return StreamingStatus::OK; }
StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id,
uint64_t checkpoint_offset) override {
return StreamingStatus::OK;
}
StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data,
uint32_t &data_size, uint32_t timeout) override;
StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override;
};
} // namespace streaming
} // namespace ray
#endif // RAY_CHANNEL_H

View file

@ -0,0 +1,89 @@
#include <unistd.h>
#include "streaming_config.h"
#include "util/streaming_logging.h"
namespace ray {
namespace streaming {
uint64_t StreamingConfig::TIME_WAIT_UINT = 1;
uint32_t StreamingConfig::DEFAULT_RING_BUFFER_CAPACITY = 500;
uint32_t StreamingConfig::DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL = 20;
// Time to force clean if barrier in queue, default 0ms
const uint32_t StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE = 2048;
void StreamingConfig::FromProto(const uint8_t *data, uint32_t size) {
proto::StreamingConfig config;
STREAMING_CHECK(config.ParseFromArray(data, size)) << "Parse streaming conf failed";
if (!config.job_name().empty()) {
SetJobName(config.job_name());
}
if (!config.task_job_id().empty()) {
STREAMING_CHECK(config.task_job_id().size() == 2 * JobID::Size());
SetTaskJobId(config.task_job_id());
}
if (!config.worker_name().empty()) {
SetWorkerName(config.worker_name());
}
if (!config.op_name().empty()) {
SetOpName(config.op_name());
}
if (config.role() != proto::OperatorType::UNKNOWN) {
SetOperatorType(config.role());
}
if (config.ring_buffer_capacity() != 0) {
SetRingBufferCapacity(config.ring_buffer_capacity());
}
if (config.empty_message_interval() != 0) {
SetEmptyMessageTimeInterval(config.empty_message_interval());
}
}
uint32_t StreamingConfig::GetRingBufferCapacity() const { return ring_buffer_capacity_; }
void StreamingConfig::SetRingBufferCapacity(uint32_t ring_buffer_capacity) {
StreamingConfig::ring_buffer_capacity_ =
std::min(ring_buffer_capacity, StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE);
}
uint32_t StreamingConfig::GetEmptyMessageTimeInterval() const {
return empty_message_time_interval_;
}
void StreamingConfig::SetEmptyMessageTimeInterval(uint32_t empty_message_time_interval) {
StreamingConfig::empty_message_time_interval_ = empty_message_time_interval;
}
streaming::proto::OperatorType StreamingConfig::GetOperatorType() const {
return operator_type_;
}
void StreamingConfig::SetOperatorType(streaming::proto::OperatorType type) {
StreamingConfig::operator_type_ = type;
}
const std::string &StreamingConfig::GetJobName() const { return job_name_; }
void StreamingConfig::SetJobName(const std::string &job_name) {
StreamingConfig::job_name_ = job_name;
}
const std::string &StreamingConfig::GetOpName() const { return op_name_; }
void StreamingConfig::SetOpName(const std::string &op_name) {
StreamingConfig::op_name_ = op_name;
}
const std::string &StreamingConfig::GetWorkerName() const { return worker_name_; }
void StreamingConfig::SetWorkerName(const std::string &worker_name) {
StreamingConfig::worker_name_ = worker_name;
}
const std::string &StreamingConfig::GetTaskJobId() const { return task_job_id_; }
void StreamingConfig::SetTaskJobId(const std::string &task_job_id) {
StreamingConfig::task_job_id_ = task_job_id;
}
} // namespace streaming
} // namespace ray

View file

@ -0,0 +1,69 @@
#ifndef RAY_STREAMING_CONFIG_H
#define RAY_STREAMING_CONFIG_H
#include <cstdint>
#include <string>
#include "protobuf/streaming.pb.h"
#include "ray/common/id.h"
namespace ray {
namespace streaming {
class StreamingConfig {
public:
static uint64_t TIME_WAIT_UINT;
static uint32_t DEFAULT_RING_BUFFER_CAPACITY;
static uint32_t DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL;
static const uint32_t MESSAGE_BUNDLE_MAX_SIZE;
private:
uint32_t ring_buffer_capacity_ = DEFAULT_RING_BUFFER_CAPACITY;
uint32_t empty_message_time_interval_ = DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL;
streaming::proto::OperatorType operator_type_ =
streaming::proto::OperatorType::TRANSFORM;
std::string job_name_ = "DEFAULT_JOB_NAME";
std::string op_name_ = "DEFAULT_OP_NAME";
std::string worker_name_ = "DEFAULT_WORKER_NAME";
std::string task_job_id_ = "ffffffff";
public:
void FromProto(const uint8_t *, uint32_t size);
const std::string &GetTaskJobId() const;
void SetTaskJobId(const std::string &task_job_id);
const std::string &GetWorkerName() const;
void SetWorkerName(const std::string &worker_name);
const std::string &GetOpName() const;
void SetOpName(const std::string &op_name);
uint32_t GetEmptyMessageTimeInterval() const;
void SetEmptyMessageTimeInterval(uint32_t empty_message_time_interval);
uint32_t GetRingBufferCapacity() const;
void SetRingBufferCapacity(uint32_t ring_buffer_capacity);
streaming::proto::OperatorType GetOperatorType() const;
void SetOperatorType(streaming::proto::OperatorType type);
const std::string &GetJobName() const;
void SetJobName(const std::string &job_name);
};
} // namespace streaming
} // namespace ray
#endif // RAY_STREAMING_CONFIG_H

View file

@ -0,0 +1,297 @@
#include <algorithm>
#include <chrono>
#include <cstdlib>
#include <iostream>
#include <memory>
#include <thread>
#include "ray/util/logging.h"
#include "ray/util/util.h"
#include "data_reader.h"
#include "message/message_bundle.h"
namespace ray {
namespace streaming {
const uint32_t DataReader::kReadItemTimeout = 1000;
void DataReader::Init(const std::vector<ObjectID> &input_ids,
const std::vector<ActorID> &actor_ids,
const std::vector<uint64_t> &queue_seq_ids,
const std::vector<uint64_t> &streaming_msg_ids,
int64_t timer_interval) {
Init(input_ids, actor_ids, timer_interval);
for (size_t i = 0; i < input_ids.size(); ++i) {
auto &q_id = input_ids[i];
channel_info_map_[q_id].current_seq_id = queue_seq_ids[i];
channel_info_map_[q_id].current_message_id = streaming_msg_ids[i];
}
}
void DataReader::Init(const std::vector<ObjectID> &input_ids,
const std::vector<ActorID> &actor_ids, int64_t timer_interval) {
STREAMING_LOG(INFO) << input_ids.size() << " queue to init.";
transfer_config_->Set(ConfigEnum::QUEUE_ID_VECTOR, input_ids);
last_fetched_queue_item_ = nullptr;
timer_interval_ = timer_interval;
last_message_ts_ = 0;
input_queue_ids_ = input_ids;
last_message_latency_ = 0;
last_bundle_unit_ = 0;
for (size_t i = 0; i < input_ids.size(); ++i) {
ObjectID q_id = input_ids[i];
STREAMING_LOG(INFO) << "[Reader] Init queue id: " << q_id;
auto &channel_info = channel_info_map_[q_id];
channel_info.channel_id = q_id;
channel_info.actor_id = actor_ids[i];
channel_info.last_queue_item_delay = 0;
channel_info.last_queue_item_latency = 0;
channel_info.last_queue_target_diff = 0;
channel_info.get_queue_item_times = 0;
}
/// Make the input id location stable.
sort(input_queue_ids_.begin(), input_queue_ids_.end(),
[](const ObjectID &a, const ObjectID &b) { return a.Hash() < b.Hash(); });
std::copy(input_ids.begin(), input_ids.end(), std::back_inserter(unready_queue_ids_));
InitChannel();
}
StreamingStatus DataReader::InitChannel() {
STREAMING_LOG(INFO) << "[Reader] Getting queues. total queue num "
<< input_queue_ids_.size() << ", unready queue num => "
<< unready_queue_ids_.size();
for (const auto &input_channel : unready_queue_ids_) {
auto &channel_info = channel_info_map_[input_channel];
std::shared_ptr<ConsumerChannel> channel;
if (runtime_context_->IsMockTest()) {
channel = std::make_shared<MockConsumer>(transfer_config_, channel_info);
} else {
channel = std::make_shared<StreamingQueueConsumer>(transfer_config_, channel_info);
}
channel_map_.emplace(input_channel, channel);
StreamingStatus status = channel->CreateTransferChannel();
if (StreamingStatus::OK != status) {
STREAMING_LOG(ERROR) << "Initialize queue failed, id => " << input_channel;
}
}
runtime_context_->SetRuntimeStatus(RuntimeStatus::Running);
STREAMING_LOG(INFO) << "[Reader] Reader construction done!";
return StreamingStatus::OK;
}
StreamingStatus DataReader::InitChannelMerger() {
STREAMING_LOG(INFO) << "[Reader] Initializing queue merger.";
// Init reader merger by given comparator when it's first created.
StreamingReaderMsgPtrComparator comparator;
if (!reader_merger_) {
reader_merger_.reset(
new PriorityQueue<std::shared_ptr<DataBundle>, StreamingReaderMsgPtrComparator>(
comparator));
}
// An old item in merger vector must be evicted before new queue item has been
// pushed.
if (!unready_queue_ids_.empty() && last_fetched_queue_item_) {
STREAMING_LOG(INFO) << "pop old item from => " << last_fetched_queue_item_->from;
RETURN_IF_NOT_OK(StashNextMessage(last_fetched_queue_item_))
last_fetched_queue_item_.reset();
}
// Create initial heap for priority queue.
for (auto &input_queue : unready_queue_ids_) {
std::shared_ptr<DataBundle> msg = std::make_shared<DataBundle>();
RETURN_IF_NOT_OK(GetMessageFromChannel(channel_info_map_[input_queue], msg))
channel_info_map_[msg->from].current_seq_id = msg->seq_id;
channel_info_map_[msg->from].current_message_id = msg->meta->GetLastMessageId();
reader_merger_->push(msg);
}
STREAMING_LOG(INFO) << "[Reader] Initializing merger done.";
return StreamingStatus::OK;
}
StreamingStatus DataReader::GetMessageFromChannel(ConsumerChannelInfo &channel_info,
std::shared_ptr<DataBundle> &message) {
auto &qid = channel_info.channel_id;
last_read_q_id_ = qid;
STREAMING_LOG(DEBUG) << "[Reader] send get request queue seq id => " << qid;
while (RuntimeStatus::Running == runtime_context_->GetRuntimeStatus() &&
!message->data) {
auto status = channel_map_[channel_info.channel_id]->ConsumeItemFromChannel(
message->seq_id, message->data, message->data_size, kReadItemTimeout);
channel_info.get_queue_item_times++;
if (!message->data) {
STREAMING_LOG(DEBUG) << "[Reader] Queue " << qid << " status " << status
<< " get item timeout, resend notify "
<< channel_info.current_seq_id;
// TODO(lingxuan.zlx): notify consumed when it's timeout.
}
}
if (RuntimeStatus::Interrupted == runtime_context_->GetRuntimeStatus()) {
return StreamingStatus::Interrupted;
}
STREAMING_LOG(DEBUG) << "[Reader] recevied queue seq id => " << message->seq_id
<< ", queue id => " << qid;
message->from = qid;
message->meta = StreamingMessageBundleMeta::FromBytes(message->data);
return StreamingStatus::OK;
}
StreamingStatus DataReader::StashNextMessage(std::shared_ptr<DataBundle> &message) {
// Push new message into priority queue and record the channel metrics in
// channel info.
std::shared_ptr<DataBundle> new_msg = std::make_shared<DataBundle>();
auto &channel_info = channel_info_map_[message->from];
reader_merger_->pop();
int64_t cur_time = current_time_ms();
RETURN_IF_NOT_OK(GetMessageFromChannel(channel_info, new_msg))
reader_merger_->push(new_msg);
channel_info.last_queue_item_delay =
new_msg->meta->GetMessageBundleTs() - message->meta->GetMessageBundleTs();
channel_info.last_queue_item_latency = current_time_ms() - cur_time;
return StreamingStatus::OK;
}
StreamingStatus DataReader::GetMergedMessageBundle(std::shared_ptr<DataBundle> &message,
bool &is_valid_break) {
int64_t cur_time = current_time_ms();
if (last_fetched_queue_item_) {
RETURN_IF_NOT_OK(StashNextMessage(last_fetched_queue_item_))
}
message = reader_merger_->top();
last_fetched_queue_item_ = message;
auto &offset_info = channel_info_map_[message->from];
uint64_t cur_queue_previous_msg_id = offset_info.current_message_id;
STREAMING_LOG(DEBUG) << "[Reader] [Bundle] from q_id =>" << message->from << "cur => "
<< cur_queue_previous_msg_id << ", message list size"
<< message->meta->GetMessageListSize() << ", lst message id =>"
<< message->meta->GetLastMessageId() << ", q seq id => "
<< message->seq_id << ", last barrier id => " << message->data_size
<< ", " << message->meta->GetMessageBundleTs();
if (message->meta->IsBundle()) {
last_message_ts_ = cur_time;
is_valid_break = true;
} else if (timer_interval_ != -1 && cur_time - last_message_ts_ > timer_interval_) {
// Throw empty message when reaching timer_interval.
last_message_ts_ = cur_time;
is_valid_break = true;
}
offset_info.current_message_id = message->meta->GetLastMessageId();
offset_info.current_seq_id = message->seq_id;
last_bundle_ts_ = message->meta->GetMessageBundleTs();
STREAMING_LOG(DEBUG) << "[Reader] [Bundle] message type =>"
<< static_cast<int>(message->meta->GetBundleType())
<< " from id => " << message->from << ", queue seq id =>"
<< message->seq_id << ", message id => "
<< message->meta->GetLastMessageId();
return StreamingStatus::OK;
}
StreamingStatus DataReader::GetBundle(const uint32_t timeout_ms,
std::shared_ptr<DataBundle> &message) {
// Notify consumed every item in this mode.
if (last_fetched_queue_item_) {
NotifyConsumedItem(channel_info_map_[last_fetched_queue_item_->from],
last_fetched_queue_item_->seq_id);
}
/// DataBundle will be returned to the upper layer in the following cases:
/// a batch of data is returned when the real data is read, or an empty message
/// is returned to the upper layer when the given timeout period is reached to
/// avoid blocking for too long.
auto start_time = current_time_ms();
bool is_valid_break = false;
uint32_t empty_bundle_cnt = 0;
while (!is_valid_break) {
if (RuntimeStatus::Interrupted == runtime_context_->GetRuntimeStatus()) {
return StreamingStatus::Interrupted;
}
auto cur_time = current_time_ms();
auto dur = cur_time - start_time;
if (dur > timeout_ms) {
return StreamingStatus::GetBundleTimeOut;
}
if (!unready_queue_ids_.empty()) {
StreamingStatus status = InitChannel();
switch (status) {
case StreamingStatus::InitQueueFailed:
break;
case StreamingStatus::WaitQueueTimeOut:
STREAMING_LOG(ERROR)
<< "Wait upstream queue timeout, maybe some actors in deadlock";
break;
default:
STREAMING_LOG(INFO) << "Init reader queue in GetBundle";
}
if (StreamingStatus::OK != status) {
return status;
}
RETURN_IF_NOT_OK(InitChannelMerger())
unready_queue_ids_.clear();
auto &merge_vec = reader_merger_->getRawVector();
for (auto &bundle : merge_vec) {
STREAMING_LOG(INFO) << "merger vector item => " << bundle->from;
}
}
RETURN_IF_NOT_OK(GetMergedMessageBundle(message, is_valid_break));
if (!is_valid_break) {
empty_bundle_cnt++;
NotifyConsumedItem(channel_info_map_[message->from], message->seq_id);
}
}
last_message_latency_ += current_time_ms() - start_time;
if (message->meta->GetMessageListSize() > 0) {
last_bundle_unit_ = message->data_size * 1.0 / message->meta->GetMessageListSize();
}
return StreamingStatus::OK;
}
void DataReader::GetOffsetInfo(
std::unordered_map<ObjectID, ConsumerChannelInfo> *&offset_map) {
offset_map = &channel_info_map_;
for (auto &offset_info : channel_info_map_) {
STREAMING_LOG(INFO) << "[Reader] [GetOffsetInfo], q id " << offset_info.first
<< ", seq id => " << offset_info.second.current_seq_id
<< ", message id => " << offset_info.second.current_message_id;
}
}
void DataReader::NotifyConsumedItem(ConsumerChannelInfo &channel_info, uint64_t offset) {
channel_map_[channel_info.channel_id]->NotifyChannelConsumed(offset);
if (offset == channel_info.queue_info.last_seq_id) {
STREAMING_LOG(DEBUG) << "notify seq id equal to last seq id => " << offset;
}
}
DataReader::DataReader(std::shared_ptr<RuntimeContext> &runtime_context)
: transfer_config_(new Config()), runtime_context_(runtime_context) {}
DataReader::~DataReader() { STREAMING_LOG(INFO) << "Streaming reader deconstruct."; }
void DataReader::Stop() {
runtime_context_->SetRuntimeStatus(RuntimeStatus::Interrupted);
}
bool StreamingReaderMsgPtrComparator::operator()(const std::shared_ptr<DataBundle> &a,
const std::shared_ptr<DataBundle> &b) {
STREAMING_CHECK(a->meta);
// We use hash value of id for stability of message in sorting.
if (a->meta->GetMessageBundleTs() == b->meta->GetMessageBundleTs()) {
return a->from.Hash() > b->from.Hash();
}
return a->meta->GetMessageBundleTs() > b->meta->GetMessageBundleTs();
}
} // namespace streaming
} // namespace ray

127
streaming/src/data_reader.h Normal file
View file

@ -0,0 +1,127 @@
#ifndef RAY_DATA_READER_H
#define RAY_DATA_READER_H
#include <cstdlib>
#include <functional>
#include <queue>
#include <string>
#include <unordered_map>
#include <vector>
#include "channel.h"
#include "message/message_bundle.h"
#include "message/priority_queue.h"
#include "runtime_context.h"
namespace ray {
namespace streaming {
/// Databundle is super-bundle that contains channel information (upstream
/// channel id & bundle meta data) and raw buffer pointer.
struct DataBundle {
uint8_t *data = nullptr;
uint32_t data_size;
ObjectID from;
uint64_t seq_id;
StreamingMessageBundleMetaPtr meta;
};
/// This is implementation of merger policy in StreamingReaderMsgPtrComparator.
struct StreamingReaderMsgPtrComparator {
StreamingReaderMsgPtrComparator() = default;
bool operator()(const std::shared_ptr<DataBundle> &a,
const std::shared_ptr<DataBundle> &b);
};
/// DataReader will fetch data bundles from channels of upstream workers, once
/// invoked by user thread. Firstly put them into a priority queue ordered by bundle
/// comparator that's related meta-data, then pop out the top bunlde to user
/// thread every time, so that the order of the message can be guranteed, which
/// will also facilitate our future implementation of fault tolerance. Finally
/// user thread can extract messages from the bundle and process one by one.
class DataReader {
private:
std::vector<ObjectID> input_queue_ids_;
std::vector<ObjectID> unready_queue_ids_;
std::unique_ptr<
PriorityQueue<std::shared_ptr<DataBundle>, StreamingReaderMsgPtrComparator>>
reader_merger_;
std::shared_ptr<DataBundle> last_fetched_queue_item_;
int64_t timer_interval_;
int64_t last_bundle_ts_;
int64_t last_message_ts_;
int64_t last_message_latency_;
int64_t last_bundle_unit_;
ObjectID last_read_q_id_;
static const uint32_t kReadItemTimeout;
protected:
std::unordered_map<ObjectID, ConsumerChannelInfo> channel_info_map_;
std::unordered_map<ObjectID, std::shared_ptr<ConsumerChannel>> channel_map_;
std::shared_ptr<Config> transfer_config_;
std::shared_ptr<RuntimeContext> runtime_context_;
public:
explicit DataReader(std::shared_ptr<RuntimeContext> &runtime_context);
virtual ~DataReader();
/// During initialization, only the channel parameters and necessary member properties
/// are assigned. All channels will be connected in the first reading operation.
/// \param input_ids
/// \param actor_ids
/// \param channel_seq_ids
/// \param msg_ids
/// \param timer_interval
void Init(const std::vector<ObjectID> &input_ids, const std::vector<ActorID> &actor_ids,
const std::vector<uint64_t> &channel_seq_ids,
const std::vector<uint64_t> &msg_ids, int64_t timer_interval);
void Init(const std::vector<ObjectID> &input_ids, const std::vector<ActorID> &actor_ids,
int64_t timer_interval);
/// Get latest message from input queues.
/// \param timeout_ms
/// \param message, return the latest message
StreamingStatus GetBundle(uint32_t timeout_ms, std::shared_ptr<DataBundle> &message);
/// Get offset information about channels for checkpoint.
/// \param offset_map (return value)
void GetOffsetInfo(std::unordered_map<ObjectID, ConsumerChannelInfo> *&offset_map);
void Stop();
/// Notify input queues to clear data whose seq id is equal or less than offset.
/// It's used when checkpoint is done.
/// \param channel_info
/// \param offset
///
void NotifyConsumedItem(ConsumerChannelInfo &channel_info, uint64_t offset);
private:
/// Create channels and connect to all upstream.
StreamingStatus InitChannel();
/// One item from every channel will be popped out, then collecting
/// them to a merged queue. High prioprity items will be fetched one by one.
/// When item pop from one channel where must produce new item for placeholder
/// in merged queue.
StreamingStatus InitChannelMerger();
StreamingStatus StashNextMessage(std::shared_ptr<DataBundle> &message);
StreamingStatus GetMessageFromChannel(ConsumerChannelInfo &channel_info,
std::shared_ptr<DataBundle> &message);
/// Get top item from prioprity queue.
StreamingStatus GetMergedMessageBundle(std::shared_ptr<DataBundle> &message,
bool &is_valid_break);
};
} // namespace streaming
} // namespace ray
#endif // RAY_DATA_READER_H

View file

@ -0,0 +1,310 @@
#include <memory>
#include <memory>
#include <signal.h>
#include <unistd.h>
#include <chrono>
#include <functional>
#include <list>
#include <numeric>
#include "data_writer.h"
#include "util/streaming_util.h"
namespace ray {
namespace streaming {
void DataWriter::WriterLoopForward() {
STREAMING_CHECK(RuntimeStatus::Running == runtime_context_->GetRuntimeStatus());
while (true) {
int64_t min_passby_message_ts = std::numeric_limits<int64_t>::max();
uint32_t empty_messge_send_count = 0;
for (auto &output_queue : output_queue_ids_) {
if (RuntimeStatus::Running != runtime_context_->GetRuntimeStatus()) {
return;
}
ProducerChannelInfo &channel_info = channel_info_map_[output_queue];
bool is_push_empty_message = false;
StreamingStatus write_status =
WriteChannelProcess(channel_info, &is_push_empty_message);
int64_t current_ts = current_time_ms();
if (StreamingStatus::OK == write_status) {
channel_info.message_pass_by_ts = current_ts;
if (is_push_empty_message) {
min_passby_message_ts =
std::min(channel_info.message_pass_by_ts, min_passby_message_ts);
empty_messge_send_count++;
}
} else if (StreamingStatus::FullChannel == write_status) {
} else {
if (StreamingStatus::EmptyRingBuffer != write_status) {
STREAMING_LOG(DEBUG) << "write buffer status => "
<< static_cast<uint32_t>(write_status)
<< ", is push empty message => " << is_push_empty_message;
}
}
}
if (empty_messge_send_count == output_queue_ids_.size()) {
// Sleep if empty message was sent in all channel.
uint64_t sleep_time_ = current_time_ms() - min_passby_message_ts;
// Sleep_time can be bigger than time interval because of network jitter.
if (sleep_time_ <= runtime_context_->GetConfig().GetEmptyMessageTimeInterval()) {
std::this_thread::sleep_for(std::chrono::milliseconds(
runtime_context_->GetConfig().GetEmptyMessageTimeInterval() - sleep_time_));
}
}
}
}
StreamingStatus DataWriter::WriteChannelProcess(ProducerChannelInfo &channel_info,
bool *is_empty_message) {
// No message in buffer, empty message will be sent to downstream queue.
uint64_t buffer_remain = 0;
StreamingStatus write_queue_flag = WriteBufferToChannel(channel_info, buffer_remain);
int64_t current_ts = current_time_ms();
if (write_queue_flag == StreamingStatus::EmptyRingBuffer &&
current_ts - channel_info.message_pass_by_ts >=
runtime_context_->GetConfig().GetEmptyMessageTimeInterval()) {
write_queue_flag = WriteEmptyMessage(channel_info);
*is_empty_message = true;
STREAMING_LOG(DEBUG) << "send empty message bundle in q_id =>"
<< channel_info.channel_id;
}
return write_queue_flag;
}
StreamingStatus DataWriter::WriteBufferToChannel(ProducerChannelInfo &channel_info,
uint64_t &buffer_remain) {
StreamingRingBufferPtr &buffer_ptr = channel_info.writer_ring_buffer;
if (!IsMessageAvailableInBuffer(channel_info)) {
return StreamingStatus::EmptyRingBuffer;
}
// Flush transient buffer to queue first.
if (buffer_ptr->IsTransientAvaliable()) {
return WriteTransientBufferToChannel(channel_info);
}
STREAMING_CHECK(CollectFromRingBuffer(channel_info, buffer_remain))
<< "empty data in ringbuffer, q id => " << channel_info.channel_id;
return WriteTransientBufferToChannel(channel_info);
}
void DataWriter::Run() {
STREAMING_LOG(INFO) << "WriterLoopForward start";
loop_thread_ = std::make_shared<std::thread>(&DataWriter::WriterLoopForward, this);
}
/// Since every memory ring buffer's size is limited, when the writing buffer is
/// full, the user thread will be blocked, which will cause backpressure
/// naturally.
uint64_t DataWriter::WriteMessageToBufferRing(const ObjectID &q_id, uint8_t *data,
uint32_t data_size,
StreamingMessageType message_type) {
STREAMING_LOG(DEBUG) << "WriteMessageToBufferRing q_id: " << q_id
<< " data_size: " << data_size;
// TODO(lingxuan.zlx): currently, unsafe in multithreads
ProducerChannelInfo &channel_info = channel_info_map_[q_id];
// Write message id stands for current lastest message id and differs from
// channel.current_message_id if it's barrier message.
uint64_t &write_message_id = channel_info.current_message_id;
write_message_id++;
auto &ring_buffer_ptr = channel_info.writer_ring_buffer;
while (ring_buffer_ptr->IsFull() &&
runtime_context_->GetRuntimeStatus() == RuntimeStatus::Running) {
std::this_thread::sleep_for(
std::chrono::milliseconds(StreamingConfig::TIME_WAIT_UINT));
}
if (runtime_context_->GetRuntimeStatus() != RuntimeStatus::Running) {
STREAMING_LOG(WARNING) << "stop in write message to ringbuffer";
return 0;
}
ring_buffer_ptr->Push(std::make_shared<StreamingMessage>(
data, data_size, write_message_id, message_type));
return write_message_id;
}
StreamingStatus DataWriter::InitChannel(const ObjectID &q_id, const ActorID &actor_id,
uint64_t channel_message_id,
uint64_t queue_size) {
ProducerChannelInfo &channel_info = channel_info_map_[q_id];
channel_info.current_message_id = channel_message_id;
channel_info.channel_id = q_id;
channel_info.actor_id = actor_id;
channel_info.queue_size = queue_size;
STREAMING_LOG(WARNING) << " Init queue [" << q_id << "]";
channel_info.writer_ring_buffer = std::make_shared<StreamingRingBuffer>(
runtime_context_->GetConfig().GetRingBufferCapacity(),
StreamingRingBufferType::SPSC);
channel_info.message_pass_by_ts = current_time_ms();
std::shared_ptr<ProducerChannel> channel;
if (runtime_context_->IsMockTest()) {
channel = std::make_shared<MockProducer>(transfer_config_, channel_info);
} else {
channel = std::make_shared<StreamingQueueProducer>(transfer_config_, channel_info);
}
channel_map_.emplace(q_id, channel);
RETURN_IF_NOT_OK(channel->CreateTransferChannel())
return StreamingStatus::OK;
}
StreamingStatus DataWriter::Init(const std::vector<ObjectID> &queue_id_vec,
const std::vector<ActorID> &actor_ids,
const std::vector<uint64_t> &channel_message_id_vec,
const std::vector<uint64_t> &queue_size_vec) {
STREAMING_CHECK(!queue_id_vec.empty() && !channel_message_id_vec.empty());
ray::JobID job_id =
JobID::FromBinary(Util::Hexqid2str(runtime_context_->GetConfig().GetTaskJobId()));
STREAMING_LOG(INFO) << "Job name => " << runtime_context_->GetConfig().GetJobName()
<< ", job id => " << job_id;
output_queue_ids_ = queue_id_vec;
transfer_config_->Set(ConfigEnum::QUEUE_ID_VECTOR, queue_id_vec);
for (size_t i = 0; i < queue_id_vec.size(); ++i) {
StreamingStatus status = InitChannel(queue_id_vec[i], actor_ids[i],
channel_message_id_vec[i], queue_size_vec[i]);
if (status != StreamingStatus::OK) {
return status;
}
}
runtime_context_->SetRuntimeStatus(RuntimeStatus::Running);
return StreamingStatus::OK;
}
DataWriter::DataWriter(std::shared_ptr<RuntimeContext> &runtime_context)
: transfer_config_(new Config()), runtime_context_(runtime_context) {}
DataWriter::~DataWriter() {
// Return if fail to init streaming writer
if (runtime_context_->GetRuntimeStatus() == RuntimeStatus::Init) {
return;
}
runtime_context_->SetRuntimeStatus(RuntimeStatus::Interrupted);
if (loop_thread_->joinable()) {
STREAMING_LOG(INFO) << "Writer loop thread waiting for join";
loop_thread_->join();
}
STREAMING_LOG(INFO) << "Writer client queue disconnect.";
}
bool DataWriter::IsMessageAvailableInBuffer(ProducerChannelInfo &channel_info) {
return channel_info.writer_ring_buffer->IsTransientAvaliable() ||
!channel_info.writer_ring_buffer->IsEmpty();
}
StreamingStatus DataWriter::WriteEmptyMessage(ProducerChannelInfo &channel_info) {
auto &q_id = channel_info.channel_id;
if (channel_info.message_last_commit_id < channel_info.current_message_id) {
// Abort to send empty message if ring buffer is not empty now.
STREAMING_LOG(DEBUG) << "q_id =>" << q_id << " abort to send empty, last commit id =>"
<< channel_info.message_last_commit_id << ", channel max id => "
<< channel_info.current_message_id;
return StreamingStatus::SkipSendEmptyMessage;
}
// Make an empty bundle, use old ts from reloaded meta if it's not nullptr.
StreamingMessageBundlePtr bundle_ptr = std::make_shared<StreamingMessageBundle>(
channel_info.current_message_id, current_time_ms());
auto &q_ringbuffer = channel_info.writer_ring_buffer;
q_ringbuffer->ReallocTransientBuffer(bundle_ptr->ClassBytesSize());
bundle_ptr->ToBytes(q_ringbuffer->GetTransientBufferMutable());
StreamingStatus status = channel_map_[q_id]->ProduceItemToChannel(
const_cast<uint8_t *>(q_ringbuffer->GetTransientBuffer()),
q_ringbuffer->GetTransientBufferSize());
STREAMING_LOG(DEBUG) << "q_id =>" << q_id << " send empty message, meta info =>"
<< bundle_ptr->ToString();
q_ringbuffer->FreeTransientBuffer();
RETURN_IF_NOT_OK(status)
channel_info.current_seq_id++;
channel_info.message_pass_by_ts = current_time_ms();
return StreamingStatus::OK;
}
StreamingStatus DataWriter::WriteTransientBufferToChannel(
ProducerChannelInfo &channel_info) {
StreamingRingBufferPtr &buffer_ptr = channel_info.writer_ring_buffer;
StreamingStatus status = channel_map_[channel_info.channel_id]->ProduceItemToChannel(
buffer_ptr->GetTransientBufferMutable(), buffer_ptr->GetTransientBufferSize());
RETURN_IF_NOT_OK(status)
channel_info.current_seq_id++;
auto transient_bundle_meta =
StreamingMessageBundleMeta::FromBytes(buffer_ptr->GetTransientBuffer());
bool is_barrier_bundle = transient_bundle_meta->IsBarrier();
// Force delete to avoid super block memory isn't released so long
// if it's barrier bundle.
buffer_ptr->FreeTransientBuffer(is_barrier_bundle);
channel_info.message_last_commit_id = transient_bundle_meta->GetLastMessageId();
return StreamingStatus::OK;
}
bool DataWriter::CollectFromRingBuffer(ProducerChannelInfo &channel_info,
uint64_t &buffer_remain) {
StreamingRingBufferPtr &buffer_ptr = channel_info.writer_ring_buffer;
auto &q_id = channel_info.channel_id;
std::list<StreamingMessagePtr> message_list;
uint64_t bundle_buffer_size = 0;
const uint32_t max_queue_item_size = channel_info.queue_size;
while (message_list.size() < runtime_context_->GetConfig().GetRingBufferCapacity() &&
!buffer_ptr->IsEmpty()) {
StreamingMessagePtr &message_ptr = buffer_ptr->Front();
uint32_t message_total_size = message_ptr->ClassBytesSize();
if (!message_list.empty() &&
bundle_buffer_size + message_total_size >= max_queue_item_size) {
STREAMING_LOG(DEBUG) << "message total size " << message_total_size
<< " max queue item size => " << max_queue_item_size;
break;
}
if (!message_list.empty() &&
message_list.back()->GetMessageType() != message_ptr->GetMessageType()) {
break;
}
// ClassBytesSize = DataSize + MetaDataSize
// bundle_buffer_size += message_ptr->GetDataSize();
bundle_buffer_size += message_total_size;
message_list.push_back(message_ptr);
buffer_ptr->Pop();
buffer_remain = buffer_ptr->Size();
}
if (bundle_buffer_size >= channel_info.queue_size) {
STREAMING_LOG(ERROR) << "bundle buffer is too large to store q id => " << q_id
<< ", bundle size => " << bundle_buffer_size
<< ", queue size => " << channel_info.queue_size;
}
StreamingMessageBundlePtr bundle_ptr;
bundle_ptr = std::make_shared<StreamingMessageBundle>(
std::move(message_list), current_time_ms(), message_list.back()->GetMessageSeqId(),
StreamingMessageBundleType::Bundle, bundle_buffer_size);
buffer_ptr->ReallocTransientBuffer(bundle_ptr->ClassBytesSize());
bundle_ptr->ToBytes(buffer_ptr->GetTransientBufferMutable());
STREAMING_CHECK(bundle_ptr->ClassBytesSize() == buffer_ptr->GetTransientBufferSize());
return true;
}
void DataWriter::Stop() {
for (auto &output_queue : output_queue_ids_) {
ProducerChannelInfo &channel_info = channel_info_map_[output_queue];
while (!channel_info.writer_ring_buffer->IsEmpty()) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
std::this_thread::sleep_for(std::chrono::milliseconds(200));
runtime_context_->SetRuntimeStatus(RuntimeStatus::Interrupted);
}
} // namespace streaming
} // namespace ray

115
streaming/src/data_writer.h Normal file
View file

@ -0,0 +1,115 @@
#ifndef RAY_DATA_WRITER_H
#define RAY_DATA_WRITER_H
#include <cstring>
#include <mutex>
#include <string>
#include <thread>
#include <vector>
#include "channel.h"
#include "config/streaming_config.h"
#include "message/message_bundle.h"
#include "runtime_context.h"
namespace ray {
namespace streaming {
/// DataWriter is designed for data transporting between upstream and downstream.
/// After the user sends the data, it does not immediately send the data to
/// downstream, but caches it in the corresponding memory ring buffer. There is
/// a spearate transfer thread (setup in WriterLoopForward function) to collect
/// the messages from all the ringbuffers, and write them to the corresponding
/// transmission channels, which is backed by StreamingQueue. Actually, the
/// advantage is that the user thread will not be affected by the transmission
/// speed during the data transfer. And also the transfer thread can automatically
/// batch the catched data from memory buffer into a data bundle to reduce
/// transmission overhead. In addtion, when there is no data in the ringbuffer,
/// it will also send an empty bundle, so downstream can know that and process
/// accordingly. It will sleep for a short interval to save cpu if all ring
/// buffers have no data in that moment.
class DataWriter {
private:
std::shared_ptr<std::thread> loop_thread_;
// One channel have unique identity.
std::vector<ObjectID> output_queue_ids_;
protected:
// ProducerTransfer is middle broker for data transporting.
std::unordered_map<ObjectID, ProducerChannelInfo> channel_info_map_;
std::unordered_map<ObjectID, std::shared_ptr<ProducerChannel>> channel_map_;
std::shared_ptr<Config> transfer_config_;
std::shared_ptr<RuntimeContext> runtime_context_;
private:
bool IsMessageAvailableInBuffer(ProducerChannelInfo &channel_info);
/// This function handles two scenarios. When there is data in the transient
/// buffer, the existing data is written into the channel first, otherwise a
/// certain amount of message is first collected from the buffer and serialized
/// into the transient buffer, and finally written to the channel.
/// \\param channel_info
/// \\param buffer_remain
StreamingStatus WriteBufferToChannel(ProducerChannelInfo &channel_info,
uint64_t &buffer_remain);
/// Start the loop forward thread for collecting messages from all channels.
/// Invoking stack:
/// WriterLoopForward
/// -- WriteChannelProcess
/// -- WriteBufferToChannel
/// -- CollectFromRingBuffer
/// -- WriteTransientBufferToChannel
/// -- WriteEmptyMessage(if WriteChannelProcess return empty state)
void WriterLoopForward();
/// Push empty message when no valid message or bundle was produced each time
/// interval.
/// \param channel_info
StreamingStatus WriteEmptyMessage(ProducerChannelInfo &channel_info);
/// Flush all data from transient buffer to channel for transporting.
/// \param channel_info
StreamingStatus WriteTransientBufferToChannel(ProducerChannelInfo &channel_info);
bool CollectFromRingBuffer(ProducerChannelInfo &channel_info, uint64_t &buffer_remain);
StreamingStatus WriteChannelProcess(ProducerChannelInfo &channel_info,
bool *is_empty_message);
StreamingStatus InitChannel(const ObjectID &q_id, const ActorID &actor_id,
uint64_t channel_message_id, uint64_t queue_size);
public:
explicit DataWriter(std::shared_ptr<RuntimeContext> &runtime_context);
virtual ~DataWriter();
/// Streaming writer client initialization.
/// \param queue_id_vec queue id vector
/// \param channel_message_id_vec channel seq id is related with message checkpoint
/// \param queue_size queue size (memory size not length)
StreamingStatus Init(const std::vector<ObjectID> &channel_ids,
const std::vector<ActorID> &actor_ids,
const std::vector<uint64_t> &channel_message_id_vec,
const std::vector<uint64_t> &queue_size_vec);
/// To increase throughout, we employed an output buffer for message transformation,
/// which means we merge a lot of message to a message bundle and no message will be
/// pushed into queue directly util daemon thread does this action.
/// Additionally, writing will block when buffer ring is full intentionly.
/// \param q_id
/// \param data
/// \param data_size
/// \param message_type
/// \return message seq iq
uint64_t WriteMessageToBufferRing(
const ObjectID &q_id, uint8_t *data, uint32_t data_size,
StreamingMessageType message_type = StreamingMessageType::Message);
void Run();
void Stop();
};
} // namespace streaming
} // namespace ray
#endif // RAY_DATA_WRITER_H

View file

@ -0,0 +1,90 @@
#include <utility>
#include <cstring>
#include <string>
#include "message.h"
#include "ray/common/status.h"
#include "util/streaming_logging.h"
namespace ray {
namespace streaming {
StreamingMessage::StreamingMessage(std::shared_ptr<uint8_t> &data, uint32_t data_size,
uint64_t seq_id, StreamingMessageType message_type)
: message_data_(data),
data_size_(data_size),
message_type_(message_type),
message_id_(seq_id) {}
StreamingMessage::StreamingMessage(std::shared_ptr<uint8_t> &&data, uint32_t data_size,
uint64_t seq_id, StreamingMessageType message_type)
: message_data_(data),
data_size_(data_size),
message_type_(message_type),
message_id_(seq_id) {}
StreamingMessage::StreamingMessage(const uint8_t *data, uint32_t data_size,
uint64_t seq_id, StreamingMessageType message_type)
: data_size_(data_size), message_type_(message_type), message_id_(seq_id) {
message_data_.reset(new uint8_t[data_size], std::default_delete<uint8_t[]>());
std::memcpy(message_data_.get(), data, data_size_);
}
StreamingMessage::StreamingMessage(const StreamingMessage &msg) {
data_size_ = msg.data_size_;
message_data_ = msg.message_data_;
message_id_ = msg.message_id_;
message_type_ = msg.message_type_;
}
StreamingMessagePtr StreamingMessage::FromBytes(const uint8_t *bytes,
bool verifer_check) {
uint32_t byte_offset = 0;
uint32_t data_size = *reinterpret_cast<const uint32_t *>(bytes + byte_offset);
byte_offset += sizeof(data_size);
uint64_t seq_id = *reinterpret_cast<const uint64_t *>(bytes + byte_offset);
byte_offset += sizeof(seq_id);
StreamingMessageType msg_type =
*reinterpret_cast<const StreamingMessageType *>(bytes + byte_offset);
byte_offset += sizeof(msg_type);
auto buf = new uint8_t[data_size];
std::memcpy(buf, bytes + byte_offset, data_size);
auto data_ptr = std::shared_ptr<uint8_t>(buf, std::default_delete<uint8_t[]>());
return std::make_shared<StreamingMessage>(data_ptr, data_size, seq_id, msg_type);
}
void StreamingMessage::ToBytes(uint8_t *serlizable_data) {
uint32_t byte_offset = 0;
std::memcpy(serlizable_data + byte_offset, reinterpret_cast<char *>(&data_size_),
sizeof(data_size_));
byte_offset += sizeof(data_size_);
std::memcpy(serlizable_data + byte_offset, reinterpret_cast<char *>(&message_id_),
sizeof(message_id_));
byte_offset += sizeof(message_id_);
std::memcpy(serlizable_data + byte_offset, reinterpret_cast<char *>(&message_type_),
sizeof(message_type_));
byte_offset += sizeof(message_type_);
std::memcpy(serlizable_data + byte_offset,
reinterpret_cast<char *>(message_data_.get()), data_size_);
byte_offset += data_size_;
STREAMING_CHECK(byte_offset == this->ClassBytesSize());
}
bool StreamingMessage::operator==(const StreamingMessage &message) const {
return GetDataSize() == message.GetDataSize() &&
GetMessageSeqId() == message.GetMessageSeqId() &&
GetMessageType() == message.GetMessageType() &&
!std::memcmp(RawData(), message.RawData(), data_size_);
}
} // namespace streaming
} // namespace ray

View file

@ -0,0 +1,93 @@
#ifndef RAY_MESSAGE_H
#define RAY_MESSAGE_H
#include <memory>
namespace ray {
namespace streaming {
class StreamingMessage;
typedef std::shared_ptr<StreamingMessage> StreamingMessagePtr;
enum class StreamingMessageType : uint32_t {
Barrier = 1,
Message = 2,
MIN = Barrier,
MAX = Message
};
constexpr uint32_t kMessageHeaderSize =
sizeof(uint32_t) + sizeof(uint64_t) + sizeof(StreamingMessageType);
/// All messages should be wrapped by this protocol.
// DataSize means length of raw data, message id is increasing from [1, +INF].
// MessageType will be used for barrier transporting and checkpoint.
/// +----------------+
/// | DataSize=U32 |
/// +----------------+
/// | MessageId=U64 |
/// +----------------+
/// | MessageType=U32|
/// +----------------+
/// | Data=var |
/// +----------------+
class StreamingMessage {
private:
std::shared_ptr<uint8_t> message_data_;
uint32_t data_size_;
StreamingMessageType message_type_;
uint64_t message_id_;
public:
/// Copy raw data from outside shared buffer.
/// \param data raw data from user buffer
/// \param data_size raw data size
/// \param seq_id message id
/// \param message_type
StreamingMessage(std::shared_ptr<uint8_t> &data, uint32_t data_size, uint64_t seq_id,
StreamingMessageType message_type);
/// Move outsite raw data to message data.
/// \param data raw data from user buffer
/// \param data_size raw data size
/// \param seq_id message id
/// \param message_type
StreamingMessage(std::shared_ptr<uint8_t> &&data, uint32_t data_size, uint64_t seq_id,
StreamingMessageType message_type);
/// Copy raw data from outside buffer.
/// \param data raw data from user buffer
/// \param data_size raw data size
/// \param seq_id message id
/// \param message_type
StreamingMessage(const uint8_t *data, uint32_t data_size, uint64_t seq_id,
StreamingMessageType message_type);
StreamingMessage(const StreamingMessage &);
StreamingMessage operator=(const StreamingMessage &) = delete;
virtual ~StreamingMessage() = default;
inline uint8_t *RawData() const { return message_data_.get(); }
inline uint32_t GetDataSize() const { return data_size_; }
inline StreamingMessageType GetMessageType() const { return message_type_; }
inline uint64_t GetMessageSeqId() const { return message_id_; }
inline bool IsMessage() { return StreamingMessageType::Message == message_type_; }
inline bool IsBarrier() { return StreamingMessageType::Barrier == message_type_; }
bool operator==(const StreamingMessage &) const;
virtual void ToBytes(uint8_t *data);
static StreamingMessagePtr FromBytes(const uint8_t *data, bool verifer_check = true);
inline virtual uint32_t ClassBytesSize() { return kMessageHeaderSize + data_size_; }
};
} // namespace streaming
} // namespace ray
#endif // RAY_MESSAGE_H

View file

@ -0,0 +1,236 @@
#include <cstring>
#include <string>
#include "ray/common/status.h"
#include "config/streaming_config.h"
#include "message_bundle.h"
#include "util/streaming_logging.h"
namespace ray {
namespace streaming {
StreamingMessageBundle::StreamingMessageBundle(uint64_t last_offset_seq_id,
uint64_t message_bundle_ts)
: StreamingMessageBundleMeta(message_bundle_ts, last_offset_seq_id, 0,
StreamingMessageBundleType::Empty) {
this->raw_bundle_size_ = 0;
}
StreamingMessageBundleMeta::StreamingMessageBundleMeta(
uint64_t message_bundle_ts, uint64_t last_offset_seq_id, uint32_t message_list_size,
StreamingMessageBundleType bundle_type)
: message_bundle_ts_(message_bundle_ts),
last_message_id_(last_offset_seq_id),
message_list_size_(message_list_size),
bundle_type_(bundle_type) {
STREAMING_CHECK(message_list_size <= StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE);
}
void StreamingMessageBundleMeta::ToBytes(uint8_t *bytes) {
uint32_t byte_offset = 0;
uint32_t magicNum = StreamingMessageBundleMeta::StreamingMessageBundleMagicNum;
std::memcpy(bytes + byte_offset, reinterpret_cast<const uint8_t *>(&magicNum),
sizeof(uint32_t));
byte_offset += sizeof(uint32_t);
std::memcpy(bytes + byte_offset, reinterpret_cast<const uint8_t *>(&message_bundle_ts_),
sizeof(uint64_t));
byte_offset += sizeof(uint64_t);
std::memcpy(bytes + byte_offset, reinterpret_cast<const uint8_t *>(&last_message_id_),
sizeof(uint64_t));
byte_offset += sizeof(uint64_t);
std::memcpy(bytes + byte_offset, reinterpret_cast<const uint8_t *>(&message_list_size_),
sizeof(uint32_t));
byte_offset += sizeof(uint32_t);
std::memcpy(bytes + byte_offset, reinterpret_cast<const uint8_t *>(&bundle_type_),
sizeof(StreamingMessageBundleType));
byte_offset += sizeof(StreamingMessageBundleType);
}
StreamingMessageBundleMetaPtr StreamingMessageBundleMeta::FromBytes(const uint8_t *bytes,
bool check) {
STREAMING_CHECK(bytes);
uint32_t byte_offset = 0;
const uint32_t magic_num = *reinterpret_cast<const uint32_t *>(bytes + byte_offset);
if (magic_num != StreamingMessageBundleMagicNum) {
STREAMING_LOG(INFO) << "Magic Number => " << magic_num;
}
STREAMING_CHECK(magic_num == StreamingMessageBundleMagicNum);
byte_offset += sizeof(uint32_t);
uint64_t message_bundle_ts = *reinterpret_cast<const uint64_t *>(bytes + byte_offset);
byte_offset += sizeof(uint64_t);
uint64_t last_message_id = *reinterpret_cast<const uint64_t *>(bytes + byte_offset);
byte_offset += sizeof(uint64_t);
uint32_t messageListSize = *reinterpret_cast<const uint32_t *>(bytes + byte_offset);
byte_offset += sizeof(uint32_t);
STREAMING_LOG(DEBUG) << "ts => " << message_bundle_ts << " last message id => "
<< last_message_id << " message size => " << messageListSize;
STREAMING_CHECK(messageListSize <= StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE);
StreamingMessageBundleType messageBundleType =
*reinterpret_cast<const StreamingMessageBundleType *>(bytes + byte_offset);
byte_offset += sizeof(StreamingMessageBundleType);
auto result = std::make_shared<StreamingMessageBundleMeta>(
message_bundle_ts, last_message_id, messageListSize, messageBundleType);
STREAMING_CHECK(byte_offset == result->ClassBytesSize());
return result;
}
bool StreamingMessageBundleMeta::operator==(StreamingMessageBundleMeta &meta) const {
return this->message_list_size_ == meta.GetMessageListSize() &&
this->message_bundle_ts_ == meta.GetMessageBundleTs() &&
this->bundle_type_ == meta.GetBundleType() &&
this->last_message_id_ == meta.GetLastMessageId();
}
bool StreamingMessageBundleMeta::operator==(StreamingMessageBundleMeta *meta) const {
return operator==(*meta);
}
StreamingMessageBundleMeta::StreamingMessageBundleMeta(
StreamingMessageBundleMeta *meta_ptr) {
bundle_type_ = meta_ptr->bundle_type_;
last_message_id_ = meta_ptr->last_message_id_;
message_bundle_ts_ = meta_ptr->message_bundle_ts_;
message_list_size_ = meta_ptr->message_list_size_;
}
StreamingMessageBundleMeta::StreamingMessageBundleMeta()
: bundle_type_(StreamingMessageBundleType::Empty) {}
StreamingMessageBundle::StreamingMessageBundle(
std::list<StreamingMessagePtr> &&message_list, uint64_t message_ts,
uint64_t last_offset_seq_id, StreamingMessageBundleType bundle_type,
uint32_t raw_data_size)
: StreamingMessageBundleMeta(message_ts, last_offset_seq_id, message_list.size(),
bundle_type),
raw_bundle_size_(raw_data_size),
message_list_(message_list) {
if (bundle_type_ != StreamingMessageBundleType::Empty) {
if (!raw_bundle_size_) {
raw_bundle_size_ = std::accumulate(
message_list_.begin(), message_list_.end(), 0,
[](uint32_t x, StreamingMessagePtr &y) { return x + y->ClassBytesSize(); });
}
}
}
StreamingMessageBundle::StreamingMessageBundle(
std::list<StreamingMessagePtr> &message_list, uint64_t message_ts,
uint64_t last_offset_seq_id, StreamingMessageBundleType bundle_type,
uint32_t raw_data_size)
: StreamingMessageBundle(std::list<StreamingMessagePtr>(message_list), message_ts,
last_offset_seq_id, bundle_type, raw_data_size) {}
StreamingMessageBundle::StreamingMessageBundle(StreamingMessageBundle &bundle) {
message_bundle_ts_ = bundle.message_bundle_ts_;
message_list_size_ = bundle.message_list_size_;
raw_bundle_size_ = bundle.raw_bundle_size_;
bundle_type_ = bundle.bundle_type_;
last_message_id_ = bundle.last_message_id_;
message_list_ = bundle.message_list_;
}
void StreamingMessageBundle::ToBytes(uint8_t *bytes) {
uint32_t byte_offset = 0;
StreamingMessageBundleMeta::ToBytes(bytes + byte_offset);
byte_offset += StreamingMessageBundleMeta::ClassBytesSize();
std::memcpy(bytes + byte_offset, reinterpret_cast<char *>(&raw_bundle_size_),
sizeof(uint32_t));
byte_offset += sizeof(uint32_t);
if (raw_bundle_size_ > 0) {
ConvertMessageListToRawData(message_list_, raw_bundle_size_, bytes + byte_offset);
}
}
StreamingMessageBundlePtr StreamingMessageBundle::FromBytes(const uint8_t *bytes,
bool verifer_check) {
uint32_t byte_offset = 0;
StreamingMessageBundleMetaPtr meta_ptr =
StreamingMessageBundleMeta::FromBytes(bytes + byte_offset);
byte_offset += meta_ptr->ClassBytesSize();
uint32_t raw_data_size = *reinterpret_cast<const uint32_t *>(bytes + byte_offset);
byte_offset += sizeof(uint32_t);
std::list<StreamingMessagePtr> message_list;
// only message bundle own raw data
if (meta_ptr->GetBundleType() != StreamingMessageBundleType::Empty) {
GetMessageListFromRawData(bytes + byte_offset, raw_data_size,
meta_ptr->GetMessageListSize(), message_list);
byte_offset += raw_data_size;
}
auto result = std::make_shared<StreamingMessageBundle>(
message_list, meta_ptr->GetMessageBundleTs(), meta_ptr->GetLastMessageId(),
meta_ptr->GetBundleType());
STREAMING_CHECK(byte_offset == result->ClassBytesSize());
return result;
}
void StreamingMessageBundle::GetMessageListFromRawData(
const uint8_t *bytes, uint32_t byte_size, uint32_t message_list_size,
std::list<StreamingMessagePtr> &message_list) {
uint32_t byte_offset = 0;
// only message bundle own raw data
for (size_t i = 0; i < message_list_size; ++i) {
StreamingMessagePtr item = StreamingMessage::FromBytes(bytes + byte_offset);
message_list.push_back(item);
byte_offset += item->ClassBytesSize();
}
STREAMING_CHECK(byte_offset == byte_size);
}
void StreamingMessageBundle::GetMessageList(
std::list<StreamingMessagePtr> &message_list) {
message_list = message_list_;
}
void StreamingMessageBundle::ConvertMessageListToRawData(
const std::list<StreamingMessagePtr> &message_list, uint32_t raw_data_size,
uint8_t *raw_data) {
uint32_t byte_offset = 0;
for (auto &message : message_list) {
message->ToBytes(raw_data + byte_offset);
byte_offset += message->ClassBytesSize();
}
STREAMING_CHECK(byte_offset == raw_data_size);
}
bool StreamingMessageBundle::operator==(StreamingMessageBundle &bundle) const {
if (!(StreamingMessageBundleMeta::operator==(&bundle) &&
this->GetRawBundleSize() == bundle.GetRawBundleSize() &&
this->GetMessageListSize() == bundle.GetMessageListSize())) {
return false;
}
auto it1 = message_list_.begin();
auto it2 = bundle.message_list_.begin();
while (it1 != message_list_.end() && it2 != bundle.message_list_.end()) {
if (!((*it1).get()->operator==(*(*it2).get()))) {
return false;
}
it1++;
it2++;
}
return true;
}
bool StreamingMessageBundle::operator==(StreamingMessageBundle *bundle) const {
return this->operator==(*bundle);
}
} // namespace streaming
} // namespace ray

View file

@ -0,0 +1,164 @@
#ifndef RAY_MESSAGE_BUNDLE_H
#define RAY_MESSAGE_BUNDLE_H
#include <ctime>
#include <list>
#include <numeric>
#include "message.h"
namespace ray {
namespace streaming {
enum class StreamingMessageBundleType : uint32_t {
Empty = 1,
Barrier = 2,
Bundle = 3,
MIN = Empty,
MAX = Bundle
};
class StreamingMessageBundleMeta;
class StreamingMessageBundle;
typedef std::shared_ptr<StreamingMessageBundle> StreamingMessageBundlePtr;
typedef std::shared_ptr<StreamingMessageBundleMeta> StreamingMessageBundleMetaPtr;
constexpr uint32_t kMessageBundleMetaHeaderSize = sizeof(uint32_t) + sizeof(uint32_t) +
sizeof(uint64_t) + sizeof(uint64_t) +
sizeof(StreamingMessageBundleType);
constexpr uint32_t kMessageBundleHeaderSize =
kMessageBundleMetaHeaderSize + sizeof(uint32_t);
class StreamingMessageBundleMeta {
public:
static const uint32_t StreamingMessageBundleMagicNum = 0xCAFEBABA;
protected:
uint64_t message_bundle_ts_;
uint64_t last_message_id_;
uint32_t message_list_size_;
StreamingMessageBundleType bundle_type_;
public:
explicit StreamingMessageBundleMeta(uint64_t, uint64_t, uint32_t,
StreamingMessageBundleType);
explicit StreamingMessageBundleMeta(StreamingMessageBundleMeta *);
explicit StreamingMessageBundleMeta();
virtual ~StreamingMessageBundleMeta(){};
bool operator==(StreamingMessageBundleMeta &) const;
bool operator==(StreamingMessageBundleMeta *) const;
inline uint64_t GetMessageBundleTs() const { return message_bundle_ts_; }
inline uint64_t GetLastMessageId() const { return last_message_id_; }
inline uint32_t GetMessageListSize() const { return message_list_size_; }
inline StreamingMessageBundleType GetBundleType() const { return bundle_type_; }
inline bool IsBarrier() { return StreamingMessageBundleType::Barrier == bundle_type_; }
inline bool IsBundle() { return StreamingMessageBundleType::Bundle == bundle_type_; }
virtual void ToBytes(uint8_t *data);
static StreamingMessageBundleMetaPtr FromBytes(const uint8_t *data,
bool verifer_check = true);
inline virtual uint32_t ClassBytesSize() { return kMessageBundleMetaHeaderSize; }
std::string ToString() {
return std::to_string(last_message_id_) + "," + std::to_string(message_list_size_) +
"," + std::to_string(message_bundle_ts_) + "," +
std::to_string(static_cast<uint32_t>(bundle_type_));
}
};
/// StreamingMessageBundle inherits from metadata class (StreamingMessageBundleMeta) with
/// the following protocol:
/// MagicNum = 0xcafebaba
/// Timestamp 64bits timestamp (milliseconds from 1970)
/// LastMessageId( the last id of bundle) (0,INF]
/// MessageListSize(bundle len of message)
/// BundleType(a. bundle = 3 , b. barrier =2, c. empty = 1)
/// RawBundleSizebinary length of data)
/// RawData ( binary data)
///
/// +--------------------+
/// | MagicNum=U32 |
/// +--------------------+
/// | BundleTs=U64 |
/// +--------------------+
/// | LastMessageId=U64 |
/// +--------------------+
/// | MessageListSize=U32|
/// +--------------------+
/// | BundleType=U32 |
/// +--------------------+
/// | RawBundleSize=U32 |
/// +--------------------+
/// | RawData=var(N*Msg) |
/// +--------------------+
/// It should be noted that StreamingMessageBundle and StreamingMessageBundleMeta share
/// almost same protocol but the last two fields (RawBundleSize and RawData).
class StreamingMessageBundle : public StreamingMessageBundleMeta {
private:
uint32_t raw_bundle_size_;
// Lazy serlization/deserlization.
std::list<StreamingMessagePtr> message_list_;
public:
explicit StreamingMessageBundle(std::list<StreamingMessagePtr> &&message_list,
uint64_t bundle_ts, uint64_t offset,
StreamingMessageBundleType bundle_type,
uint32_t raw_data_size = 0);
// Duplicated copy if left reference in constructor.
explicit StreamingMessageBundle(std::list<StreamingMessagePtr> &message_list,
uint64_t bundle_ts, uint64_t offset,
StreamingMessageBundleType bundle_type,
uint32_t raw_data_size = 0);
// New a empty bundle by passing last message id and timestamp.
explicit StreamingMessageBundle(uint64_t, uint64_t);
explicit StreamingMessageBundle(StreamingMessageBundle &bundle);
virtual ~StreamingMessageBundle() = default;
inline uint32_t GetRawBundleSize() const { return raw_bundle_size_; }
bool operator==(StreamingMessageBundle &bundle) const;
bool operator==(StreamingMessageBundle *bundle_ptr) const;
void GetMessageList(std::list<StreamingMessagePtr> &message_list);
const std::list<StreamingMessagePtr> &GetMessageList() const { return message_list_; }
virtual void ToBytes(uint8_t *data);
static StreamingMessageBundlePtr FromBytes(const uint8_t *data,
bool verifer_check = true);
inline virtual uint32_t ClassBytesSize() {
return kMessageBundleHeaderSize + raw_bundle_size_;
};
static void GetMessageListFromRawData(const uint8_t *bytes, uint32_t bytes_size,
uint32_t message_list_size,
std::list<StreamingMessagePtr> &message_list);
static void ConvertMessageListToRawData(
const std::list<StreamingMessagePtr> &message_list, uint32_t raw_data_size,
uint8_t *raw_data);
};
} // namespace streaming
} // namespace ray
#endif // RAY_MESSAGE_BUNDLE_H

View file

@ -0,0 +1,53 @@
#ifndef RAY_PRIORITY_QUEUE_H
#define RAY_PRIORITY_QUEUE_H
#include <algorithm>
#include <memory>
#include <vector>
#include "util/streaming_logging.h"
namespace ray {
namespace streaming {
template <class T, class C>
class PriorityQueue {
private:
std::vector<T> merge_vec_;
C comparator_;
public:
PriorityQueue(C &comparator) : comparator_(comparator){};
inline void push(T &&item) {
merge_vec_.push_back(std::forward<T>(item));
std::push_heap(merge_vec_.begin(), merge_vec_.end(), comparator_);
}
inline void push(const T &item) {
merge_vec_.push_back(item);
std::push_heap(merge_vec_.begin(), merge_vec_.end(), comparator_);
}
inline void pop() {
STREAMING_CHECK(!isEmpty());
std::pop_heap(merge_vec_.begin(), merge_vec_.end(), comparator_);
merge_vec_.pop_back();
}
inline void makeHeap() {
std::make_heap(merge_vec_.begin(), merge_vec_.end(), comparator_);
}
inline T &top() { return merge_vec_.front(); }
inline uint32_t size() { return merge_vec_.size(); }
inline bool isEmpty() { return merge_vec_.empty(); }
std::vector<T> &getRawVector() { return merge_vec_; }
};
} // namespace streaming
} // namespace ray
#endif // RAY_PRIORITY_QUEUE_H

View file

@ -0,0 +1,23 @@
syntax = "proto3";
package ray.streaming.proto;
option java_package = "org.ray.streaming.runtime.generated";
enum OperatorType {
UNKNOWN = 0;
TRANSFORM = 1;
SOURCE = 2;
SINK = 3;
}
// all string in this message is ASCII string
message StreamingConfig {
string job_name = 1;
string task_job_id = 2;
string worker_name = 3;
string op_name = 4;
OperatorType role = 5;
uint32 ring_buffer_capacity = 6;
uint32 empty_message_interval = 7;
}

View file

@ -0,0 +1,70 @@
syntax = "proto3";
package ray.streaming.queue.protobuf;
enum StreamingQueueMessageType {
StreamingQueueDataMsgType = 0;
StreamingQueueCheckMsgType = 1;
StreamingQueueCheckRspMsgType = 2;
StreamingQueueNotificationMsgType = 3;
StreamingQueueTestInitMsgType = 4;
StreamingQueueTestCheckStatusRspMsgType = 5;
}
enum StreamingQueueError {
OK = 0;
QUEUE_NOT_EXIST = 1;
NO_VALID_DATA_TO_PULL = 2;
}
message StreamingQueueDataMsg {
bytes src_actor_id = 1;
bytes dst_actor_id = 2;
bytes queue_id = 3;
uint64 seq_id = 4;
uint64 length = 5;
bool raw = 6;
}
message StreamingQueueCheckMsg {
bytes src_actor_id = 1;
bytes dst_actor_id = 2;
bytes queue_id = 3;
}
message StreamingQueueCheckRspMsg {
bytes src_actor_id = 1;
bytes dst_actor_id = 2;
bytes queue_id = 3;
StreamingQueueError err_code = 4;
}
message StreamingQueueNotificationMsg {
bytes src_actor_id = 1;
bytes dst_actor_id = 2;
bytes queue_id = 3;
uint64 seq_id = 4;
}
// for test
enum StreamingQueueTestRole {
WRITER = 0;
READER = 1;
}
message StreamingQueueTestInitMsg {
StreamingQueueTestRole role = 1;
bytes src_actor_id = 2;
bytes dst_actor_id = 3;
bytes actor_handle = 4;
repeated bytes queue_ids = 5;
repeated bytes rescale_queue_ids = 6;
string test_suite_name = 7;
string test_name = 8;
uint64 param = 9;
}
message StreamingQueueTestCheckStatusRspMsg {
string test_name = 1;
bool status = 2;
}

View file

@ -0,0 +1,240 @@
#include "message.h"
namespace ray {
namespace streaming {
const uint32_t Message::MagicNum = 0xBABA0510;
std::unique_ptr<LocalMemoryBuffer> Message::ToBytes() {
uint8_t *bytes = nullptr;
std::string pboutput;
ToProtobuf(&pboutput);
int64_t fbs_length = pboutput.length();
queue::protobuf::StreamingQueueMessageType type = Type();
size_t total_len =
sizeof(Message::MagicNum) + sizeof(type) + sizeof(fbs_length) + fbs_length;
if (buffer_ != nullptr) {
total_len += buffer_->Size();
}
bytes = new uint8_t[total_len];
STREAMING_CHECK(bytes != nullptr) << "allocate bytes fail.";
uint8_t *p_cur = bytes;
memcpy(p_cur, &Message::MagicNum, sizeof(Message::MagicNum));
p_cur += sizeof(Message::MagicNum);
memcpy(p_cur, &type, sizeof(type));
p_cur += sizeof(type);
memcpy(p_cur, &fbs_length, sizeof(fbs_length));
p_cur += sizeof(fbs_length);
uint8_t *fbs_bytes = (uint8_t *)pboutput.data();
memcpy(p_cur, fbs_bytes, fbs_length);
p_cur += fbs_length;
if (buffer_ != nullptr) {
memcpy(p_cur, buffer_->Data(), buffer_->Size());
}
// COPY
std::unique_ptr<LocalMemoryBuffer> buffer =
std::unique_ptr<LocalMemoryBuffer>(new LocalMemoryBuffer(bytes, total_len, true));
delete bytes;
return buffer;
}
void DataMessage::ToProtobuf(std::string *output) {
queue::protobuf::StreamingQueueDataMsg msg;
msg.set_src_actor_id(actor_id_.Binary());
msg.set_dst_actor_id(peer_actor_id_.Binary());
msg.set_queue_id(queue_id_.Binary());
msg.set_seq_id(seq_id_);
msg.set_length(buffer_->Size());
msg.set_raw(raw_);
msg.SerializeToString(output);
}
std::shared_ptr<DataMessage> DataMessage::FromBytes(uint8_t *bytes) {
bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType);
uint64_t *fbs_length = (uint64_t *)bytes;
bytes += sizeof(uint64_t);
std::string inputpb(reinterpret_cast<char const *>(bytes), *fbs_length);
queue::protobuf::StreamingQueueDataMsg message;
message.ParseFromString(inputpb);
ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id());
ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id());
ObjectID queue_id = ObjectID::FromBinary(message.queue_id());
uint64_t seq_id = message.seq_id();
uint64_t length = message.length();
bool raw = message.raw();
bytes += *fbs_length;
/// Copy data and create a new buffer for streaming queue.
std::shared_ptr<LocalMemoryBuffer> buffer =
std::make_shared<LocalMemoryBuffer>(bytes, (size_t)length, true);
std::shared_ptr<DataMessage> data_msg = std::make_shared<DataMessage>(
src_actor_id, dst_actor_id, queue_id, seq_id, buffer, raw);
return data_msg;
}
void NotificationMessage::ToProtobuf(std::string *output) {
queue::protobuf::StreamingQueueNotificationMsg msg;
msg.set_src_actor_id(actor_id_.Binary());
msg.set_dst_actor_id(peer_actor_id_.Binary());
msg.set_queue_id(queue_id_.Binary());
msg.set_seq_id(seq_id_);
msg.SerializeToString(output);
}
std::shared_ptr<NotificationMessage> NotificationMessage::FromBytes(uint8_t *bytes) {
bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType);
uint64_t *length = (uint64_t *)bytes;
bytes += sizeof(uint64_t);
std::string inputpb(reinterpret_cast<char const *>(bytes), *length);
queue::protobuf::StreamingQueueNotificationMsg message;
message.ParseFromString(inputpb);
STREAMING_LOG(INFO) << "message.src_actor_id: " << message.src_actor_id();
ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id());
ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id());
ObjectID queue_id = ObjectID::FromBinary(message.queue_id());
uint64_t seq_id = message.seq_id();
std::shared_ptr<NotificationMessage> notify_msg =
std::make_shared<NotificationMessage>(src_actor_id, dst_actor_id, queue_id, seq_id);
return notify_msg;
}
void CheckMessage::ToProtobuf(std::string *output) {
queue::protobuf::StreamingQueueCheckMsg msg;
msg.set_src_actor_id(actor_id_.Binary());
msg.set_dst_actor_id(peer_actor_id_.Binary());
msg.set_queue_id(queue_id_.Binary());
msg.SerializeToString(output);
}
std::shared_ptr<CheckMessage> CheckMessage::FromBytes(uint8_t *bytes) {
bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType);
uint64_t *length = (uint64_t *)bytes;
bytes += sizeof(uint64_t);
std::string inputpb(reinterpret_cast<char const *>(bytes), *length);
queue::protobuf::StreamingQueueCheckMsg message;
message.ParseFromString(inputpb);
ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id());
ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id());
ObjectID queue_id = ObjectID::FromBinary(message.queue_id());
std::shared_ptr<CheckMessage> check_msg =
std::make_shared<CheckMessage>(src_actor_id, dst_actor_id, queue_id);
return check_msg;
}
void CheckRspMessage::ToProtobuf(std::string *output) {
queue::protobuf::StreamingQueueCheckRspMsg msg;
msg.set_src_actor_id(actor_id_.Binary());
msg.set_dst_actor_id(peer_actor_id_.Binary());
msg.set_queue_id(queue_id_.Binary());
msg.set_err_code(err_code_);
msg.SerializeToString(output);
}
std::shared_ptr<CheckRspMessage> CheckRspMessage::FromBytes(uint8_t *bytes) {
bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType);
uint64_t *length = (uint64_t *)bytes;
bytes += sizeof(uint64_t);
std::string inputpb(reinterpret_cast<char const *>(bytes), *length);
queue::protobuf::StreamingQueueCheckRspMsg message;
message.ParseFromString(inputpb);
ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id());
ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id());
ObjectID queue_id = ObjectID::FromBinary(message.queue_id());
queue::protobuf::StreamingQueueError err_code = message.err_code();
std::shared_ptr<CheckRspMessage> check_rsp_msg =
std::make_shared<CheckRspMessage>(src_actor_id, dst_actor_id, queue_id, err_code);
return check_rsp_msg;
}
void TestInitMessage::ToProtobuf(std::string *output) {
queue::protobuf::StreamingQueueTestInitMsg msg;
msg.set_role(role_);
msg.set_src_actor_id(actor_id_.Binary());
msg.set_dst_actor_id(peer_actor_id_.Binary());
msg.set_actor_handle(actor_handle_serialized_);
for (auto &queue_id : queue_ids_) {
msg.add_queue_ids(queue_id.Binary());
}
for (auto &queue_id : rescale_queue_ids_) {
msg.add_rescale_queue_ids(queue_id.Binary());
}
msg.set_test_suite_name(test_suite_name_);
msg.set_test_name(test_name_);
msg.set_param(param_);
msg.SerializeToString(output);
}
std::shared_ptr<TestInitMessage> TestInitMessage::FromBytes(uint8_t *bytes) {
bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType);
uint64_t *length = (uint64_t *)bytes;
bytes += sizeof(uint64_t);
std::string inputpb(reinterpret_cast<char const *>(bytes), *length);
queue::protobuf::StreamingQueueTestInitMsg message;
message.ParseFromString(inputpb);
queue::protobuf::StreamingQueueTestRole role = message.role();
ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id());
ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id());
std::string actor_handle_serialized = message.actor_handle();
std::vector<ObjectID> queue_ids;
for (int i = 0; i < message.queue_ids_size(); i++) {
queue_ids.push_back(ObjectID::FromBinary(message.queue_ids(i)));
}
std::vector<ObjectID> rescale_queue_ids;
for (int i = 0; i < message.rescale_queue_ids_size(); i++) {
rescale_queue_ids.push_back(ObjectID::FromBinary(message.rescale_queue_ids(i)));
}
std::string test_suite_name = message.test_suite_name();
std::string test_name = message.test_name();
uint64_t param = message.param();
std::shared_ptr<TestInitMessage> test_init_msg = std::make_shared<TestInitMessage>(
role, src_actor_id, dst_actor_id, actor_handle_serialized, queue_ids,
rescale_queue_ids, test_suite_name, test_name, param);
return test_init_msg;
}
void TestCheckStatusRspMsg::ToProtobuf(std::string *output) {
queue::protobuf::StreamingQueueTestCheckStatusRspMsg msg;
msg.set_test_name(test_name_);
msg.set_status(status_);
msg.SerializeToString(output);
}
std::shared_ptr<TestCheckStatusRspMsg> TestCheckStatusRspMsg::FromBytes(uint8_t *bytes) {
bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType);
uint64_t *length = (uint64_t *)bytes;
bytes += sizeof(uint64_t);
std::string inputpb(reinterpret_cast<char const *>(bytes), *length);
queue::protobuf::StreamingQueueTestCheckStatusRspMsg message;
message.ParseFromString(inputpb);
std::string test_name = message.test_name();
bool status = message.status();
std::shared_ptr<TestCheckStatusRspMsg> test_check_msg =
std::make_shared<TestCheckStatusRspMsg>(test_name, status);
return test_check_msg;
}
} // namespace streaming
} // namespace ray

View file

@ -0,0 +1,235 @@
#ifndef _STREAMING_QUEUE_MESSAGE_H_
#define _STREAMING_QUEUE_MESSAGE_H_
#include "protobuf/streaming_queue.pb.h"
#include "ray/common/buffer.h"
#include "ray/common/id.h"
#include "util/streaming_logging.h"
namespace ray {
namespace streaming {
/// Base class of all message classes.
/// All payloads transferred through direct actor call are packed into a unified package,
/// consisting of protobuf-formatted metadata and data, including data and control
/// messages. These message classes wrap the package defined in
/// protobuf/streaming_queue.proto respectively.
class Message {
public:
/// Construct a Message instance.
/// \param[in] actor_id ActorID of message sender.
/// \param[in] peer_actor_id ActorID of message receiver.
/// \param[in] queue_id queue id to identify which queue the message is sent to.
/// \param[in] buffer an optional param, a chunk of data to send.
Message(const ActorID &actor_id, const ActorID &peer_actor_id, const ObjectID &queue_id,
std::shared_ptr<LocalMemoryBuffer> buffer = nullptr)
: actor_id_(actor_id),
peer_actor_id_(peer_actor_id),
queue_id_(queue_id),
buffer_(buffer) {}
Message() {}
virtual ~Message() {}
ActorID ActorId() { return actor_id_; }
ActorID PeerActorId() { return peer_actor_id_; }
ObjectID QueueId() { return queue_id_; }
std::shared_ptr<LocalMemoryBuffer> Buffer() { return buffer_; }
/// Serialize all meta data and data to a LocalMemoryBuffer, which can be sent through
/// direct actor call. \return serialized buffer .
std::unique_ptr<LocalMemoryBuffer> ToBytes();
/// Get message type.
/// \return message type.
virtual queue::protobuf::StreamingQueueMessageType Type() = 0;
/// All subclasses should implement `ToProtobuf` to serialize its own protobuf data.
virtual void ToProtobuf(std::string *output) = 0;
protected:
ActorID actor_id_;
ActorID peer_actor_id_;
ObjectID queue_id_;
std::shared_ptr<LocalMemoryBuffer> buffer_;
public:
/// A magic number to identify a valid message.
static const uint32_t MagicNum;
};
/// Wrap StreamingQueueDataMsg in streaming_queue.proto.
/// DataMessage encapsulates the memory buffer of QueueItem, a one-to-one relationship
/// exists between DataMessage and QueueItem.
class DataMessage : public Message {
public:
DataMessage(const ActorID &actor_id, const ActorID &peer_actor_id, ObjectID queue_id,
uint64_t seq_id, std::shared_ptr<LocalMemoryBuffer> buffer, bool raw)
: Message(actor_id, peer_actor_id, queue_id, buffer), seq_id_(seq_id), raw_(raw) {}
virtual ~DataMessage() {}
static std::shared_ptr<DataMessage> FromBytes(uint8_t *bytes);
virtual void ToProtobuf(std::string *output);
uint64_t SeqId() { return seq_id_; }
bool IsRaw() { return raw_; }
queue::protobuf::StreamingQueueMessageType Type() { return type_; }
private:
uint64_t seq_id_;
bool raw_;
const queue::protobuf::StreamingQueueMessageType type_ =
queue::protobuf::StreamingQueueMessageType::StreamingQueueDataMsgType;
};
/// Wrap StreamingQueueNotificationMsg in streaming_queue.proto.
/// NotificationMessage, downstream queues sends to upstream queues, for the data reader
/// to inform the data writer of the consumed offset.
class NotificationMessage : public Message {
public:
NotificationMessage(const ActorID &actor_id, const ActorID &peer_actor_id,
const ObjectID &queue_id, uint64_t seq_id)
: Message(actor_id, peer_actor_id, queue_id), seq_id_(seq_id) {}
virtual ~NotificationMessage() {}
static std::shared_ptr<NotificationMessage> FromBytes(uint8_t *bytes);
virtual void ToProtobuf(std::string *output);
uint64_t SeqId() { return seq_id_; }
queue::protobuf::StreamingQueueMessageType Type() { return type_; }
private:
uint64_t seq_id_;
const queue::protobuf::StreamingQueueMessageType type_ =
queue::protobuf::StreamingQueueMessageType::StreamingQueueNotificationMsgType;
};
/// Wrap StreamingQueueCheckMsg in streaming_queue.proto.
/// CheckMessage, upstream queues sends to downstream queues, fot the data writer to check
/// whether the corresponded downstream queue is read or not.
class CheckMessage : public Message {
public:
CheckMessage(const ActorID &actor_id, const ActorID &peer_actor_id,
const ObjectID &queue_id)
: Message(actor_id, peer_actor_id, queue_id) {}
virtual ~CheckMessage() {}
static std::shared_ptr<CheckMessage> FromBytes(uint8_t *bytes);
virtual void ToProtobuf(std::string *output);
queue::protobuf::StreamingQueueMessageType Type() { return type_; }
private:
const queue::protobuf::StreamingQueueMessageType type_ =
queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckMsgType;
};
/// Wrap StreamingQueueCheckRspMsg in streaming_queue.proto.
/// CheckRspMessage, downstream queues sends to upstream queues, the response message to
/// CheckMessage to indicate whether downstream queue is ready or not.
class CheckRspMessage : public Message {
public:
CheckRspMessage(const ActorID &actor_id, const ActorID &peer_actor_id,
const ObjectID &queue_id, queue::protobuf::StreamingQueueError err_code)
: Message(actor_id, peer_actor_id, queue_id), err_code_(err_code) {}
virtual ~CheckRspMessage() {}
static std::shared_ptr<CheckRspMessage> FromBytes(uint8_t *bytes);
virtual void ToProtobuf(std::string *output);
queue::protobuf::StreamingQueueMessageType Type() { return type_; }
queue::protobuf::StreamingQueueError Error() { return err_code_; }
private:
queue::protobuf::StreamingQueueError err_code_;
const queue::protobuf::StreamingQueueMessageType type_ =
queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType;
};
/// Wrap StreamingQueueTestInitMsg in streaming_queue.proto.
/// TestInitMessage, used for test, driver sends to test workers to init test suite.
class TestInitMessage : public Message {
public:
TestInitMessage(const queue::protobuf::StreamingQueueTestRole role,
const ActorID &actor_id, const ActorID &peer_actor_id,
const std::string actor_handle_serialized,
const std::vector<ObjectID> &queue_ids,
const std::vector<ObjectID> &rescale_queue_ids,
std::string test_suite_name, std::string test_name, uint64_t param)
: Message(actor_id, peer_actor_id, queue_ids[0]),
actor_handle_serialized_(actor_handle_serialized),
queue_ids_(queue_ids),
rescale_queue_ids_(rescale_queue_ids),
role_(role),
test_suite_name_(test_suite_name),
test_name_(test_name),
param_(param) {}
virtual ~TestInitMessage() {}
static std::shared_ptr<TestInitMessage> FromBytes(uint8_t *bytes);
virtual void ToProtobuf(std::string *output);
queue::protobuf::StreamingQueueMessageType Type() { return type_; }
std::string ActorHandleSerialized() { return actor_handle_serialized_; }
queue::protobuf::StreamingQueueTestRole Role() { return role_; }
std::vector<ObjectID> QueueIds() { return queue_ids_; }
std::vector<ObjectID> RescaleQueueIds() { return rescale_queue_ids_; }
std::string TestSuiteName() { return test_suite_name_; }
std::string TestName() { return test_name_; }
uint64_t Param() { return param_; }
std::string ToString() {
std::ostringstream os;
os << "actor_handle_serialized: " << actor_handle_serialized_;
os << " actor_id: " << ActorId();
os << " peer_actor_id: " << PeerActorId();
os << " queue_ids:[";
for (auto &qid : queue_ids_) {
os << qid << ",";
}
os << "], rescale_queue_ids:[";
for (auto &qid : rescale_queue_ids_) {
os << qid << ",";
}
os << "],";
os << " role:" << queue::protobuf::StreamingQueueTestRole_Name(role_);
os << " suite_name: " << test_suite_name_;
os << " test_name: " << test_name_;
os << " param: " << param_;
return os.str();
}
private:
const queue::protobuf::StreamingQueueMessageType type_ =
queue::protobuf::StreamingQueueMessageType::StreamingQueueTestInitMsgType;
std::string actor_handle_serialized_;
std::vector<ObjectID> queue_ids_;
std::vector<ObjectID> rescale_queue_ids_;
queue::protobuf::StreamingQueueTestRole role_;
std::string test_suite_name_;
std::string test_name_;
uint64_t param_;
};
/// Wrap StreamingQueueTestCheckStatusRspMsg in streaming_queue.proto.
/// TestCheckStatusRspMsg, used for test, driver sends to test workers to check
/// whether test has completed or failed.
class TestCheckStatusRspMsg : public Message {
public:
TestCheckStatusRspMsg(const std::string test_name, bool status)
: test_name_(test_name), status_(status) {}
virtual ~TestCheckStatusRspMsg() {}
static std::shared_ptr<TestCheckStatusRspMsg> FromBytes(uint8_t *bytes);
virtual void ToProtobuf(std::string *output);
queue::protobuf::StreamingQueueMessageType Type() { return type_; }
std::string TestName() { return test_name_; }
bool Status() { return status_; }
private:
const queue::protobuf::StreamingQueueMessageType type_ =
queue::protobuf::StreamingQueueMessageType::StreamingQueueTestCheckStatusRspMsgType;
std::string test_name_;
bool status_;
};
} // namespace streaming
} // namespace ray
#endif

View file

@ -0,0 +1,211 @@
#include "queue.h"
#include <chrono>
#include <thread>
#include "queue_handler.h"
#include "util/streaming_util.h"
namespace ray {
namespace streaming {
bool Queue::Push(QueueItem item) {
std::unique_lock<std::mutex> lock(mutex_);
if (max_data_size_ < item.DataSize() + data_size_) return false;
buffer_queue_.push_back(item);
data_size_ += item.DataSize();
readable_cv_.notify_one();
return true;
}
QueueItem Queue::FrontProcessed() {
std::unique_lock<std::mutex> lock(mutex_);
STREAMING_CHECK(buffer_queue_.size() != 0) << "WriterQueue Pop fail";
if (watershed_iter_ == buffer_queue_.begin()) {
return InvalidQueueItem();
}
QueueItem item = buffer_queue_.front();
return item;
}
QueueItem Queue::PopProcessed() {
std::unique_lock<std::mutex> lock(mutex_);
STREAMING_CHECK(buffer_queue_.size() != 0) << "WriterQueue Pop fail";
if (watershed_iter_ == buffer_queue_.begin()) {
return InvalidQueueItem();
}
QueueItem item = buffer_queue_.front();
buffer_queue_.pop_front();
data_size_ -= item.DataSize();
data_size_sent_ -= item.DataSize();
return item;
}
QueueItem Queue::PopPending() {
std::unique_lock<std::mutex> lock(mutex_);
auto it = std::next(watershed_iter_);
QueueItem item = *it;
data_size_sent_ += it->DataSize();
buffer_queue_.splice(watershed_iter_, buffer_queue_, it, std::next(it));
return item;
}
QueueItem Queue::PopPendingBlockTimeout(uint64_t timeout_us) {
std::unique_lock<std::mutex> lock(mutex_);
std::chrono::system_clock::time_point point =
std::chrono::system_clock::now() + std::chrono::microseconds(timeout_us);
if (readable_cv_.wait_until(lock, point, [this] {
return std::next(watershed_iter_) != buffer_queue_.end();
})) {
auto it = std::next(watershed_iter_);
QueueItem item = *it;
data_size_sent_ += it->DataSize();
buffer_queue_.splice(watershed_iter_, buffer_queue_, it, std::next(it));
return item;
} else {
uint8_t data[1];
return QueueItem(QUEUE_INVALID_SEQ_ID, data, 1, 0, true);
}
}
QueueItem Queue::BackPending() {
std::unique_lock<std::mutex> lock(mutex_);
if (std::next(watershed_iter_) == buffer_queue_.end()) {
uint8_t data[1];
return QueueItem(QUEUE_INVALID_SEQ_ID, data, 1, 0, true);
}
return buffer_queue_.back();
}
bool Queue::IsPendingEmpty() {
std::unique_lock<std::mutex> lock(mutex_);
return std::next(watershed_iter_) == buffer_queue_.end();
}
bool Queue::IsPendingFull(uint64_t data_size) {
std::unique_lock<std::mutex> lock(mutex_);
return max_data_size_ < data_size + data_size_;
}
size_t Queue::ProcessedCount() {
std::unique_lock<std::mutex> lock(mutex_);
if (watershed_iter_ == buffer_queue_.begin()) return 0;
auto begin = buffer_queue_.begin();
auto end = std::prev(watershed_iter_);
return end->SeqId() + 1 - begin->SeqId();
}
size_t Queue::PendingCount() {
std::unique_lock<std::mutex> lock(mutex_);
if (std::next(watershed_iter_) == buffer_queue_.end()) return 0;
auto begin = std::next(watershed_iter_);
auto end = std::prev(buffer_queue_.end());
return begin->SeqId() - end->SeqId() + 1;
}
Status WriterQueue::Push(uint64_t seq_id, uint8_t *data, uint32_t data_size,
uint64_t timestamp, bool raw) {
if (IsPendingFull(data_size)) {
return Status::OutOfMemory("Queue Push OutOfMemory");
}
while (is_pulling_) {
STREAMING_LOG(INFO) << "This queue is sending pull data, wait.";
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
QueueItem item(seq_id, data, data_size, timestamp, raw);
Queue::Push(item);
return Status::OK();
}
void WriterQueue::Send() {
while (!IsPendingEmpty()) {
// FIXME: front -> send -> pop
QueueItem item = PopPending();
DataMessage msg(actor_id_, peer_actor_id_, queue_id_, item.SeqId(), item.Buffer(),
item.IsRaw());
std::unique_ptr<LocalMemoryBuffer> buffer = msg.ToBytes();
STREAMING_CHECK(transport_ != nullptr);
transport_->Send(std::move(buffer),
DownstreamQueueMessageHandler::peer_async_function_);
}
}
Status WriterQueue::TryEvictItems() {
STREAMING_LOG(INFO) << "TryEvictItems";
QueueItem item = FrontProcessed();
uint64_t first_seq_id = item.SeqId();
STREAMING_LOG(INFO) << "TryEvictItems first_seq_id: " << first_seq_id
<< " min_consumed_id_: " << min_consumed_id_
<< " eviction_limit_: " << eviction_limit_;
if (min_consumed_id_ == QUEUE_INVALID_SEQ_ID || first_seq_id > min_consumed_id_) {
return Status::OutOfMemory("The queue is full and some reader doesn't consume");
}
if (eviction_limit_ == QUEUE_INVALID_SEQ_ID || first_seq_id > eviction_limit_) {
return Status::OutOfMemory("The queue is full and eviction limit block evict");
}
uint64_t evict_target_seq_id = std::min(min_consumed_id_, eviction_limit_);
while (item.SeqId() <= evict_target_seq_id) {
PopProcessed();
STREAMING_LOG(INFO) << "TryEvictItems directly " << item.SeqId();
item = FrontProcessed();
}
return Status::OK();
}
void WriterQueue::OnNotify(std::shared_ptr<NotificationMessage> notify_msg) {
STREAMING_LOG(INFO) << "OnNotify target seq_id: " << notify_msg->SeqId();
min_consumed_id_ = notify_msg->SeqId();
}
void ReaderQueue::OnConsumed(uint64_t seq_id) {
STREAMING_LOG(INFO) << "OnConsumed: " << seq_id;
QueueItem item = FrontProcessed();
while (item.SeqId() <= seq_id) {
PopProcessed();
item = FrontProcessed();
}
Notify(seq_id);
}
void ReaderQueue::Notify(uint64_t seq_id) {
std::vector<TaskArg> task_args;
CreateNotifyTask(seq_id, task_args);
// SubmitActorTask
NotificationMessage msg(actor_id_, peer_actor_id_, queue_id_, seq_id);
std::unique_ptr<LocalMemoryBuffer> buffer = msg.ToBytes();
transport_->Send(std::move(buffer), UpstreamQueueMessageHandler::peer_async_function_);
}
void ReaderQueue::CreateNotifyTask(uint64_t seq_id, std::vector<TaskArg> &task_args) {}
void ReaderQueue::OnData(QueueItem &item) {
if (item.SeqId() != expect_seq_id_) {
STREAMING_LOG(WARNING) << "OnData ignore seq_id: " << item.SeqId()
<< " expect_seq_id_: " << expect_seq_id_;
return;
}
last_recv_seq_id_ = item.SeqId();
STREAMING_LOG(DEBUG) << "ReaderQueue::OnData seq_id: " << last_recv_seq_id_;
Push(item);
expect_seq_id_++;
}
} // namespace streaming
} // namespace ray

213
streaming/src/queue/queue.h Normal file
View file

@ -0,0 +1,213 @@
#ifndef _STREAMING_QUEUE_H_
#define _STREAMING_QUEUE_H_
#include <iterator>
#include <list>
#include <vector>
#include "ray/common/id.h"
#include "ray/util/util.h"
#include "queue_item.h"
#include "transport.h"
#include "util/streaming_logging.h"
#include "utils.h"
namespace ray {
namespace streaming {
using ray::ObjectID;
enum QueueType { UPSTREAM = 0, DOWNSTREAM };
/// A queue-like data structure, which does not delete its items after poped.
/// The lifecycle of each item is:
/// - Pending, an item is pushed into a queue, but has not been processed (sent out or
/// consumed),
/// - Processed, has been handled by the user, but should not be deleted.
/// - Evicted, useless to the user, should be poped and destroyed.
/// At present, this data structure is implemented with one std::list,
/// using a watershed iterator to divided.
class Queue {
public:
/// \param[in] queue_id the unique identification of a pair of queues (upstream and
/// downstream). \param[in] size max size of the queue in bytes. \param[in] transport
/// transport to send items to peer.
Queue(ObjectID queue_id, uint64_t size, std::shared_ptr<Transport> transport)
: queue_id_(queue_id), max_data_size_(size), data_size_(0), data_size_sent_(0) {
buffer_queue_.push_back(InvalidQueueItem());
watershed_iter_ = buffer_queue_.begin();
}
virtual ~Queue() {}
/// Push an item into the queue.
/// \param[in] item the QueueItem object to be send to peer.
/// \return false if the queue is full.
bool Push(QueueItem item);
/// Get the front of item which in processed state.
QueueItem FrontProcessed();
/// Pop the front of item which in processed state.
QueueItem PopProcessed();
/// Pop the front of item which in pending state, the item
/// will not be evicted at this moment, its state turn to
/// processed.
QueueItem PopPending();
/// PopPending with timeout in microseconds.
QueueItem PopPendingBlockTimeout(uint64_t timeout_us);
/// Return the last item in pending state.
QueueItem BackPending();
bool IsPendingEmpty();
bool IsPendingFull(uint64_t data_size = 0);
/// Return the size in bytes of all items in queue.
uint64_t QueueSize() { return data_size_; }
/// Return the size in bytes of all items in pending state.
uint64_t PendingDataSize() { return data_size_ - data_size_sent_; }
/// Return the size in bytes of all items in processed state.
uint64_t ProcessedDataSize() { return data_size_sent_; }
/// Return item count of the queue.
size_t Count() { return buffer_queue_.size(); }
/// Return item count in pending state.
size_t PendingCount();
/// Return item count in processed state.
size_t ProcessedCount();
protected:
ObjectID queue_id_;
std::list<QueueItem> buffer_queue_;
std::list<QueueItem>::iterator watershed_iter_;
/// max data size in bytes
uint64_t max_data_size_;
uint64_t data_size_;
uint64_t data_size_sent_;
std::mutex mutex_;
std::condition_variable readable_cv_;
};
/// Queue in upstream.
class WriterQueue : public Queue {
public:
/// \param queue_id, the unique ObjectID to identify a queue
/// \param actor_id, the actor id of upstream worker
/// \param peer_actor_id, the actor id of downstream worker
/// \param size, max data size in bytes
/// \param transport, transport
WriterQueue(const ObjectID &queue_id, const ActorID &actor_id,
const ActorID &peer_actor_id, uint64_t size,
std::shared_ptr<Transport> transport)
: Queue(queue_id, size, transport),
actor_id_(actor_id),
peer_actor_id_(peer_actor_id),
eviction_limit_(QUEUE_INVALID_SEQ_ID),
min_consumed_id_(QUEUE_INVALID_SEQ_ID),
peer_last_msg_id_(0),
peer_last_seq_id_(QUEUE_INVALID_SEQ_ID),
transport_(transport),
is_pulling_(false) {}
/// Push a continuous buffer into queue.
/// NOTE: the buffer should be copied.
Status Push(uint64_t seq_id, uint8_t *data, uint32_t data_size, uint64_t timestamp,
bool raw = false);
/// Callback function, will be called when downstream queue notifies
/// it has consumed some items.
/// NOTE: this callback function is called in queue thread.
void OnNotify(std::shared_ptr<NotificationMessage> notify_msg);
/// Send items through direct call.
void Send();
/// Called when user pushs item into queue. The count of items
/// can be evicted, determined by eviction_limit_ and min_consumed_id_.
Status TryEvictItems();
void SetQueueEvictionLimit(uint64_t eviction_limit) {
eviction_limit_ = eviction_limit;
}
uint64_t EvictionLimit() { return eviction_limit_; }
uint64_t GetMinConsumedSeqID() { return min_consumed_id_; }
void SetPeerLastIds(uint64_t msg_id, uint64_t seq_id) {
peer_last_msg_id_ = msg_id;
peer_last_seq_id_ = seq_id;
}
uint64_t GetPeerLastMsgId() { return peer_last_msg_id_; }
uint64_t GetPeerLastSeqId() { return peer_last_seq_id_; }
private:
ActorID actor_id_;
ActorID peer_actor_id_;
uint64_t eviction_limit_;
uint64_t min_consumed_id_;
uint64_t peer_last_msg_id_;
uint64_t peer_last_seq_id_;
std::shared_ptr<Transport> transport_;
std::atomic<bool> is_pulling_;
};
/// Queue in downstream.
class ReaderQueue : public Queue {
public:
/// \param queue_id, the unique ObjectID to identify a queue
/// \param actor_id, the actor id of upstream worker
/// \param peer_actor_id, the actor id of downstream worker
/// \param transport, transport
/// NOTE: we do not restrict queue size of ReaderQueue
ReaderQueue(const ObjectID &queue_id, const ActorID &actor_id,
const ActorID &peer_actor_id, std::shared_ptr<Transport> transport)
: Queue(queue_id, std::numeric_limits<uint64_t>::max(), transport),
actor_id_(actor_id),
peer_actor_id_(peer_actor_id),
min_consumed_id_(QUEUE_INVALID_SEQ_ID),
last_recv_seq_id_(QUEUE_INVALID_SEQ_ID),
expect_seq_id_(1),
transport_(transport) {}
/// Delete processed items whose seq id <= seq_id,
/// then notify upstream queue.
void OnConsumed(uint64_t seq_id);
void OnData(QueueItem &item);
uint64_t GetMinConsumedSeqID() { return min_consumed_id_; }
uint64_t GetLastRecvSeqId() { return last_recv_seq_id_; }
void SetExpectSeqId(uint64_t expect) { expect_seq_id_ = expect; }
private:
void Notify(uint64_t seq_id);
void CreateNotifyTask(uint64_t seq_id, std::vector<TaskArg> &task_args);
private:
ActorID actor_id_;
ActorID peer_actor_id_;
uint64_t min_consumed_id_;
uint64_t last_recv_seq_id_;
uint64_t expect_seq_id_;
std::shared_ptr<PromiseWrapper> promise_for_pull_;
std::shared_ptr<Transport> transport_;
};
} // namespace streaming
} // namespace ray
#endif

View file

@ -0,0 +1,25 @@
#include "queue_client.h"
namespace ray {
namespace streaming {
void WriterClient::OnWriterMessage(std::shared_ptr<LocalMemoryBuffer> buffer) {
upstream_handler_->DispatchMessageAsync(buffer);
}
std::shared_ptr<LocalMemoryBuffer> WriterClient::OnWriterMessageSync(
std::shared_ptr<LocalMemoryBuffer> buffer) {
return upstream_handler_->DispatchMessageSync(buffer);
}
void ReaderClient::OnReaderMessage(std::shared_ptr<LocalMemoryBuffer> buffer) {
downstream_handler_->DispatchMessageAsync(buffer);
}
std::shared_ptr<LocalMemoryBuffer> ReaderClient::OnReaderMessageSync(
std::shared_ptr<LocalMemoryBuffer> buffer) {
return downstream_handler_->DispatchMessageSync(buffer);
}
} // namespace streaming
} // namespace ray

View file

@ -0,0 +1,62 @@
#ifndef _STREAMING_QUEUE_CLIENT_H_
#define _STREAMING_QUEUE_CLIENT_H_
#include "queue_handler.h"
#include "transport.h"
namespace ray {
namespace streaming {
/// The interface of the streaming queue for DataReader.
/// A ReaderClient should be created before DataReader created in Cython/Jni, and hold by
/// Jobworker. When DataReader receive a buffer from upstream DataWriter (DataReader's
/// raycall function is called), it calls `OnReaderMessage` to pass the buffer to its own
/// downstream queue, or `OnReaderMessageSync` to wait for handle result.
class ReaderClient {
public:
/// Construct a ReaderClient object.
/// \param[in] core_worker CoreWorker C++ pointer of current actor
/// \param[in] async_func DataReader's raycall function descriptor to be called by
/// DataWriter, asynchronous semantics \param[in] sync_func DataReader's raycall
/// function descriptor to be called by DataWriter, synchronous semantics
ReaderClient(CoreWorker *core_worker, RayFunction &async_func, RayFunction &sync_func)
: core_worker_(core_worker) {
DownstreamQueueMessageHandler::peer_async_function_ = async_func;
DownstreamQueueMessageHandler::peer_sync_function_ = sync_func;
downstream_handler_ = ray::streaming::DownstreamQueueMessageHandler::CreateService(
core_worker_, core_worker_->GetWorkerContext().GetCurrentActorID());
}
/// Post buffer to downstream queue service, asynchronously.
void OnReaderMessage(std::shared_ptr<LocalMemoryBuffer> buffer);
/// Post buffer to downstream queue service, synchronously.
/// \return handle result.
std::shared_ptr<LocalMemoryBuffer> OnReaderMessageSync(
std::shared_ptr<LocalMemoryBuffer> buffer);
private:
CoreWorker *core_worker_;
std::shared_ptr<DownstreamQueueMessageHandler> downstream_handler_;
};
/// Interface of streaming queue for DataWriter. Similar to ReaderClient.
class WriterClient {
public:
WriterClient(CoreWorker *core_worker, RayFunction &async_func, RayFunction &sync_func)
: core_worker_(core_worker) {
UpstreamQueueMessageHandler::peer_async_function_ = async_func;
UpstreamQueueMessageHandler::peer_sync_function_ = sync_func;
upstream_handler_ = ray::streaming::UpstreamQueueMessageHandler::CreateService(
core_worker, core_worker_->GetWorkerContext().GetCurrentActorID());
}
void OnWriterMessage(std::shared_ptr<LocalMemoryBuffer> buffer);
std::shared_ptr<LocalMemoryBuffer> OnWriterMessageSync(
std::shared_ptr<LocalMemoryBuffer> buffer);
private:
CoreWorker *core_worker_;
std::shared_ptr<UpstreamQueueMessageHandler> upstream_handler_;
};
} // namespace streaming
} // namespace ray
#endif

View file

@ -0,0 +1,358 @@
#include "queue_handler.h"
#include "util/streaming_util.h"
#include "utils.h"
namespace ray {
namespace streaming {
constexpr uint64_t COMMON_SYNC_CALL_TIMEOUTT_MS = 5 * 1000;
std::shared_ptr<UpstreamQueueMessageHandler>
UpstreamQueueMessageHandler::upstream_handler_ = nullptr;
std::shared_ptr<DownstreamQueueMessageHandler>
DownstreamQueueMessageHandler::downstream_handler_ = nullptr;
RayFunction UpstreamQueueMessageHandler::peer_sync_function_;
RayFunction UpstreamQueueMessageHandler::peer_async_function_;
RayFunction DownstreamQueueMessageHandler::peer_sync_function_;
RayFunction DownstreamQueueMessageHandler::peer_async_function_;
std::shared_ptr<Message> QueueMessageHandler::ParseMessage(
std::shared_ptr<LocalMemoryBuffer> buffer) {
uint8_t *bytes = buffer->Data();
uint8_t *p_cur = bytes;
uint32_t *magic_num = (uint32_t *)p_cur;
STREAMING_CHECK(*magic_num == Message::MagicNum)
<< *magic_num << " " << Message::MagicNum;
p_cur += sizeof(Message::MagicNum);
queue::protobuf::StreamingQueueMessageType *type =
(queue::protobuf::StreamingQueueMessageType *)p_cur;
std::shared_ptr<Message> message = nullptr;
switch (*type) {
case queue::protobuf::StreamingQueueMessageType::StreamingQueueNotificationMsgType:
message = NotificationMessage::FromBytes(bytes);
break;
case queue::protobuf::StreamingQueueMessageType::StreamingQueueDataMsgType:
message = DataMessage::FromBytes(bytes);
break;
case queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckMsgType:
message = CheckMessage::FromBytes(bytes);
break;
case queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType:
message = CheckRspMessage::FromBytes(bytes);
break;
default:
STREAMING_CHECK(false) << "nonsupport message type: "
<< queue::protobuf::StreamingQueueMessageType_Name(*type);
break;
}
return message;
}
void QueueMessageHandler::DispatchMessageAsync(
std::shared_ptr<LocalMemoryBuffer> buffer) {
queue_service_.post(
boost::bind(&QueueMessageHandler::DispatchMessageInternal, this, buffer, nullptr));
}
std::shared_ptr<LocalMemoryBuffer> QueueMessageHandler::DispatchMessageSync(
std::shared_ptr<LocalMemoryBuffer> buffer) {
std::shared_ptr<LocalMemoryBuffer> result = nullptr;
std::shared_ptr<PromiseWrapper> promise = std::make_shared<PromiseWrapper>();
queue_service_.post(
boost::bind(&QueueMessageHandler::DispatchMessageInternal, this, buffer,
[&promise, &result](std::shared_ptr<LocalMemoryBuffer> rst) {
result = rst;
promise->Notify(ray::Status::OK());
}));
Status st = promise->Wait();
STREAMING_CHECK(st.ok());
return result;
}
std::shared_ptr<Transport> QueueMessageHandler::GetOutTransport(
const ObjectID &queue_id) {
auto it = out_transports_.find(queue_id);
if (it == out_transports_.end()) return nullptr;
return it->second;
}
void QueueMessageHandler::SetPeerActorID(const ObjectID &queue_id,
const ActorID &actor_id) {
actors_.emplace(queue_id, actor_id);
out_transports_.emplace(
queue_id, std::make_shared<ray::streaming::Transport>(core_worker_, actor_id));
}
ActorID QueueMessageHandler::GetPeerActorID(const ObjectID &queue_id) {
auto it = actors_.find(queue_id);
STREAMING_CHECK(it != actors_.end());
return it->second;
}
void QueueMessageHandler::Release() {
actors_.clear();
out_transports_.clear();
}
void QueueMessageHandler::Start() {
queue_thread_ = std::thread(&QueueMessageHandler::QueueThreadCallback, this);
}
void QueueMessageHandler::Stop() {
STREAMING_LOG(INFO) << "QueueMessageHandler Stop.";
queue_service_.stop();
if (queue_thread_.joinable()) {
queue_thread_.join();
}
}
std::shared_ptr<UpstreamQueueMessageHandler> UpstreamQueueMessageHandler::CreateService(
CoreWorker *core_worker, const ActorID &actor_id) {
if (nullptr == upstream_handler_) {
upstream_handler_ =
std::make_shared<UpstreamQueueMessageHandler>(core_worker, actor_id);
}
return upstream_handler_;
}
std::shared_ptr<UpstreamQueueMessageHandler> UpstreamQueueMessageHandler::GetService() {
return upstream_handler_;
}
std::shared_ptr<WriterQueue> UpstreamQueueMessageHandler::CreateUpstreamQueue(
const ObjectID &queue_id, const ActorID &peer_actor_id, uint64_t size) {
STREAMING_LOG(INFO) << "CreateUpstreamQueue: " << queue_id << " " << actor_id_ << "->"
<< peer_actor_id;
std::shared_ptr<WriterQueue> queue = GetUpQueue(queue_id);
if (queue != nullptr) {
STREAMING_LOG(WARNING) << "Duplicate to create up queue." << queue_id;
return queue;
}
queue = std::unique_ptr<streaming::WriterQueue>(new streaming::WriterQueue(
queue_id, actor_id_, peer_actor_id, size, GetOutTransport(queue_id)));
upstream_queues_[queue_id] = queue;
return queue;
}
bool UpstreamQueueMessageHandler::UpstreamQueueExists(const ObjectID &queue_id) {
return nullptr != GetUpQueue(queue_id);
}
std::shared_ptr<streaming::WriterQueue> UpstreamQueueMessageHandler::GetUpQueue(
const ObjectID &queue_id) {
auto it = upstream_queues_.find(queue_id);
if (it == upstream_queues_.end()) return nullptr;
return it->second;
}
bool UpstreamQueueMessageHandler::CheckQueueSync(const ObjectID &queue_id) {
ActorID peer_actor_id = GetPeerActorID(queue_id);
STREAMING_LOG(INFO) << "CheckQueueSync queue_id: " << queue_id
<< " peer_actor_id: " << peer_actor_id;
CheckMessage msg(actor_id_, peer_actor_id, queue_id);
std::unique_ptr<LocalMemoryBuffer> buffer = msg.ToBytes();
auto transport_it = GetOutTransport(queue_id);
STREAMING_CHECK(transport_it != nullptr);
std::shared_ptr<LocalMemoryBuffer> result_buffer = transport_it->SendForResultWithRetry(
std::move(buffer), DownstreamQueueMessageHandler::peer_sync_function_, 10,
COMMON_SYNC_CALL_TIMEOUTT_MS);
if (result_buffer == nullptr) {
return false;
}
std::shared_ptr<Message> result_msg = ParseMessage(result_buffer);
STREAMING_CHECK(
result_msg->Type() ==
queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType);
std::shared_ptr<CheckRspMessage> check_rsp_msg =
std::dynamic_pointer_cast<CheckRspMessage>(result_msg);
STREAMING_LOG(INFO) << "CheckQueueSync return queue_id: " << check_rsp_msg->QueueId();
STREAMING_CHECK(check_rsp_msg->PeerActorId() == actor_id_);
return queue::protobuf::StreamingQueueError::OK == check_rsp_msg->Error();
}
void UpstreamQueueMessageHandler::WaitQueues(const std::vector<ObjectID> &queue_ids,
int64_t timeout_ms,
std::vector<ObjectID> &failed_queues) {
failed_queues.insert(failed_queues.begin(), queue_ids.begin(), queue_ids.end());
uint64_t start_time_us = current_time_ms();
uint64_t current_time_us = start_time_us;
while (!failed_queues.empty() && current_time_us < start_time_us + timeout_ms * 1000) {
for (auto it = failed_queues.begin(); it != failed_queues.end();) {
if (CheckQueueSync(*it)) {
STREAMING_LOG(INFO) << "Check queue: " << *it << " return, ready.";
it = failed_queues.erase(it);
} else {
STREAMING_LOG(INFO) << "Check queue: " << *it << " return, not ready.";
std::this_thread::sleep_for(std::chrono::milliseconds(50));
it++;
}
}
current_time_us = current_time_ms();
}
}
void UpstreamQueueMessageHandler::DispatchMessageInternal(
std::shared_ptr<LocalMemoryBuffer> buffer,
std::function<void(std::shared_ptr<LocalMemoryBuffer>)> callback) {
std::shared_ptr<Message> msg = ParseMessage(buffer);
STREAMING_LOG(DEBUG) << "QueueMessageHandler::DispatchMessageInternal: "
<< " qid: " << msg->QueueId() << " actorid " << msg->ActorId()
<< " peer actorid: " << msg->PeerActorId() << " type: "
<< queue::protobuf::StreamingQueueMessageType_Name(msg->Type());
if (msg->Type() ==
queue::protobuf::StreamingQueueMessageType::StreamingQueueNotificationMsgType) {
OnNotify(std::dynamic_pointer_cast<NotificationMessage>(msg));
} else if (msg->Type() ==
queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType) {
STREAMING_CHECK(false) << "Should not receive StreamingQueueCheckRspMsg";
} else {
STREAMING_CHECK(false) << "message type should be added: "
<< queue::protobuf::StreamingQueueMessageType_Name(
msg->Type());
}
}
void UpstreamQueueMessageHandler::OnNotify(
std::shared_ptr<NotificationMessage> notify_msg) {
auto queue = GetUpQueue(notify_msg->QueueId());
if (queue == nullptr) {
STREAMING_LOG(WARNING) << "Can not find queue for "
<< queue::protobuf::StreamingQueueMessageType_Name(
notify_msg->Type())
<< ", maybe queue has been destroyed, ignore it."
<< " seq id: " << notify_msg->SeqId();
return;
}
queue->OnNotify(notify_msg);
}
void UpstreamQueueMessageHandler::ReleaseAllUpQueues() {
STREAMING_LOG(INFO) << "ReleaseAllUpQueues";
upstream_queues_.clear();
Release();
}
std::shared_ptr<DownstreamQueueMessageHandler>
DownstreamQueueMessageHandler::CreateService(CoreWorker *core_worker,
const ActorID &actor_id) {
if (nullptr == downstream_handler_) {
downstream_handler_ =
std::make_shared<DownstreamQueueMessageHandler>(core_worker, actor_id);
}
return downstream_handler_;
}
std::shared_ptr<DownstreamQueueMessageHandler>
DownstreamQueueMessageHandler::GetService() {
return downstream_handler_;
}
bool DownstreamQueueMessageHandler::DownstreamQueueExists(const ObjectID &queue_id) {
return nullptr != GetDownQueue(queue_id);
}
std::shared_ptr<ReaderQueue> DownstreamQueueMessageHandler::CreateDownstreamQueue(
const ObjectID &queue_id, const ActorID &peer_actor_id) {
STREAMING_LOG(INFO) << "CreateDownstreamQueue: " << queue_id << " " << peer_actor_id
<< "->" << actor_id_;
auto it = downstream_queues_.find(queue_id);
if (it != downstream_queues_.end()) {
STREAMING_LOG(WARNING) << "Duplicate to create down queue!!!! " << queue_id;
return it->second;
}
std::shared_ptr<streaming::ReaderQueue> queue =
std::unique_ptr<streaming::ReaderQueue>(new streaming::ReaderQueue(
queue_id, actor_id_, peer_actor_id, GetOutTransport(queue_id)));
downstream_queues_[queue_id] = queue;
return queue;
}
std::shared_ptr<streaming::ReaderQueue> DownstreamQueueMessageHandler::GetDownQueue(
const ObjectID &queue_id) {
auto it = downstream_queues_.find(queue_id);
if (it == downstream_queues_.end()) return nullptr;
return it->second;
}
std::shared_ptr<LocalMemoryBuffer> DownstreamQueueMessageHandler::OnCheckQueue(
std::shared_ptr<CheckMessage> check_msg) {
queue::protobuf::StreamingQueueError err_code =
queue::protobuf::StreamingQueueError::OK;
auto down_queue = downstream_queues_.find(check_msg->QueueId());
if (down_queue == downstream_queues_.end()) {
STREAMING_LOG(WARNING) << "OnCheckQueue " << check_msg->QueueId() << " not found.";
err_code = queue::protobuf::StreamingQueueError::QUEUE_NOT_EXIST;
}
CheckRspMessage msg(check_msg->PeerActorId(), check_msg->ActorId(),
check_msg->QueueId(), err_code);
std::shared_ptr<LocalMemoryBuffer> buffer = msg.ToBytes();
return buffer;
}
void DownstreamQueueMessageHandler::ReleaseAllDownQueues() {
STREAMING_LOG(INFO) << "ReleaseAllDownQueues size: " << downstream_queues_.size();
downstream_queues_.clear();
Release();
}
void DownstreamQueueMessageHandler::DispatchMessageInternal(
std::shared_ptr<LocalMemoryBuffer> buffer,
std::function<void(std::shared_ptr<LocalMemoryBuffer>)> callback) {
std::shared_ptr<Message> msg = ParseMessage(buffer);
STREAMING_LOG(DEBUG) << "QueueMessageHandler::DispatchMessageInternal: "
<< " qid: " << msg->QueueId() << " actorid " << msg->ActorId()
<< " peer actorid: " << msg->PeerActorId() << " type: "
<< queue::protobuf::StreamingQueueMessageType_Name(msg->Type());
if (msg->Type() ==
queue::protobuf::StreamingQueueMessageType::StreamingQueueDataMsgType) {
OnData(std::dynamic_pointer_cast<DataMessage>(msg));
} else if (msg->Type() ==
queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckMsgType) {
std::shared_ptr<LocalMemoryBuffer> check_result =
this->OnCheckQueue(std::dynamic_pointer_cast<CheckMessage>(msg));
if (callback != nullptr) {
callback(check_result);
}
} else {
STREAMING_CHECK(false) << "message type should be added: "
<< queue::protobuf::StreamingQueueMessageType_Name(
msg->Type());
}
}
void DownstreamQueueMessageHandler::OnData(std::shared_ptr<DataMessage> msg) {
auto queue = GetDownQueue(msg->QueueId());
if (queue == nullptr) {
STREAMING_LOG(WARNING) << "Can not find queue for "
<< queue::protobuf::StreamingQueueMessageType_Name(msg->Type())
<< ", maybe queue has been destroyed, ignore it."
<< " seq id: " << msg->SeqId();
return;
}
QueueItem item(msg);
queue->OnData(item);
}
} // namespace streaming
} // namespace ray

View file

@ -0,0 +1,194 @@
#ifndef _QUEUE_SERVICE_H_
#define _QUEUE_SERVICE_H_
#include <boost/asio.hpp>
#include <boost/bind.hpp>
#include <boost/thread.hpp>
#include <thread>
#include "queue.h"
#include "util/streaming_logging.h"
namespace ray {
namespace streaming {
/// Base class of UpstreamQueueMessageHandler and DownstreamQueueMessageHandler.
/// A queue service manages a group of queues, upstream queues or downstream queues of
/// the current actor. Each queue service holds a boost.asio io_service, to handle
/// messages asynchronously. When a message received by Writer/Reader in ray call thread,
/// the message was delivered to
/// UpstreamQueueMessageHandler/DownstreamQueueMessageHandler, then the ray call thread
/// returns immediately. The queue service parses meta infomation from the message,
/// including queue_id actor_id, etc, and dispatchs message to queue according to
/// queue_id.
class QueueMessageHandler {
public:
/// Construct a QueueMessageHandler instance.
/// \param[in] core_worker CoreWorker C++ pointer of current actor, used to call Core
/// Worker's api.
/// For Python worker, the pointer can be obtained from
/// ray.worker.global_worker.core_worker; For Java worker, obtained from
/// RayNativeRuntime object through java reflection.
/// \param[in] actor_id actor id of current actor.
QueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id)
: core_worker_(core_worker),
actor_id_(actor_id),
queue_dummy_work_(queue_service_) {
Start();
}
virtual ~QueueMessageHandler() { Stop(); }
/// Dispatch message buffer to asio service.
/// \param[in] buffer serialized message received from peer actor.
void DispatchMessageAsync(std::shared_ptr<LocalMemoryBuffer> buffer);
/// Dispatch message buffer to asio service synchronously, and wait for handle result.
/// \param[in] buffer serialized message received from peer actor.
/// \return handle result.
std::shared_ptr<LocalMemoryBuffer> DispatchMessageSync(
std::shared_ptr<LocalMemoryBuffer> buffer);
/// Get transport to a peer actor specified by actor_id.
/// \param[in] actor_id actor id of peer actor
/// \return transport
std::shared_ptr<Transport> GetOutTransport(const ObjectID &actor_id);
/// The actual function where message being dispatched, called by DispatchMessageAsync
/// and DispatchMessageSync.
/// \param[in] buffer serialized message received from peer actor.
/// \param[in] callback the callback function used by DispatchMessageSync, called
/// after message processed complete. The std::shared_ptr<LocalMemoryBuffer>
/// parameter is the return value.
virtual void DispatchMessageInternal(
std::shared_ptr<LocalMemoryBuffer> buffer,
std::function<void(std::shared_ptr<LocalMemoryBuffer>)> callback) = 0;
/// Save actor_id of the peer actor specified by queue_id. For a upstream queue, the
/// peer actor refer specifically to the actor in current ray cluster who has a
/// downstream queue with same queue_id, and vice versa.
/// \param[in] queue_id queue id of current queue.
/// \param[in] actor_id actor_id actor id of corresponded peer actor.
void SetPeerActorID(const ObjectID &queue_id, const ActorID &actor_id);
/// Obtain the actor id of the peer actor specified by queue_id.
/// \return actor id
ActorID GetPeerActorID(const ObjectID &queue_id);
/// Release all queues in current queue service.
void Release();
private:
/// Start asio service
void Start();
/// Stop asio service
void Stop();
/// The callback function of internal thread.
void QueueThreadCallback() { queue_service_.run(); }
protected:
/// CoreWorker C++ pointer of current actor
CoreWorker *core_worker_;
/// actor_id actor id of current actor
ActorID actor_id_;
/// Helper function, parse message buffer to Message object.
std::shared_ptr<Message> ParseMessage(std::shared_ptr<LocalMemoryBuffer> buffer);
private:
/// Map from queue id to a actor id of the queue's peer actor.
std::unordered_map<ObjectID, ActorID> actors_;
/// Map from queue id to a transport of the queue's peer actor.
std::unordered_map<ObjectID, std::shared_ptr<Transport>> out_transports_;
/// The internal thread which asio service run with.
std::thread queue_thread_;
/// The internal asio service.
boost::asio::io_service queue_service_;
/// The asio work which keeps queue_service_ alive.
boost::asio::io_service::work queue_dummy_work_;
};
/// UpstreamQueueMessageHandler holds and manages all upstream queues of current actor.
class UpstreamQueueMessageHandler : public QueueMessageHandler {
public:
/// Construct a UpstreamQueueMessageHandler instance.
UpstreamQueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id)
: QueueMessageHandler(core_worker, actor_id) {}
/// Create a upstream queue.
/// \param[in] queue_id queue id of the queue to be created.
/// \param[in] peer_actor_id actor id of peer actor.
/// \param[in] size the max memory size of the queue.
std::shared_ptr<WriterQueue> CreateUpstreamQueue(const ObjectID &queue_id,
const ActorID &peer_actor_id,
uint64_t size);
/// Check whether the upstream queue specified by queue_id exists or not.
bool UpstreamQueueExists(const ObjectID &queue_id);
/// Wait all queues in queue_ids vector ready, until timeout.
/// \param[in] queue_ids a group of queues.
/// \param[in] timeout_ms max timeout time interval for wait all queues.
/// \param[out] failed_queues a group of queues which are not ready when timeout.
void WaitQueues(const std::vector<ObjectID> &queue_ids, int64_t timeout_ms,
std::vector<ObjectID> &failed_queues);
/// Handle notify message from corresponded downstream queue.
void OnNotify(std::shared_ptr<NotificationMessage> notify_msg);
/// Obtain upstream queue specified by queue_id.
std::shared_ptr<streaming::WriterQueue> GetUpQueue(const ObjectID &queue_id);
/// Release all upstream queues
void ReleaseAllUpQueues();
virtual void DispatchMessageInternal(
std::shared_ptr<LocalMemoryBuffer> buffer,
std::function<void(std::shared_ptr<LocalMemoryBuffer>)> callback) override;
static std::shared_ptr<UpstreamQueueMessageHandler> CreateService(
CoreWorker *core_worker, const ActorID &actor_id);
static std::shared_ptr<UpstreamQueueMessageHandler> GetService();
static RayFunction peer_sync_function_;
static RayFunction peer_async_function_;
private:
bool CheckQueueSync(const ObjectID &queue_ids);
private:
std::unordered_map<ObjectID, std::shared_ptr<streaming::WriterQueue>> upstream_queues_;
static std::shared_ptr<UpstreamQueueMessageHandler> upstream_handler_;
};
/// UpstreamQueueMessageHandler holds and manages all downstream queues of current actor.
class DownstreamQueueMessageHandler : public QueueMessageHandler {
public:
DownstreamQueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id)
: QueueMessageHandler(core_worker, actor_id) {}
std::shared_ptr<ReaderQueue> CreateDownstreamQueue(const ObjectID &queue_id,
const ActorID &peer_actor_id);
bool DownstreamQueueExists(const ObjectID &queue_id);
void UpdateDownActor(const ObjectID &queue_id, const ActorID &actor_id);
std::shared_ptr<LocalMemoryBuffer> OnCheckQueue(
std::shared_ptr<CheckMessage> check_msg);
std::shared_ptr<streaming::ReaderQueue> GetDownQueue(const ObjectID &queue_id);
void ReleaseAllDownQueues();
void OnData(std::shared_ptr<DataMessage> msg);
virtual void DispatchMessageInternal(
std::shared_ptr<LocalMemoryBuffer> buffer,
std::function<void(std::shared_ptr<LocalMemoryBuffer>)> callback);
static std::shared_ptr<DownstreamQueueMessageHandler> CreateService(
CoreWorker *core_worker, const ActorID &actor_id);
static std::shared_ptr<DownstreamQueueMessageHandler> GetService();
static RayFunction peer_sync_function_;
static RayFunction peer_async_function_;
private:
std::unordered_map<ObjectID, std::shared_ptr<streaming::ReaderQueue>>
downstream_queues_;
static std::shared_ptr<DownstreamQueueMessageHandler> downstream_handler_;
};
} // namespace streaming
} // namespace ray
#endif

View file

@ -0,0 +1,109 @@
#ifndef _STREAMING_QUEUE_ITEM_H_
#define _STREAMING_QUEUE_ITEM_H_
#include <iterator>
#include <list>
#include <thread>
#include <vector>
#include "ray/common/id.h"
#include "message.h"
#include "message/message_bundle.h"
#include "util/streaming_logging.h"
namespace ray {
namespace streaming {
using ray::ObjectID;
const uint64_t QUEUE_INVALID_SEQ_ID = std::numeric_limits<uint64_t>::max();
/// QueueItem is the element stored in `Queue`. Actually, when DataWriter pushes a message
/// bundle into a queue, the bundle is packed into one QueueItem, so a one-to-one
/// relationship exists between message bundle and QueueItem. Meanwhile, the QueueItem is
/// also the minimum unit to send through direct actor call. Each QueueItem holds a
/// LocalMemoryBuffer shared_ptr, which will be sent out by Transport.
class QueueItem {
public:
/// Construct a QueueItem object.
/// \param[in] seq_id the sequential id assigned by DataWriter for a message bundle and
/// QueueItem.
/// \param[in] data the data buffer to be stored in this QueueItem.
/// \param[in] data_size the data size in bytes.
/// \param[in] timestamp the time when this QueueItem created.
/// \param[in] raw whether the data content is raw bytes, only used in some tests.
QueueItem(uint64_t seq_id, uint8_t *data, uint32_t data_size, uint64_t timestamp,
bool raw = false)
: seq_id_(seq_id),
timestamp_(timestamp),
raw_(raw),
/*COPY*/ buffer_(std::make_shared<LocalMemoryBuffer>(data, data_size, true)) {}
QueueItem(uint64_t seq_id, std::shared_ptr<LocalMemoryBuffer> buffer,
uint64_t timestamp, bool raw = false)
: seq_id_(seq_id), timestamp_(timestamp), raw_(raw), buffer_(buffer) {}
QueueItem(std::shared_ptr<DataMessage> data_msg)
: seq_id_(data_msg->SeqId()),
raw_(data_msg->IsRaw()),
buffer_(data_msg->Buffer()) {}
QueueItem(const QueueItem &&item) {
buffer_ = item.buffer_;
seq_id_ = item.seq_id_;
timestamp_ = item.timestamp_;
raw_ = item.raw_;
}
QueueItem(const QueueItem &item) {
buffer_ = item.buffer_;
seq_id_ = item.seq_id_;
timestamp_ = item.timestamp_;
raw_ = item.raw_;
}
QueueItem &operator=(const QueueItem &item) {
buffer_ = item.buffer_;
seq_id_ = item.seq_id_;
timestamp_ = item.timestamp_;
raw_ = item.raw_;
return *this;
}
virtual ~QueueItem() = default;
uint64_t SeqId() { return seq_id_; }
bool IsRaw() { return raw_; }
uint64_t TimeStamp() { return timestamp_; }
size_t DataSize() { return buffer_->Size(); }
std::shared_ptr<LocalMemoryBuffer> Buffer() { return buffer_; }
/// Get max message id in this item.
/// \return max message id.
uint64_t MaxMsgId() {
if (raw_) {
return 0;
}
auto message_bundle = StreamingMessageBundleMeta::FromBytes(buffer_->Data());
return message_bundle->GetLastMessageId();
}
protected:
uint64_t seq_id_;
uint64_t timestamp_;
bool raw_;
std::shared_ptr<LocalMemoryBuffer> buffer_;
};
class InvalidQueueItem : public QueueItem {
public:
InvalidQueueItem() : QueueItem(QUEUE_INVALID_SEQ_ID, data_, 1, 0) {}
private:
uint8_t data_[1];
};
typedef std::shared_ptr<QueueItem> QueueItemPtr;
} // namespace streaming
} // namespace ray
#endif

View file

@ -0,0 +1,94 @@
#include "transport.h"
#include "utils.h"
namespace ray {
namespace streaming {
static constexpr int TASK_OPTION_RETURN_NUM_0 = 0;
static constexpr int TASK_OPTION_RETURN_NUM_1 = 1;
void Transport::SendInternal(std::shared_ptr<LocalMemoryBuffer> buffer,
RayFunction &function, int return_num,
std::vector<ObjectID> &return_ids) {
std::unordered_map<std::string, double> resources;
TaskOptions options{return_num, true, resources};
char meta_data[3] = {'R', 'A', 'W'};
std::shared_ptr<LocalMemoryBuffer> meta =
std::make_shared<LocalMemoryBuffer>((uint8_t *)meta_data, 3, true);
std::vector<TaskArg> args;
if (function.GetLanguage() == Language::PYTHON) {
auto dummy = "__RAY_DUMMY__";
std::shared_ptr<LocalMemoryBuffer> dummyBuffer =
std::make_shared<LocalMemoryBuffer>((uint8_t *)dummy, 13, true);
args.emplace_back(TaskArg::PassByValue(
std::make_shared<RayObject>(std::move(dummyBuffer), meta, true)));
}
args.emplace_back(
TaskArg::PassByValue(std::make_shared<RayObject>(std::move(buffer), meta, true)));
STREAMING_CHECK(core_worker_ != nullptr);
std::vector<std::shared_ptr<RayObject>> results;
ray::Status st =
core_worker_->SubmitActorTask(peer_actor_id_, function, args, options, &return_ids);
if (!st.ok()) {
STREAMING_LOG(ERROR) << "SubmitActorTask failed. " << st;
}
}
void Transport::Send(std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function) {
STREAMING_LOG(INFO) << "Transport::Send buffer size: " << buffer->Size();
std::vector<ObjectID> return_ids;
SendInternal(std::move(buffer), function, TASK_OPTION_RETURN_NUM_0, return_ids);
}
std::shared_ptr<LocalMemoryBuffer> Transport::SendForResult(
std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function,
int64_t timeout_ms) {
std::vector<ObjectID> return_ids;
SendInternal(buffer, function, TASK_OPTION_RETURN_NUM_1, return_ids);
std::vector<std::shared_ptr<RayObject>> results;
Status get_st = core_worker_->Get(return_ids, timeout_ms, &results);
if (!get_st.ok()) {
STREAMING_LOG(ERROR) << "Get fail.";
return nullptr;
}
STREAMING_CHECK(results.size() >= 1);
if (results[0]->IsException()) {
STREAMING_LOG(ERROR) << "peer actor may has exceptions, should retry.";
return nullptr;
}
STREAMING_CHECK(results[0]->HasData());
if (results[0]->GetData()->Size() == 4) {
STREAMING_LOG(WARNING) << "peer actor may not ready yet, should retry.";
return nullptr;
}
std::shared_ptr<Buffer> result_buffer = results[0]->GetData();
std::shared_ptr<LocalMemoryBuffer> return_buffer = std::make_shared<LocalMemoryBuffer>(
result_buffer->Data(), result_buffer->Size(), true);
return return_buffer;
}
std::shared_ptr<LocalMemoryBuffer> Transport::SendForResultWithRetry(
std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function, int retry_cnt,
int64_t timeout_ms) {
STREAMING_LOG(INFO) << "SendForResultWithRetry retry_cnt: " << retry_cnt
<< " timeout_ms: " << timeout_ms
<< " function: " << function.GetFunctionDescriptor()[0];
std::shared_ptr<LocalMemoryBuffer> buffer_shared = std::move(buffer);
for (int cnt = 0; cnt < retry_cnt; cnt++) {
auto result = SendForResult(buffer_shared, function, timeout_ms);
if (result != nullptr) {
return result;
}
}
STREAMING_LOG(WARNING) << "SendForResultWithRetry fail after retry.";
return nullptr;
}
} // namespace streaming
} // namespace ray

View file

@ -0,0 +1,63 @@
#ifndef _STREAMING_QUEUE_TRANSPORT_H_
#define _STREAMING_QUEUE_TRANSPORT_H_
#include "ray/common/id.h"
#include "ray/core_worker/core_worker.h"
#include "util/streaming_logging.h"
namespace ray {
namespace streaming {
/// Transport is the transfer endpoint to a specific actor, buffers can be sent to peer
/// through direct actor call.
class Transport {
public:
/// Construct a Transport object.
/// \param[in] core_worker CoreWorker C++ pointer of current actor, which we call direct
/// actor call interface with.
/// \param[in] peer_actor_id actor id of peer actor.
Transport(CoreWorker *core_worker, const ActorID &peer_actor_id)
: core_worker_(core_worker), peer_actor_id_(peer_actor_id) {}
virtual ~Transport() = default;
/// Send buffer asynchronously, peer's `function` will be called.
/// \param[in] buffer buffer to be sent.
/// \param[in] function the function descriptor of peer's function.
virtual void Send(std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function);
/// Send buffer synchronously, peer's `function` will be called, and return the peer
/// function's return value.
/// \param[in] buffer buffer to be sent.
/// \param[in] function the function descriptor of peer's function.
/// \param[in] timeout_ms max time to wait for result.
/// \return peer function's result.
virtual std::shared_ptr<LocalMemoryBuffer> SendForResult(
std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function,
int64_t timeout_ms);
/// Send buffer and get result with retry.
/// return value.
/// \param[in] buffer buffer to be sent.
/// \param[in] function the function descriptor of peer's function.
/// \param[in] max retry count
/// \param[in] timeout_ms max time to wait for result.
/// \return peer function's result.
std::shared_ptr<LocalMemoryBuffer> SendForResultWithRetry(
std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function, int retry_cnt,
int64_t timeout_ms);
private:
/// Send buffer internal
/// \param[in] buffer buffer to be sent.
/// \param[in] function the function descriptor of peer's function.
/// \param[in] return_num return value number of the call.
/// \param[out] return_ids return ids from SubmitActorTask.
virtual void SendInternal(std::shared_ptr<LocalMemoryBuffer> buffer,
RayFunction &function, int return_num,
std::vector<ObjectID> &return_ids);
private:
CoreWorker *core_worker_;
ActorID peer_actor_id_;
};
} // namespace streaming
} // namespace ray
#endif

View file

@ -0,0 +1,50 @@
#ifndef _STREAMING_QUEUE_UTILS_H_
#define _STREAMING_QUEUE_UTILS_H_
#include <chrono>
#include <future>
#include <thread>
#include "ray/util/util.h"
namespace ray {
namespace streaming {
/// Helper class encapulate std::future to help multithread async wait.
class PromiseWrapper {
public:
Status Wait() {
std::future<bool> fut = promise_.get_future();
fut.get();
return status_;
}
Status WaitFor(uint64_t timeout_ms) {
std::future<bool> fut = promise_.get_future();
std::future_status status;
do {
status = fut.wait_for(std::chrono::milliseconds(timeout_ms));
if (status == std::future_status::deferred) {
} else if (status == std::future_status::timeout) {
return Status::Invalid("timeout");
} else if (status == std::future_status::ready) {
return status_;
}
} while (status == std::future_status::deferred);
return status_;
}
void Notify(Status status) {
status_ = status;
promise_.set_value(true);
}
Status GetResultStatus() { return status_; }
private:
std::promise<bool> promise_;
Status status_;
};
} // namespace streaming
} // namespace ray
#endif

View file

@ -0,0 +1,82 @@
#include "ring_buffer.h"
#include "util/streaming_logging.h"
namespace ray {
namespace streaming {
StreamingRingBuffer::StreamingRingBuffer(size_t buf_size,
StreamingRingBufferType buffer_type) {
switch (buffer_type) {
case StreamingRingBufferType::SPSC:
message_buffer_ =
std::make_shared<RingBufferImplLockFree<StreamingMessagePtr>>(buf_size);
break;
case StreamingRingBufferType::SPSC_LOCK:
default:
message_buffer_ =
std::make_shared<RingBufferImplThreadSafe<StreamingMessagePtr>>(buf_size);
}
}
bool StreamingRingBuffer::Push(const StreamingMessagePtr &msg) {
message_buffer_->Push(msg);
return true;
}
bool StreamingRingBuffer::Push(StreamingMessagePtr &&msg) {
message_buffer_->Push(std::forward<StreamingMessagePtr>(msg));
return true;
}
StreamingMessagePtr &StreamingRingBuffer::Front() {
STREAMING_CHECK(!message_buffer_->Empty());
return message_buffer_->Front();
}
void StreamingRingBuffer::Pop() {
STREAMING_CHECK(!message_buffer_->Empty());
message_buffer_->Pop();
}
bool StreamingRingBuffer::IsFull() { return message_buffer_->Full(); }
bool StreamingRingBuffer::IsEmpty() { return message_buffer_->Empty(); }
size_t StreamingRingBuffer::Size() { return message_buffer_->Size(); };
size_t StreamingRingBuffer::Capacity() const { return message_buffer_->Capacity(); }
size_t StreamingRingBuffer::GetTransientBufferSize() {
return transient_buffer_.GetTransientBufferSize();
};
void StreamingRingBuffer::SetTransientBufferSize(uint32_t new_transient_buffer_size) {
return transient_buffer_.SetTransientBufferSize(new_transient_buffer_size);
}
size_t StreamingRingBuffer::GetMaxTransientBufferSize() const {
return transient_buffer_.GetMaxTransientBufferSize();
}
const uint8_t *StreamingRingBuffer::GetTransientBuffer() const {
return transient_buffer_.GetTransientBuffer();
}
uint8_t *StreamingRingBuffer::GetTransientBufferMutable() const {
return transient_buffer_.GetTransientBufferMutable();
}
void StreamingRingBuffer::ReallocTransientBuffer(uint32_t size) {
transient_buffer_.ReallocTransientBuffer(size);
}
bool StreamingRingBuffer::IsTransientAvaliable() {
return transient_buffer_.IsTransientAvaliable();
}
void StreamingRingBuffer::FreeTransientBuffer(bool is_force) {
transient_buffer_.FreeTransientBuffer(is_force);
}
} // namespace streaming
} // namespace ray

233
streaming/src/ring_buffer.h Normal file
View file

@ -0,0 +1,233 @@
#ifndef RAY_RING_BUFFER_H
#define RAY_RING_BUFFER_H
#include <atomic>
#include <boost/circular_buffer.hpp>
#include <boost/thread/locks.hpp>
#include <boost/thread/shared_mutex.hpp>
#include <condition_variable>
#include <memory>
#include <mutex>
#include <queue>
#include "message/message.h"
#include "ray/common/status.h"
#include "util/streaming_logging.h"
namespace ray {
namespace streaming {
/// Because the data cannot be successfully written to the channel every time, in
/// order not to serialize the message repeatedly, we designed a temporary buffer
/// area so that when the downstream is backpressured or the channel is blocked
/// due to memory limitations, it can be cached first and waited for the next use.
class StreamingTransientBuffer {
private:
std::shared_ptr<uint8_t> transient_buffer_;
// BufferSize is length of last serialization data.
uint32_t transient_buffer_size_ = 0;
uint32_t max_transient_buffer_size_ = 0;
bool transient_flag_ = false;
public:
inline size_t GetTransientBufferSize() const { return transient_buffer_size_; }
inline void SetTransientBufferSize(uint32_t new_transient_buffer_size) {
transient_buffer_size_ = new_transient_buffer_size;
}
inline size_t GetMaxTransientBufferSize() const { return max_transient_buffer_size_; }
inline const uint8_t *GetTransientBuffer() const { return transient_buffer_.get(); }
inline uint8_t *GetTransientBufferMutable() const { return transient_buffer_.get(); }
/// To reuse transient buffer, we will realloc buffer memory if size of needed
/// message bundle raw data is greater-than original buffer size.
/// \param size buffer size
///
inline void ReallocTransientBuffer(uint32_t size) {
transient_buffer_size_ = size;
transient_flag_ = true;
if (max_transient_buffer_size_ > size) {
return;
}
max_transient_buffer_size_ = size;
transient_buffer_.reset(new uint8_t[size], std::default_delete<uint8_t[]>());
}
inline bool IsTransientAvaliable() { return transient_flag_; }
inline void FreeTransientBuffer(bool is_force = false) {
transient_buffer_size_ = 0;
transient_flag_ = false;
// Transient buffer always holds max size buffer among all messages, which is
// wasteful. So expiration time is considerable idea to release large buffer if this
// transient buffer pointer hold it in long time.
if (is_force) {
max_transient_buffer_size_ = 0;
transient_buffer_.reset();
}
}
virtual ~StreamingTransientBuffer() = default;
};
template <class T>
class AbstractRingBufferImpl {
public:
virtual void Push(T &&) = 0;
virtual void Push(const T &) = 0;
virtual void Pop() = 0;
virtual T &Front() = 0;
virtual bool Empty() = 0;
virtual bool Full() = 0;
virtual size_t Size() = 0;
virtual size_t Capacity() = 0;
};
template <class T>
class RingBufferImplThreadSafe : public AbstractRingBufferImpl<T> {
private:
boost::shared_mutex ring_buffer_mutex_;
boost::circular_buffer<T> buffer_;
public:
RingBufferImplThreadSafe(size_t size) : buffer_(size) {}
virtual ~RingBufferImplThreadSafe() = default;
void Push(T &&t) {
boost::unique_lock<boost::shared_mutex> lock(ring_buffer_mutex_);
buffer_.push_back(t);
}
void Push(const T &t) {
boost::unique_lock<boost::shared_mutex> lock(ring_buffer_mutex_);
buffer_.push_back(t);
}
void Pop() {
boost::unique_lock<boost::shared_mutex> lock(ring_buffer_mutex_);
buffer_.pop_front();
}
T &Front() {
boost::shared_lock<boost::shared_mutex> lock(ring_buffer_mutex_);
return buffer_.front();
}
bool Empty() {
boost::shared_lock<boost::shared_mutex> lock(ring_buffer_mutex_);
return buffer_.empty();
}
bool Full() {
boost::shared_lock<boost::shared_mutex> lock(ring_buffer_mutex_);
return buffer_.full();
}
size_t Size() {
boost::shared_lock<boost::shared_mutex> lock(ring_buffer_mutex_);
return buffer_.size();
}
size_t Capacity() { return buffer_.capacity(); }
};
template <class T>
class RingBufferImplLockFree : public AbstractRingBufferImpl<T> {
private:
std::vector<T> buffer_;
std::atomic<size_t> capacity_;
std::atomic<size_t> read_index_;
std::atomic<size_t> write_index_;
public:
RingBufferImplLockFree(size_t size)
: buffer_(size, nullptr), capacity_(size), read_index_(0), write_index_(0) {}
virtual ~RingBufferImplLockFree() = default;
void Push(T &&t) {
STREAMING_CHECK(!Full());
buffer_[write_index_] = t;
write_index_ = IncreaseIndex(write_index_);
}
void Push(const T &t) {
STREAMING_CHECK(!Full());
buffer_[write_index_] = t;
write_index_ = IncreaseIndex(write_index_);
}
void Pop() {
STREAMING_CHECK(!Empty());
read_index_ = IncreaseIndex(read_index_);
}
T &Front() {
STREAMING_CHECK(!Empty());
return buffer_[read_index_];
}
bool Empty() { return write_index_ == read_index_; }
bool Full() { return IncreaseIndex(write_index_) == read_index_; }
size_t Size() { return (write_index_ + capacity_ - read_index_) % capacity_; }
size_t Capacity() { return capacity_; }
private:
size_t IncreaseIndex(size_t index) const { return (index + 1) % capacity_; }
};
enum class StreamingRingBufferType : uint8_t { SPSC_LOCK, SPSC };
/// StreamingRinggBuffer is factory to generate two different buffers. In data
/// writer, we use lock-free single producer single consumer (SPSC) ring buffer
/// to hold messages from user thread because SPSC has much better performance
/// than lock style. Since the SPSC_LOCK is useful to our event-driver model(
/// we will use that buffer to optimize our thread model in the future), so
/// it cann't be removed currently.
class StreamingRingBuffer {
private:
std::shared_ptr<AbstractRingBufferImpl<StreamingMessagePtr>> message_buffer_;
StreamingTransientBuffer transient_buffer_;
public:
explicit StreamingRingBuffer(size_t buf_size, StreamingRingBufferType buffer_type =
StreamingRingBufferType::SPSC_LOCK);
bool Push(StreamingMessagePtr &&msg);
bool Push(const StreamingMessagePtr &msg);
StreamingMessagePtr &Front();
void Pop();
bool IsFull();
bool IsEmpty();
size_t Size();
size_t Capacity() const;
size_t GetTransientBufferSize();
void SetTransientBufferSize(uint32_t new_transient_buffer_size);
size_t GetMaxTransientBufferSize() const;
const uint8_t *GetTransientBuffer() const;
uint8_t *GetTransientBufferMutable() const;
void ReallocTransientBuffer(uint32_t size);
bool IsTransientAvaliable();
void FreeTransientBuffer(bool is_force = false);
};
typedef std::shared_ptr<StreamingRingBuffer> StreamingRingBufferPtr;
} // namespace streaming
} // namespace ray
#endif // RAY_RING_BUFFER_H

View file

@ -0,0 +1,32 @@
#include "ray/common/id.h"
#include "ray/protobuf/common.pb.h"
#include "ray/util/util.h"
#include "runtime_context.h"
#include "util/streaming_logging.h"
namespace ray {
namespace streaming {
void RuntimeContext::SetConfig(const StreamingConfig &streaming_config) {
STREAMING_CHECK(runtime_status_ == RuntimeStatus::Init)
<< "set config must be at beginning";
config_ = streaming_config;
}
void RuntimeContext::SetConfig(const uint8_t *data, uint32_t size) {
STREAMING_CHECK(runtime_status_ == RuntimeStatus::Init)
<< "set config must be at beginning";
if (!data) {
STREAMING_LOG(WARNING) << "buffer pointer is null, but len is => " << size;
return;
}
config_.FromProto(data, size);
}
RuntimeContext::~RuntimeContext() {}
RuntimeContext::RuntimeContext() : runtime_status_(RuntimeStatus::Init) {}
} // namespace streaming
} // namespace ray

View file

@ -0,0 +1,42 @@
#ifndef RAY_STREAMING_H
#define RAY_STREAMING_H
#include <string>
#include "config/streaming_config.h"
#include "status.h"
namespace ray {
namespace streaming {
enum class RuntimeStatus : uint8_t { Init = 0, Running = 1, Interrupted = 2 };
#define RETURN_IF_NOT_OK(STATUS_EXP) \
{ \
StreamingStatus state = STATUS_EXP; \
if (StreamingStatus::OK != state) { \
return state; \
} \
}
class RuntimeContext {
public:
RuntimeContext();
virtual ~RuntimeContext();
inline const StreamingConfig &GetConfig() const { return config_; };
void SetConfig(const StreamingConfig &config);
void SetConfig(const uint8_t *data, uint32_t buffer_len);
inline RuntimeStatus GetRuntimeStatus() { return runtime_status_; }
inline void SetRuntimeStatus(RuntimeStatus status) { runtime_status_ = status; }
inline void MarkMockTest() { is_mock_test_ = true; }
inline bool IsMockTest() { return is_mock_test_; }
private:
StreamingConfig config_;
RuntimeStatus runtime_status_;
bool is_mock_test_ = false;
};
} // namespace streaming
} // namespace ray
#endif // RAY_STREAMING_H

47
streaming/src/status.h Normal file
View file

@ -0,0 +1,47 @@
#ifndef RAY_STREAMING_STATUS_H
#define RAY_STREAMING_STATUS_H
#include <ostream>
#include <sstream>
#include <string>
namespace ray {
namespace streaming {
enum class StreamingStatus : uint32_t {
OK = 0,
ReconstructTimeOut = 1,
QueueIdNotFound = 3,
ResubscribeFailed = 4,
EmptyRingBuffer = 5,
FullChannel = 6,
NoSuchItem = 7,
InitQueueFailed = 8,
GetBundleTimeOut = 9,
SkipSendEmptyMessage = 10,
Interrupted = 11,
WaitQueueTimeOut = 12,
OutOfMemory = 13,
Invalid = 14,
UnknownError = 15,
TailStatus = 999,
MIN = OK,
MAX = TailStatus
};
static inline std::ostream &operator<<(std::ostream &os, const StreamingStatus &status) {
os << static_cast<std::underlying_type<StreamingStatus>::type>(status);
return os;
}
#define RETURN_IF_NOT_OK(STATUS_EXP) \
{ \
StreamingStatus state = STATUS_EXP; \
if (StreamingStatus::OK != state) { \
return state; \
} \
}
} // namespace streaming
} // namespace ray
#endif // RAY_STREAMING_STATUS_H

View file

@ -0,0 +1,176 @@
#include <cstring>
#include <string>
#include "gtest/gtest.h"
#include "message/message.h"
#include "message/message_bundle.h"
using namespace ray;
using namespace ray::streaming;
TEST(StreamingSerializationTest, streaming_message_serialization_test) {
uint8_t data[] = {9, 1, 3};
StreamingMessagePtr message =
std::make_shared<StreamingMessage>(data, 3, 7, StreamingMessageType::Message);
uint32_t message_length = message->ClassBytesSize();
uint8_t *bytes = new uint8_t[message_length];
message->ToBytes(bytes);
StreamingMessagePtr new_message = StreamingMessage::FromBytes(bytes);
EXPECT_EQ(std::memcmp(new_message->RawData(), data, 3), 0);
delete[] bytes;
}
TEST(StreamingSerializationTest, streaming_message_empty_bundle_serialization_test) {
for (int i = 0; i < 10; ++i) {
StreamingMessageBundle bundle(i, i);
uint64_t bundle_size = bundle.ClassBytesSize();
uint8_t *bundle_bytes = new uint8_t[bundle_size];
bundle.ToBytes(bundle_bytes);
StreamingMessageBundlePtr bundle_ptr =
StreamingMessageBundle::FromBytes(bundle_bytes);
EXPECT_EQ(bundle.ClassBytesSize(), bundle_ptr->ClassBytesSize());
EXPECT_EQ(bundle.GetMessageListSize(), bundle_ptr->GetMessageListSize());
EXPECT_EQ(bundle.GetBundleType(), bundle_ptr->GetBundleType());
EXPECT_EQ(bundle.GetLastMessageId(), bundle_ptr->GetLastMessageId());
std::list<StreamingMessagePtr> s_message_list;
bundle_ptr->GetMessageList(s_message_list);
std::list<StreamingMessagePtr> b_message_list;
bundle.GetMessageList(b_message_list);
EXPECT_EQ(b_message_list.size(), 0);
EXPECT_EQ(s_message_list.size(), 0);
delete[] bundle_bytes;
}
}
TEST(StreamingSerializationTest, streaming_message_barrier_bundle_serialization_test) {
for (int i = 0; i < 10; ++i) {
uint8_t data[] = {1, 2, 3, 4};
uint32_t data_size = 4;
uint32_t head_size = sizeof(uint64_t);
uint64_t checkpoint_id = 777;
std::shared_ptr<uint8_t> ptr(new uint8_t[data_size + head_size],
std::default_delete<uint8_t[]>());
// move checkpint_id in head of barrier data
std::memcpy(ptr.get(), &checkpoint_id, head_size);
std::memcpy(ptr.get() + head_size, data, data_size);
StreamingMessagePtr message = std::make_shared<StreamingMessage>(
data, head_size + data_size, i, StreamingMessageType::Barrier);
std::list<StreamingMessagePtr> message_list;
message_list.push_back(message);
// message list will be moved to bundle member
std::list<StreamingMessagePtr> message_list_cpy(message_list);
StreamingMessageBundle bundle(message_list_cpy, i, i,
StreamingMessageBundleType::Barrier);
uint64_t bundle_size = bundle.ClassBytesSize();
uint8_t *bundle_bytes = new uint8_t[bundle_size];
bundle.ToBytes(bundle_bytes);
StreamingMessageBundlePtr bundle_ptr =
StreamingMessageBundle::FromBytes(bundle_bytes);
EXPECT_TRUE(bundle.ClassBytesSize() == bundle_ptr->ClassBytesSize());
EXPECT_TRUE(bundle.GetMessageListSize() == bundle_ptr->GetMessageListSize());
EXPECT_TRUE(bundle.GetBundleType() == bundle_ptr->GetBundleType());
EXPECT_TRUE(bundle.GetLastMessageId() == bundle_ptr->GetLastMessageId());
std::list<StreamingMessagePtr> s_message_list;
bundle_ptr->GetMessageList(s_message_list);
EXPECT_TRUE(s_message_list.size() == message_list.size());
auto m_item = message_list.back();
auto s_item = s_message_list.back();
EXPECT_TRUE(s_item->ClassBytesSize() == m_item->ClassBytesSize());
EXPECT_TRUE(s_item->GetMessageType() == m_item->GetMessageType());
EXPECT_TRUE(s_item->GetMessageSeqId() == m_item->GetMessageSeqId());
EXPECT_TRUE(s_item->GetDataSize() == m_item->GetDataSize());
EXPECT_TRUE(
std::memcmp(s_item->RawData(), m_item->RawData(), m_item->GetDataSize()) == 0);
EXPECT_TRUE(*(s_item.get()) == (*(m_item.get())));
delete[] bundle_bytes;
}
}
TEST(StreamingSerializationTest, streaming_message_bundle_serialization_test) {
for (int k = 0; k <= 1000; k++) {
std::list<StreamingMessagePtr> message_list;
for (int i = 0; i < 100; ++i) {
uint8_t *data = new uint8_t[i + 1];
data[0] = i;
StreamingMessagePtr message = std::make_shared<StreamingMessage>(
data, i + 1, i + 1, StreamingMessageType::Message);
message_list.push_back(message);
delete[] data;
}
StreamingMessageBundle messageBundle(message_list, 0, 1,
StreamingMessageBundleType::Bundle);
size_t message_length = messageBundle.ClassBytesSize();
uint8_t *bytes = new uint8_t[message_length];
messageBundle.ToBytes(bytes);
StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(bytes);
EXPECT_EQ(bundle_ptr->ClassBytesSize(), message_length);
std::list<StreamingMessagePtr> s_message_list;
bundle_ptr->GetMessageList(s_message_list);
EXPECT_TRUE(bundle_ptr->operator==(messageBundle));
StreamingMessageBundleMetaPtr bundle_meta_ptr =
StreamingMessageBundleMeta::FromBytes(bytes);
EXPECT_EQ(bundle_meta_ptr->GetBundleType(), bundle_ptr->GetBundleType());
EXPECT_EQ(bundle_meta_ptr->GetLastMessageId(), bundle_ptr->GetLastMessageId());
EXPECT_EQ(bundle_meta_ptr->GetMessageBundleTs(), bundle_ptr->GetMessageBundleTs());
EXPECT_EQ(bundle_meta_ptr->GetMessageListSize(), bundle_ptr->GetMessageListSize());
delete[] bytes;
}
}
TEST(StreamingSerializationTest, streaming_message_bundle_equal_test) {
std::list<StreamingMessagePtr> message_list;
std::list<StreamingMessagePtr> message_list_same;
std::list<StreamingMessagePtr> message_list_cpy;
for (int i = 0; i < 100; ++i) {
uint8_t *data = new uint8_t[i + 1];
for (int j = 0; j < i + 1; ++j) {
data[j] = i;
}
StreamingMessagePtr message = std::make_shared<StreamingMessage>(
data, i + 1, i + 1, StreamingMessageType::Message);
message_list.push_back(message);
message_list_cpy.push_front(message);
delete[] data;
}
for (int i = 0; i < 100; ++i) {
uint8_t *data = new uint8_t[i + 1];
for (int j = 0; j < i + 1; ++j) {
data[j] = i;
}
StreamingMessagePtr message = std::make_shared<StreamingMessage>(
data, i + 1, i + 1, StreamingMessageType::Message);
message_list_same.push_back(message);
delete[] data;
}
StreamingMessageBundle message_bundle(message_list, 0, 1,
StreamingMessageBundleType::Bundle);
StreamingMessageBundle message_bundle_same(message_list_same, 0, 1,
StreamingMessageBundleType::Bundle);
StreamingMessageBundle message_bundle_reverse(message_list_cpy, 0, 1,
StreamingMessageBundleType::Bundle);
EXPECT_TRUE(message_bundle_same == message_bundle);
EXPECT_FALSE(message_bundle_reverse == message_bundle);
size_t message_length = message_bundle.ClassBytesSize();
uint8_t *bytes = new uint8_t[message_length];
message_bundle.ToBytes(bytes);
StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(bytes);
EXPECT_EQ(bundle_ptr->ClassBytesSize(), message_length);
std::list<StreamingMessagePtr> s_message_list;
bundle_ptr->GetMessageList(s_message_list);
EXPECT_TRUE(bundle_ptr->operator==(message_bundle));
delete[] bytes;
}
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View file

@ -0,0 +1,439 @@
#define BOOST_BIND_NO_PLACEHOLDERS
#include "ray/core_worker/context.h"
#include "ray/core_worker/core_worker.h"
#include "src/ray/util/test_util.h"
#include "data_reader.h"
#include "data_writer.h"
#include "message/message.h"
#include "message/message_bundle.h"
#include "queue/queue_client.h"
#include "ring_buffer.h"
#include "status.h"
#include "gtest/gtest.h"
using namespace std::placeholders;
const uint32_t MESSAGE_BOUND_SIZE = 10000;
const uint32_t DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE = 1000;
namespace ray {
namespace streaming {
class StreamingQueueTestSuite {
public:
StreamingQueueTestSuite(std::shared_ptr<CoreWorker> core_worker, ActorID &peer_actor_id,
std::vector<ObjectID> queue_ids,
std::vector<ObjectID> rescale_queue_ids)
: core_worker_(core_worker),
peer_actor_id_(peer_actor_id),
queue_ids_(queue_ids),
rescale_queue_ids_(rescale_queue_ids) {}
virtual void ExecuteTest(std::string test_name) {
auto it = test_func_map_.find(test_name);
STREAMING_CHECK(it != test_func_map_.end());
current_test_ = test_name;
status_ = false;
auto func = it->second;
executor_thread_ = std::make_shared<std::thread>(func);
executor_thread_->detach();
}
virtual std::shared_ptr<LocalMemoryBuffer> CheckCurTestStatus() {
TestCheckStatusRspMsg msg(current_test_, status_);
return msg.ToBytes();
}
virtual bool TestDone() { return status_; }
virtual ~StreamingQueueTestSuite() {}
protected:
std::unordered_map<std::string, std::function<void()>> test_func_map_;
std::string current_test_;
bool status_;
std::shared_ptr<std::thread> executor_thread_;
std::shared_ptr<CoreWorker> core_worker_;
ActorID peer_actor_id_;
std::vector<ObjectID> queue_ids_;
std::vector<ObjectID> rescale_queue_ids_;
};
class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite {
public:
StreamingQueueWriterTestSuite(std::shared_ptr<CoreWorker> core_worker,
ActorID &peer_actor_id, std::vector<ObjectID> queue_ids,
std::vector<ObjectID> rescale_queue_ids)
: StreamingQueueTestSuite(core_worker, peer_actor_id, queue_ids,
rescale_queue_ids) {
test_func_map_ = {
{"streaming_writer_exactly_once_test",
std::bind(&StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest,
this)}};
}
private:
void TestWriteMessageToBufferRing(std::shared_ptr<DataWriter> writer_client,
std::vector<ray::ObjectID> &q_list) {
// const uint8_t temp_data[] = {1, 2, 4, 5};
uint32_t i = 1;
while (i <= MESSAGE_BOUND_SIZE) {
for (auto &q_id : q_list) {
uint64_t buffer_len = (i % DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE);
uint8_t *data = new uint8_t[buffer_len];
for (uint32_t j = 0; j < buffer_len; ++j) {
data[j] = j % 128;
}
writer_client->WriteMessageToBufferRing(q_id, data, buffer_len,
StreamingMessageType::Message);
}
++i;
}
// Wait a while
std::this_thread::sleep_for(std::chrono::milliseconds(5000));
}
void StreamingWriterStrategyTest(StreamingConfig &config) {
for (auto &queue_id : queue_ids_) {
STREAMING_LOG(INFO) << "queue_id: " << queue_id;
}
std::vector<ActorID> actor_ids(queue_ids_.size(), peer_actor_id_);
STREAMING_LOG(INFO) << "writer actor_ids size: " << actor_ids.size()
<< " actor_id: " << peer_actor_id_;
std::shared_ptr<RuntimeContext> runtime_context(new RuntimeContext());
runtime_context->SetConfig(config);
std::shared_ptr<DataWriter> streaming_writer_client(new DataWriter(runtime_context));
uint64_t queue_size = 10 * 1000 * 1000;
std::vector<uint64_t> channel_seq_id_vec(queue_ids_.size(), 0);
streaming_writer_client->Init(queue_ids_, actor_ids, channel_seq_id_vec,
std::vector<uint64_t>(queue_ids_.size(), queue_size));
STREAMING_LOG(INFO) << "streaming_writer_client Init done";
streaming_writer_client->Run();
std::thread test_loop_thread(
&StreamingQueueWriterTestSuite::TestWriteMessageToBufferRing, this,
streaming_writer_client, std::ref(queue_ids_));
// test_loop_thread.detach();
if (test_loop_thread.joinable()) {
test_loop_thread.join();
}
}
void StreamingWriterExactlyOnceTest() {
StreamingConfig config;
StreamingWriterStrategyTest(config);
STREAMING_LOG(INFO)
<< "StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest";
status_ = true;
}
};
class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite {
public:
StreamingQueueReaderTestSuite(std::shared_ptr<CoreWorker> core_worker,
ActorID peer_actor_id, std::vector<ObjectID> queue_ids,
std::vector<ObjectID> rescale_queue_ids)
: StreamingQueueTestSuite(core_worker, peer_actor_id, queue_ids,
rescale_queue_ids) {
test_func_map_ = {
{"streaming_writer_exactly_once_test",
std::bind(&StreamingQueueReaderTestSuite::StreamingWriterExactlyOnceTest,
this)}};
}
private:
void ReaderLoopForward(std::shared_ptr<DataReader> reader_client,
std::shared_ptr<DataWriter> writer_client,
std::vector<ray::ObjectID> &queue_id_vec) {
uint64_t recevied_message_cnt = 0;
std::unordered_map<ray::ObjectID, uint64_t> queue_last_cp_id;
for (auto &q_id : queue_id_vec) {
queue_last_cp_id[q_id] = 0;
}
STREAMING_LOG(INFO) << "Start read message bundle";
while (true) {
std::shared_ptr<DataBundle> msg;
StreamingStatus st = reader_client->GetBundle(100, msg);
if (st != StreamingStatus::OK || !msg->data) {
STREAMING_LOG(DEBUG) << "read bundle timeout, status = " << (int)st;
continue;
}
STREAMING_CHECK(msg.get() && msg->meta.get())
<< "read null pointer message, queue id => " << msg->from.Hex();
if (msg->meta->GetBundleType() == StreamingMessageBundleType::Barrier) {
STREAMING_LOG(DEBUG) << "barrier message recevied => "
<< msg->meta->GetMessageBundleTs();
std::unordered_map<ray::ObjectID, ConsumerChannelInfo> *offset_map;
reader_client->GetOffsetInfo(offset_map);
for (auto &q_id : queue_id_vec) {
reader_client->NotifyConsumedItem((*offset_map)[q_id],
(*offset_map)[q_id].current_seq_id);
}
// writer_client->ClearCheckpoint(msg->last_barrier_id);
continue;
} else if (msg->meta->GetBundleType() == StreamingMessageBundleType::Empty) {
STREAMING_LOG(DEBUG) << "empty message recevied => "
<< msg->meta->GetMessageBundleTs();
continue;
}
StreamingMessageBundlePtr bundlePtr;
bundlePtr = StreamingMessageBundle::FromBytes(msg->data);
std::list<StreamingMessagePtr> message_list;
bundlePtr->GetMessageList(message_list);
STREAMING_LOG(INFO) << "message size => " << message_list.size()
<< " from queue id => " << msg->from.Hex()
<< " last message id => " << msg->meta->GetLastMessageId();
recevied_message_cnt += message_list.size();
for (auto &item : message_list) {
uint64_t i = item->GetMessageSeqId();
uint32_t buff_len = i % DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE;
if (i > MESSAGE_BOUND_SIZE) break;
EXPECT_EQ(buff_len, item->GetDataSize());
uint8_t *compared_data = new uint8_t[buff_len];
for (uint32_t j = 0; j < item->GetDataSize(); ++j) {
compared_data[j] = j % 128;
}
EXPECT_EQ(std::memcmp(compared_data, item->RawData(), item->GetDataSize()), 0);
delete[] compared_data;
}
STREAMING_LOG(DEBUG) << "Received message count => " << recevied_message_cnt;
if (recevied_message_cnt == queue_id_vec.size() * MESSAGE_BOUND_SIZE) {
STREAMING_LOG(INFO) << "recevied message count => " << recevied_message_cnt
<< ", break";
break;
}
}
}
void StreamingReaderStrategyTest(StreamingConfig &config) {
std::vector<ActorID> actor_ids(queue_ids_.size(), peer_actor_id_);
STREAMING_LOG(INFO) << "reader actor_ids size: " << actor_ids.size()
<< " actor_id: " << peer_actor_id_;
std::shared_ptr<RuntimeContext> runtime_context(new RuntimeContext());
runtime_context->SetConfig(config);
std::shared_ptr<DataReader> reader(new DataReader(runtime_context));
reader->Init(queue_ids_, actor_ids, -1);
ReaderLoopForward(reader, nullptr, queue_ids_);
STREAMING_LOG(INFO) << "Reader exit";
}
void StreamingWriterExactlyOnceTest() {
STREAMING_LOG(INFO)
<< "StreamingQueueReaderTestSuite::StreamingWriterExactlyOnceTest";
StreamingConfig config;
StreamingReaderStrategyTest(config);
status_ = true;
}
};
class TestSuiteFactory {
public:
static std::shared_ptr<StreamingQueueTestSuite> CreateTestSuite(
std::shared_ptr<CoreWorker> worker, std::shared_ptr<TestInitMessage> message) {
std::shared_ptr<StreamingQueueTestSuite> test_suite = nullptr;
std::string suite_name = message->TestSuiteName();
queue::protobuf::StreamingQueueTestRole role = message->Role();
const std::vector<ObjectID> &queue_ids = message->QueueIds();
const std::vector<ObjectID> &rescale_queue_ids = message->RescaleQueueIds();
ActorID peer_actor_id = message->PeerActorId();
if (role == queue::protobuf::StreamingQueueTestRole::WRITER) {
if (suite_name == "StreamingWriterTest") {
test_suite = std::make_shared<StreamingQueueWriterTestSuite>(
worker, peer_actor_id, queue_ids, rescale_queue_ids);
} else {
STREAMING_CHECK(false) << "unsurported suite_name: " << suite_name;
}
} else {
if (suite_name == "StreamingWriterTest") {
test_suite = std::make_shared<StreamingQueueReaderTestSuite>(
worker, peer_actor_id, queue_ids, rescale_queue_ids);
} else {
STREAMING_CHECK(false) << "unsupported suite_name: " << suite_name;
}
}
return test_suite;
}
};
class StreamingWorker {
public:
StreamingWorker(const std::string &store_socket, const std::string &raylet_socket,
int node_manager_port, const gcs::GcsClientOptions &gcs_options)
: test_suite_(nullptr), peer_actor_handle_(nullptr) {
worker_ = std::make_shared<CoreWorker>(
WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket,
JobID::FromInt(1), gcs_options, "", "127.0.0.1", node_manager_port,
std::bind(&StreamingWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7));
RayFunction reader_async_call_func{ray::Language::PYTHON, {"reader_async_call_func"}};
RayFunction reader_sync_call_func{ray::Language::PYTHON, {"reader_sync_call_func"}};
RayFunction writer_async_call_func{ray::Language::PYTHON, {"writer_async_call_func"}};
RayFunction writer_sync_call_func{ray::Language::PYTHON, {"writer_sync_call_func"}};
reader_client_ = std::make_shared<ReaderClient>(worker_.get(), reader_async_call_func,
reader_sync_call_func);
writer_client_ = std::make_shared<WriterClient>(worker_.get(), writer_async_call_func,
writer_sync_call_func);
STREAMING_LOG(INFO) << "StreamingWorker constructor";
}
void StartExecutingTasks() {
// Start executing tasks.
worker_->StartExecutingTasks();
}
private:
Status ExecuteTask(TaskType task_type, const RayFunction &ray_function,
const std::unordered_map<std::string, double> &required_resources,
const std::vector<std::shared_ptr<RayObject>> &args,
const std::vector<ObjectID> &arg_reference_ids,
const std::vector<ObjectID> &return_ids,
std::vector<std::shared_ptr<RayObject>> *results) {
// Only one arg param used in streaming.
STREAMING_CHECK(args.size() >= 1) << "args.size() = " << args.size();
std::vector<std::string> function_descriptor = ray_function.GetFunctionDescriptor();
STREAMING_LOG(INFO) << "StreamingWorker::ExecuteTask " << function_descriptor[0];
std::string func_name = function_descriptor[0];
if (func_name == "init") {
std::shared_ptr<LocalMemoryBuffer> local_buffer =
std::make_shared<LocalMemoryBuffer>(args[0]->GetData()->Data(),
args[0]->GetData()->Size(), true);
HandleInitTask(local_buffer);
} else if (func_name == "execute_test") {
STREAMING_LOG(INFO) << "Test name: " << function_descriptor[1];
test_suite_->ExecuteTest(function_descriptor[1]);
} else if (func_name == "check_current_test_status") {
results->push_back(
std::make_shared<RayObject>(test_suite_->CheckCurTestStatus(), nullptr));
} else if (func_name == "reader_sync_call_func") {
if (test_suite_->TestDone()) {
STREAMING_LOG(WARNING) << "Test has done!!";
return Status::OK();
}
std::shared_ptr<LocalMemoryBuffer> local_buffer =
std::make_shared<LocalMemoryBuffer>(args[1]->GetData()->Data(),
args[1]->GetData()->Size(), true);
auto result_buffer = reader_client_->OnReaderMessageSync(local_buffer);
results->push_back(std::make_shared<RayObject>(result_buffer, nullptr));
} else if (func_name == "reader_async_call_func") {
if (test_suite_->TestDone()) {
STREAMING_LOG(WARNING) << "Test has done!!";
return Status::OK();
}
std::shared_ptr<LocalMemoryBuffer> local_buffer =
std::make_shared<LocalMemoryBuffer>(args[1]->GetData()->Data(),
args[1]->GetData()->Size(), true);
reader_client_->OnReaderMessage(local_buffer);
} else if (func_name == "writer_sync_call_func") {
if (test_suite_->TestDone()) {
STREAMING_LOG(WARNING) << "Test has done!!";
return Status::OK();
}
std::shared_ptr<LocalMemoryBuffer> local_buffer =
std::make_shared<LocalMemoryBuffer>(args[1]->GetData()->Data(),
args[1]->GetData()->Size(), true);
auto result_buffer = writer_client_->OnWriterMessageSync(local_buffer);
results->push_back(std::make_shared<RayObject>(result_buffer, nullptr));
} else if (func_name == "writer_async_call_func") {
if (test_suite_->TestDone()) {
STREAMING_LOG(WARNING) << "Test has done!!";
return Status::OK();
}
std::shared_ptr<LocalMemoryBuffer> local_buffer =
std::make_shared<LocalMemoryBuffer>(args[1]->GetData()->Data(),
args[1]->GetData()->Size(), true);
writer_client_->OnWriterMessage(local_buffer);
} else {
STREAMING_LOG(WARNING) << "Invalid function name " << func_name;
}
return Status::OK();
}
private:
void HandleInitTask(std::shared_ptr<LocalMemoryBuffer> buffer) {
uint8_t *bytes = buffer->Data();
uint8_t *p_cur = bytes;
uint32_t *magic_num = (uint32_t *)p_cur;
STREAMING_CHECK(*magic_num == Message::MagicNum);
p_cur += sizeof(Message::MagicNum);
queue::protobuf::StreamingQueueMessageType *type =
(queue::protobuf::StreamingQueueMessageType *)p_cur;
STREAMING_CHECK(
*type ==
queue::protobuf::StreamingQueueMessageType::StreamingQueueTestInitMsgType);
std::shared_ptr<TestInitMessage> message = TestInitMessage::FromBytes(bytes);
STREAMING_LOG(INFO) << "Init message: " << message->ToString();
std::string actor_handle_serialized = message->ActorHandleSerialized();
worker_->DeserializeAndRegisterActorHandle(actor_handle_serialized);
std::shared_ptr<ActorHandle> actor_handle(new ActorHandle(actor_handle_serialized));
STREAMING_CHECK(actor_handle != nullptr);
STREAMING_LOG(INFO) << " actor id from handle: " << actor_handle->GetActorID();
;
// STREAMING_LOG(INFO) << "actor_handle_serialized: " << actor_handle_serialized;
// peer_actor_handle_ =
// std::make_shared<ActorHandle>(actor_handle_serialized);
STREAMING_LOG(INFO) << "HandleInitTask queues:";
for (auto qid : message->QueueIds()) {
STREAMING_LOG(INFO) << "queue: " << qid;
}
for (auto qid : message->RescaleQueueIds()) {
STREAMING_LOG(INFO) << "rescale queue: " << qid;
}
test_suite_ = TestSuiteFactory::CreateTestSuite(worker_, message);
STREAMING_CHECK(test_suite_ != nullptr);
}
private:
std::shared_ptr<CoreWorker> worker_;
std::shared_ptr<ReaderClient> reader_client_;
std::shared_ptr<WriterClient> writer_client_;
std::shared_ptr<std::thread> test_thread_;
std::shared_ptr<StreamingQueueTestSuite> test_suite_;
std::shared_ptr<ActorHandle> peer_actor_handle_;
};
} // namespace streaming
} // namespace ray
int main(int argc, char **argv) {
RAY_CHECK(argc == 4);
auto store_socket = std::string(argv[1]);
auto raylet_socket = std::string(argv[2]);
auto node_manager_port = std::stoi(std::string(argv[3]));
ray::gcs::GcsClientOptions gcs_options("127.0.0.1", 6379, "");
ray::streaming::StreamingWorker worker(store_socket, raylet_socket, node_manager_port,
gcs_options);
worker.StartExecutingTasks();
return 0;
}

View file

@ -0,0 +1,136 @@
#include "data_reader.h"
#include "data_writer.h"
#include "gtest/gtest.h"
using namespace ray;
using namespace ray::streaming;
TEST(StreamingMockTransfer, mock_produce_consume) {
std::shared_ptr<Config> transfer_config;
ObjectID channel_id = ObjectID::FromRandom();
ProducerChannelInfo producer_channel_info;
producer_channel_info.channel_id = channel_id;
producer_channel_info.current_seq_id = 0;
MockProducer producer(transfer_config, producer_channel_info);
ConsumerChannelInfo consumer_channel_info;
consumer_channel_info.channel_id = channel_id;
MockConsumer consumer(transfer_config, consumer_channel_info);
producer.CreateTransferChannel();
uint8_t data[3] = {1, 2, 3};
producer.ProduceItemToChannel(data, 3);
uint8_t *data_consumed;
uint32_t data_size_consumed;
uint64_t data_seq_id;
consumer.ConsumeItemFromChannel(data_seq_id, data_consumed, data_size_consumed, -1);
EXPECT_EQ(data_size_consumed, 3);
EXPECT_EQ(data_seq_id, 1);
EXPECT_EQ(std::memcmp(data_consumed, data, 3), 0);
consumer.NotifyChannelConsumed(1);
auto status =
consumer.ConsumeItemFromChannel(data_seq_id, data_consumed, data_size_consumed, -1);
EXPECT_EQ(status, StreamingStatus::NoSuchItem);
}
class StreamingTransferTest : public ::testing::Test {
public:
StreamingTransferTest() {
std::shared_ptr<RuntimeContext> runtime_context(new RuntimeContext());
runtime_context->MarkMockTest();
writer = std::make_shared<DataWriter>(runtime_context);
reader = std::make_shared<DataReader>(runtime_context);
}
virtual ~StreamingTransferTest() = default;
void InitTransfer(int channel_num = 1) {
for (int i = 0; i < channel_num; ++i) {
queue_vec.push_back(ObjectID::FromRandom());
}
std::vector<uint64_t> channel_id_vec(queue_vec.size(), 0);
std::vector<uint64_t> queue_size_vec(queue_vec.size(), 10000);
// actor ids are not used in this test, so we can just use Nil.
std::vector<ActorID> actor_id_vec(queue_vec.size(),
ActorID::NilFromJob(JobID::FromInt(0)));
writer->Init(queue_vec, actor_id_vec, channel_id_vec, queue_size_vec);
reader->Init(queue_vec, actor_id_vec, channel_id_vec, queue_size_vec, -1);
}
void DestroyTransfer() {
writer.reset();
reader.reset();
}
protected:
std::shared_ptr<DataWriter> writer;
std::shared_ptr<DataReader> reader;
std::vector<ObjectID> queue_vec;
};
TEST_F(StreamingTransferTest, exchange_single_channel_test) {
InitTransfer();
writer->Run();
uint8_t data[4] = {1, 2, 3, 0xff};
uint32_t data_size = 4;
writer->WriteMessageToBufferRing(queue_vec[0], data, data_size);
std::shared_ptr<DataBundle> msg;
reader->GetBundle(5000, msg);
StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(msg->data);
auto &message_list = bundle_ptr->GetMessageList();
auto &message = message_list.front();
EXPECT_EQ(std::memcmp(message->RawData(), data, data_size), 0);
}
TEST_F(StreamingTransferTest, exchange_multichannel_test) {
int channel_num = 4;
InitTransfer(4);
writer->Run();
for (int i = 0; i < channel_num; ++i) {
uint8_t data[4] = {1, 2, 3, (uint8_t)i};
uint32_t data_size = 4;
writer->WriteMessageToBufferRing(queue_vec[i], data, data_size);
std::shared_ptr<DataBundle> msg;
reader->GetBundle(5000, msg);
EXPECT_EQ(msg->from, queue_vec[i]);
StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(msg->data);
auto &message_list = bundle_ptr->GetMessageList();
auto &message = message_list.front();
EXPECT_EQ(std::memcmp(message->RawData(), data, data_size), 0);
}
}
TEST_F(StreamingTransferTest, exchange_consumed_test) {
InitTransfer();
writer->Run();
uint32_t data_size = 8196;
std::shared_ptr<uint8_t> data(new uint8_t[data_size]);
auto func = [data, data_size](int index) { std::fill_n(data.get(), data_size, index); };
int num = 10000;
std::thread write_thread([this, data, data_size, &func, num]() {
for (uint32_t i = 0; i < num; ++i) {
func(i);
writer->WriteMessageToBufferRing(queue_vec[0], data.get(), data_size);
}
});
std::list<StreamingMessagePtr> read_message_list;
while (read_message_list.size() < num) {
std::shared_ptr<DataBundle> msg;
reader->GetBundle(5000, msg);
StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(msg->data);
auto &message_list = bundle_ptr->GetMessageList();
std::copy(message_list.begin(), message_list.end(),
std::back_inserter(read_message_list));
}
int index = 0;
for (auto &message : read_message_list) {
func(index++);
EXPECT_EQ(std::memcmp(message->RawData(), data.get(), data_size), 0);
}
write_thread.join();
}
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View file

@ -0,0 +1,313 @@
namespace ray {
namespace streaming {
ray::ObjectID RandomObjectID() { return ObjectID::FromRandom(); }
static void flushall_redis(void) {
redisContext *context = redisConnect("127.0.0.1", 6379);
freeReplyObject(redisCommand(context, "FLUSHALL"));
freeReplyObject(redisCommand(context, "SET NumRedisShards 1"));
freeReplyObject(redisCommand(context, "LPUSH RedisShards 127.0.0.1:6380"));
redisFree(context);
}
/// Base class for real-world tests with streaming queue
class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
public:
StreamingQueueTestBase(int num_nodes, std::string raylet_exe, std::string store_exe,
int port, std::string actor_exe)
: gcs_options_("127.0.0.1", 6379, ""),
raylet_executable_(raylet_exe),
store_executable_(store_exe),
actor_executable_(actor_exe),
node_manager_port_(port) {
// flush redis first.
flushall_redis();
RAY_CHECK(num_nodes >= 0);
if (num_nodes > 0) {
raylet_socket_names_.resize(num_nodes);
raylet_store_socket_names_.resize(num_nodes);
}
// start plasma store.
for (auto &store_socket : raylet_store_socket_names_) {
store_socket = StartStore();
}
// start raylet on each node. Assign each node with different resources so that
// a task can be scheduled to the desired node.
for (int i = 0; i < num_nodes; i++) {
raylet_socket_names_[i] =
StartRaylet(raylet_store_socket_names_[i], "127.0.0.1", node_manager_port_ + i,
"127.0.0.1", "\"CPU,4.0,resource" + std::to_string(i) + ",10\"");
}
}
~StreamingQueueTestBase() {
STREAMING_LOG(INFO) << "Stop raylet store and actors";
for (const auto &raylet_socket : raylet_socket_names_) {
StopRaylet(raylet_socket);
}
for (const auto &store_socket : raylet_store_socket_names_) {
StopStore(store_socket);
}
}
JobID NextJobId() const {
static uint32_t job_counter = 1;
return JobID::FromInt(job_counter++);
}
std::string StartStore() {
std::string store_socket_name = "/tmp/store" + RandomObjectID().Hex();
std::string store_pid = store_socket_name + ".pid";
std::string plasma_command = store_executable_ + " -m 10000000 -s " +
store_socket_name +
" 1> /dev/null 2> /dev/null & echo $! > " + store_pid;
RAY_LOG(DEBUG) << plasma_command;
RAY_CHECK(system(plasma_command.c_str()) == 0);
usleep(200 * 1000);
return store_socket_name;
}
void StopStore(std::string store_socket_name) {
std::string store_pid = store_socket_name + ".pid";
std::string kill_9 = "kill -9 `cat " + store_pid + "`";
RAY_LOG(DEBUG) << kill_9;
ASSERT_EQ(system(kill_9.c_str()), 0);
ASSERT_EQ(system(("rm -rf " + store_socket_name).c_str()), 0);
ASSERT_EQ(system(("rm -rf " + store_socket_name + ".pid").c_str()), 0);
}
std::string StartRaylet(std::string store_socket_name, std::string node_ip_address,
int port, std::string redis_address, std::string resource) {
std::string raylet_socket_name = "/tmp/raylet" + RandomObjectID().Hex();
std::string ray_start_cmd = raylet_executable_;
ray_start_cmd.append(" --raylet_socket_name=" + raylet_socket_name)
.append(" --store_socket_name=" + store_socket_name)
.append(" --object_manager_port=0 --node_manager_port=" + std::to_string(port))
.append(" --node_ip_address=" + node_ip_address)
.append(" --redis_address=" + redis_address)
.append(" --redis_port=6379")
.append(" --num_initial_workers=1")
.append(" --maximum_startup_concurrency=10")
.append(" --static_resource_list=" + resource)
.append(" --python_worker_command=\"" + actor_executable_ + " " +
store_socket_name + " " + raylet_socket_name + " " +
std::to_string(port) + "\"")
.append(" --config_list=initial_reconstruction_timeout_milliseconds,2000")
.append(" & echo $! > " + raylet_socket_name + ".pid");
RAY_LOG(DEBUG) << "Ray Start command: " << ray_start_cmd;
RAY_CHECK(system(ray_start_cmd.c_str()) == 0);
usleep(200 * 1000);
return raylet_socket_name;
}
void StopRaylet(std::string raylet_socket_name) {
std::string raylet_pid = raylet_socket_name + ".pid";
std::string kill_9 = "kill -9 `cat " + raylet_pid + "`";
RAY_LOG(DEBUG) << kill_9;
ASSERT_TRUE(system(kill_9.c_str()) == 0);
ASSERT_TRUE(system(("rm -rf " + raylet_socket_name).c_str()) == 0);
ASSERT_TRUE(system(("rm -rf " + raylet_socket_name + ".pid").c_str()) == 0);
}
void InitWorker(CoreWorker &driver, ActorID &self_actor_id, ActorID &peer_actor_id,
const queue::protobuf::StreamingQueueTestRole role,
const std::vector<ObjectID> &queue_ids,
const std::vector<ObjectID> &rescale_queue_ids, std::string suite_name,
std::string test_name, uint64_t param) {
std::string forked_serialized_str;
Status st = driver.SerializeActorHandle(peer_actor_id, &forked_serialized_str);
STREAMING_CHECK(st.ok());
STREAMING_LOG(INFO) << "forked_serialized_str: " << forked_serialized_str;
TestInitMessage msg(role, self_actor_id, peer_actor_id, forked_serialized_str,
queue_ids, rescale_queue_ids, suite_name, test_name, param);
std::vector<TaskArg> args;
args.emplace_back(
TaskArg::PassByValue(std::make_shared<RayObject>(msg.ToBytes(), nullptr, true)));
std::unordered_map<std::string, double> resources;
TaskOptions options{0, true, resources};
std::vector<ObjectID> return_ids;
RayFunction func{ray::Language::PYTHON, {"init"}};
RAY_CHECK_OK(driver.SubmitActorTask(self_actor_id, func, args, options, &return_ids));
}
void SubmitTestToActor(CoreWorker &driver, ActorID &actor_id, const std::string test) {
uint8_t data[8];
auto buffer = std::make_shared<LocalMemoryBuffer>(data, 8, true);
std::vector<TaskArg> args;
args.emplace_back(
TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr, true)));
std::unordered_map<std::string, double> resources;
TaskOptions options{0, true, resources};
std::vector<ObjectID> return_ids;
RayFunction func{ray::Language::PYTHON, {"execute_test", test}};
RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids));
}
bool CheckCurTest(CoreWorker &driver, ActorID &actor_id, const std::string test_name) {
uint8_t data[8];
auto buffer = std::make_shared<LocalMemoryBuffer>(data, 8, true);
std::vector<TaskArg> args;
args.emplace_back(
TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr, true)));
std::unordered_map<std::string, double> resources;
TaskOptions options{1, true, resources};
std::vector<ObjectID> return_ids;
RayFunction func{ray::Language::PYTHON, {"check_current_test_status"}};
RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids));
std::vector<bool> wait_results;
std::vector<std::shared_ptr<RayObject>> results;
Status wait_st = driver.Wait(return_ids, 1, 5 * 1000, &wait_results);
if (!wait_st.ok()) {
STREAMING_LOG(ERROR) << "Wait fail.";
return false;
}
STREAMING_CHECK(wait_results.size() >= 1);
if (!wait_results[0]) {
STREAMING_LOG(WARNING) << "Wait direct call fail.";
return false;
}
Status get_st = driver.Get(return_ids, -1, &results);
if (!get_st.ok()) {
STREAMING_LOG(ERROR) << "Get fail.";
return false;
}
STREAMING_CHECK(results.size() >= 1);
if (results[0]->IsException()) {
STREAMING_LOG(INFO) << "peer actor may has exceptions.";
return false;
}
STREAMING_CHECK(results[0]->HasData());
STREAMING_LOG(DEBUG) << "SendForResult result[0] DataSize: " << results[0]->GetSize();
const std::shared_ptr<ray::Buffer> result_buffer = results[0]->GetData();
std::shared_ptr<LocalMemoryBuffer> return_buffer =
std::make_shared<LocalMemoryBuffer>(result_buffer->Data(), result_buffer->Size(),
true);
uint8_t *bytes = result_buffer->Data();
uint8_t *p_cur = bytes;
uint32_t *magic_num = (uint32_t *)p_cur;
STREAMING_CHECK(*magic_num == Message::MagicNum);
p_cur += sizeof(Message::MagicNum);
queue::protobuf::StreamingQueueMessageType *type =
(queue::protobuf::StreamingQueueMessageType *)p_cur;
STREAMING_CHECK(*type == queue::protobuf::StreamingQueueMessageType::
StreamingQueueTestCheckStatusRspMsgType);
std::shared_ptr<TestCheckStatusRspMsg> message =
TestCheckStatusRspMsg::FromBytes(bytes);
STREAMING_CHECK(message->TestName() == test_name);
return message->Status();
}
ActorID CreateActorHelper(CoreWorker &worker,
const std::unordered_map<std::string, double> &resources,
bool is_direct_call, uint64_t max_reconstructions) {
std::unique_ptr<ActorHandle> actor_handle;
// Test creating actor.
uint8_t array[] = {1, 2, 3};
auto buffer = std::make_shared<LocalMemoryBuffer>(array, sizeof(array));
RayFunction func{ray::Language::PYTHON, {"actor creation task"}};
std::vector<TaskArg> args;
args.emplace_back(TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr)));
ActorCreationOptions actor_options{
max_reconstructions, is_direct_call,
/*max_concurrency*/ 1, resources, resources, {},
/*is_detached*/ false, /*is_asyncio*/ false};
// Create an actor.
ActorID actor_id;
RAY_CHECK_OK(worker.CreateActor(func, args, actor_options, &actor_id));
return actor_id;
}
void SubmitTest(uint32_t queue_num, std::string suite_name, std::string test_name,
uint64_t timeout_ms) {
std::vector<ray::ObjectID> queue_id_vec;
std::vector<ray::ObjectID> rescale_queue_id_vec;
for (uint32_t i = 0; i < queue_num; ++i) {
ObjectID queue_id = ray::ObjectID::FromRandom();
queue_id_vec.emplace_back(queue_id);
}
// One scale id
ObjectID rescale_queue_id = ray::ObjectID::FromRandom();
rescale_queue_id_vec.emplace_back(rescale_queue_id);
std::vector<uint64_t> channel_seq_id_vec(queue_num, 0);
for (size_t i = 0; i < queue_id_vec.size(); ++i) {
STREAMING_LOG(INFO) << " qid hex => " << queue_id_vec[i].Hex();
}
for (auto &qid : rescale_queue_id_vec) {
STREAMING_LOG(INFO) << " rescale qid hex => " << qid.Hex();
}
STREAMING_LOG(INFO) << "Sub process: writer.";
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "",
node_manager_port_, nullptr);
// Create writer and reader actors
std::unordered_map<std::string, double> resources;
auto actor_id_writer = CreateActorHelper(driver, resources, true, 0);
auto actor_id_reader = CreateActorHelper(driver, resources, true, 0);
InitWorker(driver, actor_id_writer, actor_id_reader,
queue::protobuf::StreamingQueueTestRole::WRITER, queue_id_vec,
rescale_queue_id_vec, suite_name, test_name, GetParam());
InitWorker(driver, actor_id_reader, actor_id_writer,
queue::protobuf::StreamingQueueTestRole::READER, queue_id_vec,
rescale_queue_id_vec, suite_name, test_name, GetParam());
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
SubmitTestToActor(driver, actor_id_writer, test_name);
SubmitTestToActor(driver, actor_id_reader, test_name);
uint64_t slept_time_ms = 0;
while (slept_time_ms < timeout_ms) {
std::this_thread::sleep_for(std::chrono::milliseconds(5 * 1000));
STREAMING_LOG(INFO) << "Check test status.";
if (CheckCurTest(driver, actor_id_writer, test_name) &&
CheckCurTest(driver, actor_id_reader, test_name)) {
STREAMING_LOG(INFO) << "Test Success, Exit.";
return;
}
slept_time_ms += 5 * 1000;
}
EXPECT_TRUE(false);
STREAMING_LOG(INFO) << "Test Timeout, Exit.";
}
void SetUp() {}
void TearDown() {}
protected:
std::vector<std::string> raylet_socket_names_;
std::vector<std::string> raylet_store_socket_names_;
gcs::GcsClientOptions gcs_options_;
std::string raylet_executable_;
std::string store_executable_;
std::string actor_executable_;
int node_manager_port_;
};
} // namespace streaming
} // namespace ray

View file

@ -0,0 +1,93 @@
#include "gtest/gtest.h"
#include "ray/util/logging.h"
#include <unistd.h>
#include <iostream>
#include <set>
#include <thread>
#include "message/message.h"
#include "ring_buffer.h"
using namespace ray;
using namespace ray::streaming;
size_t data_n = 1000000;
TEST(StreamingRingBufferTest, streaming_message_ring_buffer_test) {
for (int k = 0; k < 10000; ++k) {
StreamingRingBuffer ring_buffer(3, StreamingRingBufferType::SPSC_LOCK);
for (int i = 0; i < 5; ++i) {
uint8_t data[] = {1, 1, 3};
data[0] = i;
StreamingMessagePtr message =
std::make_shared<StreamingMessage>(data, 3, i, StreamingMessageType::Message);
EXPECT_EQ(ring_buffer.Push(message), true);
size_t ith = i >= 3 ? 3 : (i + 1);
EXPECT_EQ(ring_buffer.Size(), ith);
}
int th = 2;
while (!ring_buffer.IsEmpty()) {
StreamingMessagePtr message_ptr = ring_buffer.Front();
ring_buffer.Pop();
EXPECT_EQ(message_ptr->GetDataSize(), 3);
EXPECT_EQ(*(message_ptr->RawData()), th++);
}
}
}
TEST(StreamingRingBufferTest, spsc_test) {
size_t m_num = 1000;
StreamingRingBuffer ring_buffer(m_num, StreamingRingBufferType::SPSC);
std::thread thread([&ring_buffer]() {
for (size_t j = 0; j < data_n; ++j) {
StreamingMessagePtr message = std::make_shared<StreamingMessage>(
reinterpret_cast<uint8_t *>(&j), sizeof(size_t), j,
StreamingMessageType::Message);
while (ring_buffer.IsFull()) {
}
ring_buffer.Push(message);
}
});
size_t count = 0;
while (count < data_n) {
while (ring_buffer.IsEmpty()) {
}
auto &msg = ring_buffer.Front();
EXPECT_EQ(std::memcmp(msg->RawData(), &count, sizeof(size_t)), 0);
ring_buffer.Pop();
count++;
}
thread.join();
EXPECT_EQ(count, data_n);
}
TEST(StreamingRingBufferTest, mutex_test) {
size_t m_num = data_n;
StreamingRingBuffer ring_buffer(m_num, StreamingRingBufferType::SPSC_LOCK);
std::thread thread([&ring_buffer]() {
for (size_t j = 0; j < data_n; ++j) {
StreamingMessagePtr message = std::make_shared<StreamingMessage>(
reinterpret_cast<uint8_t *>(&j), sizeof(size_t), j,
StreamingMessageType::Message);
while (ring_buffer.IsFull()) {
}
ring_buffer.Push(message);
}
});
size_t count = 0;
while (count < data_n) {
while (ring_buffer.IsEmpty()) {
}
auto msg = ring_buffer.Front();
EXPECT_EQ(std::memcmp(msg->RawData(), &count, sizeof(size_t)), 0);
ring_buffer.Pop();
count++;
}
thread.join();
EXPECT_EQ(count, data_n);
}
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View file

@ -0,0 +1,69 @@
#!/usr/bin/env bash
# Run all streaming c++ tests using streaming queue, instead of plasma queue
# This needs to be run in the root directory.
# Try to find an unused port for raylet to use.
PORTS="2000 2001 2002 2003 2004 2005 2006 2007 2008 2009"
RAYLET_PORT=0
for port in $PORTS; do
nc -z localhost $port
if [[ $? != 0 ]]; then
RAYLET_PORT=$port
break
fi
done
if [[ $RAYLET_PORT == 0 ]]; then
echo "WARNING: Could not find unused port for raylet to use. Exiting without running tests."
exit
fi
# Cause the script to exit if a single command fails.
set -e
set -x
export STREAMING_METRICS_MODE=DEV
# Get the directory in which this script is executing.
SCRIPT_DIR="`dirname \"$0\"`"
# Get the directory in which this script is executing.
SCRIPT_DIR="`dirname \"$0\"`"
RAY_ROOT="$SCRIPT_DIR/../../.."
# Makes $RAY_ROOT an absolute path.
RAY_ROOT="`( cd \"$RAY_ROOT\" && pwd )`"
if [ -z "$RAY_ROOT" ] ; then
exit 1
fi
bazel build "//:core_worker_test" "//:mock_worker" "//:raylet" "//:libray_redis_module.so" "@plasma//:plasma_store_server"
bazel build //streaming:streaming_test_worker
bazel build //streaming:streaming_queue_tests
# Ensure we're in the right directory.
if [ ! -d "$RAY_ROOT/python" ]; then
echo "Unable to find root Ray directory. Has this script moved?"
exit 1
fi
REDIS_MODULE="./bazel-bin/libray_redis_module.so"
LOAD_MODULE_ARGS="--loadmodule ${REDIS_MODULE}"
STORE_EXEC="./bazel-bin/external/plasma/plasma_store_server"
RAYLET_EXEC="./bazel-bin/raylet"
STREAMING_TEST_WORKER_EXEC="./bazel-bin/streaming/streaming_test_worker"
# Allow cleanup commands to fail.
bazel run //:redis-cli -- -p 6379 shutdown || true
sleep 1s
bazel run //:redis-cli -- -p 6380 shutdown || true
sleep 1s
bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6379 &
sleep 2s
bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6380 &
sleep 2s
# Run tests.
./bazel-bin/streaming/streaming_queue_tests $STORE_EXEC $RAYLET_EXEC $RAYLET_PORT $STREAMING_TEST_WORKER_EXEC
sleep 1s
bazel run //:redis-cli -- -p 6379 shutdown
bazel run //:redis-cli -- -p 6380 shutdown
sleep 1s

View file

@ -0,0 +1,65 @@
#define BOOST_BIND_NO_PLACEHOLDERS
#include <unistd.h>
#include "gtest/gtest.h"
#include "queue/queue_client.h"
#include "ray/core_worker/core_worker.h"
#include "data_reader.h"
#include "data_writer.h"
#include "message/message.h"
#include "message/message_bundle.h"
#include "ring_buffer.h"
#include "queue_tests_base.h"
using namespace std::placeholders;
namespace ray {
namespace streaming {
static std::string store_executable;
static std::string raylet_executable;
static std::string actor_executable;
static int node_manager_port;
class StreamingWriterTest : public StreamingQueueTestBase {
public:
StreamingWriterTest()
: StreamingQueueTestBase(1, raylet_executable, store_executable, node_manager_port,
actor_executable) {}
};
class StreamingExactlySameTest : public StreamingQueueTestBase {
public:
StreamingExactlySameTest()
: StreamingQueueTestBase(1, raylet_executable, store_executable, node_manager_port,
actor_executable) {}
};
TEST_P(StreamingWriterTest, streaming_writer_exactly_once_test) {
STREAMING_LOG(INFO) << "StreamingWriterTest.streaming_writer_exactly_once_test";
uint32_t queue_num = 1;
STREAMING_LOG(INFO) << "Streaming Strategy => EXACTLY ONCE";
SubmitTest(queue_num, "StreamingWriterTest", "streaming_writer_exactly_once_test",
60 * 1000);
}
INSTANTIATE_TEST_CASE_P(StreamingTest, StreamingWriterTest, testing::Values(0));
INSTANTIATE_TEST_CASE_P(StreamingTest, StreamingExactlySameTest,
testing::Values(0, 1, 5, 9));
} // namespace streaming
} // namespace ray
int main(int argc, char **argv) {
// set_streaming_log_config("streaming_writer_test", StreamingLogLevel::INFO, 0);
::testing::InitGoogleTest(&argc, argv);
RAY_CHECK(argc == 5);
ray::streaming::store_executable = std::string(argv[1]);
ray::streaming::raylet_executable = std::string(argv[2]);
ray::streaming::node_manager_port = std::stoi(std::string(argv[3]));
ray::streaming::actor_executable = std::string(argv[4]);
return RUN_ALL_TESTS();
}

View file

@ -0,0 +1,24 @@
#include "gtest/gtest.h"
#include "util/streaming_util.h"
using namespace ray;
using namespace ray::streaming;
TEST(StreamingUtilTest, test_Byte2hex) {
const uint8_t data[2] = {0x11, 0x07};
EXPECT_TRUE(Util::Byte2hex(data, 2) == "1107");
EXPECT_TRUE(Util::Byte2hex(data, 2) != "1108");
}
TEST(StreamingUtilTest, test_Hex2str) {
const uint8_t data[2] = {0x11, 0x07};
EXPECT_TRUE(std::memcmp(Util::Hexqid2str("1107").c_str(), data, 2) == 0);
const uint8_t data2[2] = {0x10, 0x0f};
EXPECT_TRUE(std::memcmp(Util::Hexqid2str("100f").c_str(), data2, 2) == 0);
}
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View file

@ -0,0 +1,12 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include "glog/log_severity.h"
#include "glog/logging.h"
#include "streaming_logging.h"
namespace ray {
namespace streaming {} // namespace streaming
} // namespace ray

View file

@ -0,0 +1,11 @@
#ifndef RAY_STREAMING_LOGGING_H
#define RAY_STREAMING_LOGGING_H
#include "ray/util/logging.h"
#define STREAMING_LOG RAY_LOG
#define STREAMING_CHECK RAY_CHECK
namespace ray {
namespace streaming {} // namespace streaming
} // namespace ray
#endif // RAY_STREAMING_LOGGING_H

View file

@ -0,0 +1,42 @@
#include <unordered_set>
#include "streaming_util.h"
namespace ray {
namespace streaming {
boost::any &Config::Get(ConfigEnum key) const {
auto item = config_map_.find(key);
STREAMING_CHECK(item != config_map_.end());
return item->second;
}
boost::any Config::Get(ConfigEnum key, boost::any default_value) const {
auto item = config_map_.find(key);
if (item == config_map_.end()) {
return default_value;
}
return item->second;
}
std::string Util::Byte2hex(const uint8_t *data, uint32_t data_size) {
constexpr char hex[] = "0123456789abcdef";
std::string result;
for (uint32_t i = 0; i < data_size; i++) {
unsigned short val = data[i];
result.push_back(hex[val >> 4]);
result.push_back(hex[val & 0xf]);
}
return result;
}
std::string Util::Hexqid2str(const std::string &q_id_hex) {
std::string result;
for (uint32_t i = 0; i < q_id_hex.size(); i += 2) {
std::string byte = q_id_hex.substr(i, 2);
char chr = static_cast<char>(std::strtol(byte.c_str(), nullptr, 16));
result.push_back(chr);
}
return result;
}
} // namespace streaming
} // namespace ray

View file

@ -0,0 +1,99 @@
#ifndef RAY_STREAMING_UTIL_H
#define RAY_STREAMING_UTIL_H
#include <boost/any.hpp>
#include <string>
#include <unordered_map>
#include "util/streaming_logging.h"
namespace ray {
namespace streaming {
enum class ConfigEnum : uint32_t {
QUEUE_ID_VECTOR = 0,
RECONSTRUCT_RETRY_TIMES,
RECONSTRUCT_TIMEOUT_PER_MB,
CURRENT_DRIVER_ID,
/// For direct call
CORE_WORKER,
SYNC_FUNCTION,
ASYNC_FUNCTION,
TRANSFER_MIN = QUEUE_ID_VECTOR,
TRANSFER_MAX = ASYNC_FUNCTION
};
} // namespace streaming
} // namespace ray
namespace std {
template <>
struct hash<::ray::streaming::ConfigEnum> {
size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const {
return static_cast<uint32_t>(config_enum_key);
}
};
template <>
struct hash<const ::ray::streaming::ConfigEnum> {
size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const {
return static_cast<uint32_t>(config_enum_key);
}
};
} // namespace std
namespace ray {
namespace streaming {
class Config {
public:
template <typename ValueType>
inline void Set(ConfigEnum key, const ValueType &any) {
config_map_.emplace(key, any);
}
template <typename ValueType>
inline void Set(ConfigEnum key, ValueType &&any) {
config_map_.emplace(key, any);
}
template <typename ValueType>
inline boost::any &GetOrDefault(ConfigEnum key, ValueType &&any) {
auto item = config_map_.find(key);
if (item != config_map_.end()) {
return item->second;
}
Set(key, any);
return any;
}
boost::any &Get(ConfigEnum key) const;
boost::any Get(ConfigEnum key, boost::any default_value) const;
inline uint32_t GetInt32(ConfigEnum key) { return boost::any_cast<uint32_t>(Get(key)); }
inline uint64_t GetInt64(ConfigEnum key) { return boost::any_cast<uint64_t>(Get(key)); }
inline double GetDouble(ConfigEnum key) { return boost::any_cast<double>(Get(key)); }
inline bool GetBool(ConfigEnum key) { return boost::any_cast<bool>(Get(key)); }
inline std::string GetString(ConfigEnum key) {
return boost::any_cast<std::string>(Get(key));
}
virtual ~Config() = default;
protected:
mutable std::unordered_map<ConfigEnum, boost::any> config_map_;
};
class Util {
public:
static std::string Byte2hex(const uint8_t *data, uint32_t data_size);
static std::string Hexqid2str(const std::string &q_id_hex);
};
} // namespace streaming
} // namespace ray
#endif // RAY_STREAMING_UTIL_H

View file

@ -1,30 +1,49 @@
diff --git bazel/cython_library.bzl bazel/cython_library.bzl
index 48b41d74e8..6084734f59 100644
index 48b41d74e8..a9bc168e5d 100644
--- bazel/cython_library.bzl
+++ bazel/cython_library.bzl
@@ -7,7 +7,7 @@
@@ -7,18 +7,20 @@
# been written at cython/cython and tensorflow/tensorflow. We branch from
# Tensorflow's version as it is more actively maintained and works for gRPC
# Python's needs.
-def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs):
+def pyx_library(name, deps=[], copts=[], py_deps=[], srcs=[], **kwargs):
+def pyx_library(name, deps=[], copts=[], cc_kwargs={}, py_deps=[], srcs=[], **kwargs):
"""Compiles a group of .pyx / .pxd / .py files.
First runs Cython to create .cpp files for each input .pyx or .py + .pxd
@@ -19,6 +19,7 @@ def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs):
- pair. Then builds a shared object for each, passing "deps" to each cc_binary
- rule (includes Python headers by default). Finally, creates a py_library rule
- with the shared objects and any pure Python "srcs", with py_deps as its
- dependencies; the shared objects can be imported like normal Python files.
+ pair. Then builds a shared object for each, passing "deps" and `**cc_kwargs`
+ to each cc_binary rule (includes Python headers by default). Finally, creates
+ a py_library rule with the shared objects and any pure Python "srcs", with py_deps
+ as its dependencies; the shared objects can be imported like normal Python files.
Args:
name: Name for the rule.
deps: C/C++ dependencies of the Cython (e.g. Numpy headers).
+ copts: C/C++ compiler options for Cython
+ cc_kwargs: cc_binary extra arguments such as linkstatic, linkopts, features
py_deps: Pure Python dependencies of the final library.
srcs: .py, .pyx, or .pxd files to either compile or pass through.
**kwargs: Extra keyword arguments passed to the py_library.
@@ -58,6 +59,7 @@ def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs):
@@ -57,9 +59,11 @@ def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs):
shared_object_name = stem + ".so"
native.cc_binary(
name=shared_object_name,
srcs=[stem + ".cpp"],
- srcs=[stem + ".cpp"],
+ srcs=[stem + ".cpp"] + cc_kwargs.pop("srcs", []),
+ copts=copts,
deps=deps + ["@local_config_python//:python_headers"],
linkshared=1,
+ **cc_kwargs
)
shared_objects.append(shared_object_name)
@@ -72,3 +76,4 @@ def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs):
data=shared_objects,
**kwargs)
+
--