mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[Streaming] Streaming data transfer and python integration (#6185)
This commit is contained in:
parent
c1d4ab8bb4
commit
6272907a57
93 changed files with 8434 additions and 1480 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -130,6 +130,7 @@ scripts/nodes.txt
|
|||
|
||||
# Pytest Cache
|
||||
**/.pytest_cache
|
||||
**/.cache
|
||||
.benchmarks
|
||||
|
||||
# Vscode
|
||||
|
@ -145,6 +146,9 @@ java/**/.classpath
|
|||
java/**/.project
|
||||
java/runtime/native_dependencies/
|
||||
|
||||
# streaming/python
|
||||
streaming/python/generated/
|
||||
|
||||
# python virtual env
|
||||
venv
|
||||
|
||||
|
|
17
.travis.yml
17
.travis.yml
|
@ -34,6 +34,21 @@ matrix:
|
|||
- if [ $RAY_CI_JAVA_AFFECTED != "1" ]; then exit; fi
|
||||
- ./java/test.sh
|
||||
|
||||
- os: linux
|
||||
env: BAZEL_PYTHON_VERSION=PY3 PYTHON=3.5 PYTHONWARNINGS=ignore TESTSUITE=streaming
|
||||
install:
|
||||
- python $TRAVIS_BUILD_DIR/ci/travis/determine_tests_to_run.py
|
||||
- eval `python $TRAVIS_BUILD_DIR/ci/travis/determine_tests_to_run.py`
|
||||
- if [ $RAY_CI_STREAMING_PYTHON_AFFECTED != "1" ]; then exit; fi
|
||||
- ./ci/suppress_output ./ci/travis/install-bazel.sh
|
||||
- ./ci/suppress_output ./ci/travis/install-dependencies.sh
|
||||
- export PATH="$HOME/miniconda/bin:$PATH"
|
||||
- ./ci/suppress_output ./ci/travis/install-ray.sh
|
||||
script:
|
||||
# Streaming cpp test.
|
||||
- if [ $RAY_CI_STREAMING_CPP_AFFECTED == "1" ]; then ./ci/suppress_output bash streaming/src/test/run_streaming_queue_test.sh; fi
|
||||
- if [ RAY_CI_STREAMING_PYTHON_AFFECTED == "1" ]; then python -m pytest -v --durations=5 --timeout=300 python/ray/streaming/tests/; fi
|
||||
|
||||
- os: linux
|
||||
env: LINT=1 PYTHONWARNINGS=ignore
|
||||
before_install:
|
||||
|
@ -51,7 +66,7 @@ matrix:
|
|||
- sphinx-build -W -b html -d _build/doctrees source _build/html
|
||||
- cd ..
|
||||
# Run Python linting, ignore dict vs {} (C408), others are defaults
|
||||
- flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605
|
||||
- flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,streaming/python/generated,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605
|
||||
- ./ci/travis/format.sh --all
|
||||
# Make sure that the README is formatted properly.
|
||||
- cd python
|
||||
|
|
67
BUILD.bazel
67
BUILD.bazel
|
@ -7,6 +7,28 @@ load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library")
|
|||
load("@com_github_grpc_grpc//bazel:cython_library.bzl", "pyx_library")
|
||||
load("@rules_proto_grpc//python:defs.bzl", "python_grpc_compile")
|
||||
load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
|
||||
load("//bazel:ray.bzl", "if_linux_x86_64")
|
||||
|
||||
config_setting(
|
||||
name = "windows",
|
||||
values = {"cpu": "x64_windows"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "macos",
|
||||
values = {
|
||||
"apple_platform_type": "macos",
|
||||
"cpu": "darwin",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "linux_x86_64",
|
||||
values = {"cpu": "k8"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# TODO(mehrdadn): (How to) support dynamic linking?
|
||||
PROPAGATED_WINDOWS_DEFINES = ["RAY_STATIC"]
|
||||
|
@ -219,6 +241,7 @@ cc_library(
|
|||
includes = [
|
||||
"@boost//:asio",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":common_cc_proto",
|
||||
":gcs_cc_proto",
|
||||
|
@ -327,6 +350,7 @@ cc_library(
|
|||
"-lpthread",
|
||||
],
|
||||
}),
|
||||
visibility = ["//streaming:__subpackages__"],
|
||||
deps = [
|
||||
":common_cc_proto",
|
||||
":gcs",
|
||||
|
@ -373,6 +397,7 @@ cc_library(
|
|||
"src/ray/core_worker/transport/*.h",
|
||||
]),
|
||||
copts = COPTS,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
|
@ -659,6 +684,7 @@ cc_library(
|
|||
includes = [
|
||||
"src",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":sha256",
|
||||
"@com_github_google_glog//:glog",
|
||||
|
@ -782,15 +808,51 @@ pyx_library(
|
|||
name = "_raylet",
|
||||
srcs = glob([
|
||||
"python/ray/__init__.py",
|
||||
"python/ray/_raylet.pxd",
|
||||
"python/ray/_raylet.pyx",
|
||||
"python/ray/includes/*.pxd",
|
||||
"python/ray/includes/*.pxi",
|
||||
]),
|
||||
copts = COPTS,
|
||||
# Export ray ABI symbols, which can then be used by _streaming.so.
|
||||
# We need to dlopen this lib with RTLD_GLOBAL to use ABI in this
|
||||
# shared lib, see python/ray/__init__.py.
|
||||
cc_kwargs = {
|
||||
"linkstatic": 1,
|
||||
# see https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/lite/BUILD#L444
|
||||
"linkopts": select({
|
||||
"//:macos": [
|
||||
"-Wl,-exported_symbols_list,$(location //:src/ray/ray_exported_symbols.lds)",
|
||||
],
|
||||
"//:windows": [],
|
||||
"//conditions:default": [
|
||||
"-Wl,--version-script,$(location //:src/ray/ray_version_script.lds)",
|
||||
],
|
||||
}),
|
||||
},
|
||||
copts = COPTS + if_linux_x86_64(["-fno-gnu-unique"]),
|
||||
deps = [
|
||||
"//:core_worker_lib",
|
||||
"//:raylet_lib",
|
||||
"//:serialization_cc_proto",
|
||||
"//:src/ray/ray_exported_symbols.lds",
|
||||
"//:src/ray/ray_version_script.lds",
|
||||
],
|
||||
)
|
||||
|
||||
pyx_library(
|
||||
name = "_streaming",
|
||||
srcs = glob([
|
||||
"python/ray/streaming/_streaming.pyx",
|
||||
"python/ray/__init__.py",
|
||||
"python/ray/_raylet.pxd",
|
||||
"python/ray/includes/*.pxd",
|
||||
"python/ray/includes/*.pxi",
|
||||
"python/ray/streaming/__init__.pxd",
|
||||
"python/ray/streaming/includes/*.pxd",
|
||||
"python/ray/streaming/includes/*.pxi",
|
||||
]),
|
||||
deps = [
|
||||
"//streaming:streaming_lib",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -922,6 +984,7 @@ genrule(
|
|||
name = "ray_pkg",
|
||||
srcs = [
|
||||
"python/ray/_raylet.so",
|
||||
"python/ray/streaming/_streaming.so",
|
||||
"//:python_sources",
|
||||
"//:all_py_proto",
|
||||
"//:redis-server",
|
||||
|
@ -930,12 +993,14 @@ genrule(
|
|||
"//:raylet",
|
||||
"//:raylet_monitor",
|
||||
"@plasma//:plasma_store_server",
|
||||
"//streaming:copy_streaming_py_proto",
|
||||
],
|
||||
outs = ["ray_pkg.out"],
|
||||
cmd = """
|
||||
set -x &&
|
||||
WORK_DIR=$$(pwd) &&
|
||||
cp -f $(location python/ray/_raylet.so) "$$WORK_DIR/python/ray" &&
|
||||
cp -f $(location python/ray/streaming/_streaming.so) $$WORK_DIR/python/ray/streaming &&
|
||||
mkdir -p "$$WORK_DIR/python/ray/core/src/ray/thirdparty/redis/src/" &&
|
||||
cp -f $(location //:redis-server) "$$WORK_DIR/python/ray/core/src/ray/thirdparty/redis/src/" &&
|
||||
cp -f $(location //:redis-cli) "$$WORK_DIR/python/ray/core/src/ray/thirdparty/redis/src/" &&
|
||||
|
|
|
@ -64,3 +64,9 @@ def define_java_module(
|
|||
"{auto_gen_header}": "<!-- This file is auto-generated by Bazel from pom_template.xml, do not modify it. -->",
|
||||
},
|
||||
)
|
||||
|
||||
def if_linux_x86_64(a):
|
||||
return select({
|
||||
"//:linux_x86_64": a,
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
|
|
@ -44,6 +44,7 @@ while [[ $# > 0 ]]; do
|
|||
done
|
||||
|
||||
pushd $ROOT_DIR/../..
|
||||
BAZEL_FILES="bazel/BUILD bazel/BUILD.plasma bazel/ray.bzl BUILD.bazel WORKSPACE"
|
||||
BAZEL_FILES="bazel/BUILD bazel/BUILD.plasma bazel/ray.bzl BUILD.bazel
|
||||
streaming/BUILD.bazel WORKSPACE"
|
||||
buildifier -mode=$RUN_TYPE -diff_command="diff -u" $BAZEL_FILES
|
||||
popd
|
||||
|
|
|
@ -38,6 +38,8 @@ if __name__ == "__main__":
|
|||
RAY_CI_PYTHON_AFFECTED = 0
|
||||
RAY_CI_LINUX_WHEELS_AFFECTED = 0
|
||||
RAY_CI_MACOS_WHEELS_AFFECTED = 0
|
||||
RAY_CI_STREAMING_CPP_AFFECTED = 0
|
||||
RAY_CI_STREAMING_PYTHON_AFFECTED = 0
|
||||
|
||||
if os.environ["TRAVIS_EVENT_TYPE"] == "pull_request":
|
||||
|
||||
|
@ -71,6 +73,7 @@ if __name__ == "__main__":
|
|||
RAY_CI_PYTHON_AFFECTED = 1
|
||||
RAY_CI_LINUX_WHEELS_AFFECTED = 1
|
||||
RAY_CI_MACOS_WHEELS_AFFECTED = 1
|
||||
RAY_CI_STREAMING_PYTHON_AFFECTED = 1
|
||||
elif changed_file.startswith("java/"):
|
||||
RAY_CI_JAVA_AFFECTED = 1
|
||||
elif any(
|
||||
|
@ -86,6 +89,13 @@ if __name__ == "__main__":
|
|||
RAY_CI_PYTHON_AFFECTED = 1
|
||||
RAY_CI_LINUX_WHEELS_AFFECTED = 1
|
||||
RAY_CI_MACOS_WHEELS_AFFECTED = 1
|
||||
RAY_CI_STREAMING_CPP_AFFECTED = 1
|
||||
RAY_CI_STREAMING_PYTHON_AFFECTED = 1
|
||||
elif changed_file.startswith("streaming/src"):
|
||||
RAY_CI_STREAMING_CPP_AFFECTED = 1
|
||||
RAY_CI_STREAMING_PYTHON_AFFECTED = 1
|
||||
elif changed_file.startswith("streaming/python"):
|
||||
RAY_CI_STREAMING_PYTHON_AFFECTED = 1
|
||||
else:
|
||||
RAY_CI_TUNE_AFFECTED = 1
|
||||
RAY_CI_RLLIB_AFFECTED = 1
|
||||
|
@ -94,6 +104,7 @@ if __name__ == "__main__":
|
|||
RAY_CI_PYTHON_AFFECTED = 1
|
||||
RAY_CI_LINUX_WHEELS_AFFECTED = 1
|
||||
RAY_CI_MACOS_WHEELS_AFFECTED = 1
|
||||
RAY_CI_STREAMING_CPP_AFFECTED = 1
|
||||
else:
|
||||
RAY_CI_TUNE_AFFECTED = 1
|
||||
RAY_CI_RLLIB_AFFECTED = 1
|
||||
|
@ -102,6 +113,7 @@ if __name__ == "__main__":
|
|||
RAY_CI_PYTHON_AFFECTED = 1
|
||||
RAY_CI_LINUX_WHEELS_AFFECTED = 1
|
||||
RAY_CI_MACOS_WHEELS_AFFECTED = 1
|
||||
RAY_CI_STREAMING_CPP_AFFECTED = 1
|
||||
|
||||
# Log the modified environment variables visible in console.
|
||||
for output_stream in [sys.stdout, sys.stderr]:
|
||||
|
@ -116,3 +128,7 @@ if __name__ == "__main__":
|
|||
.format(RAY_CI_LINUX_WHEELS_AFFECTED))
|
||||
_print("export RAY_CI_MACOS_WHEELS_AFFECTED={}"
|
||||
.format(RAY_CI_MACOS_WHEELS_AFFECTED))
|
||||
_print("export RAY_CI_STREAMING_CPP_AFFECTED={}"
|
||||
.format(RAY_CI_STREAMING_CPP_AFFECTED))
|
||||
_print("export RAY_CI_STREAMING_PYTHON_AFFECTED={}"
|
||||
.format(RAY_CI_STREAMING_PYTHON_AFFECTED))
|
||||
|
|
|
@ -79,14 +79,14 @@ format_changed() {
|
|||
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
|
||||
if which flake8 >/dev/null; then
|
||||
git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \
|
||||
flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605
|
||||
flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,streaming/python/generated,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605
|
||||
fi
|
||||
fi
|
||||
|
||||
if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then
|
||||
if which flake8 >/dev/null; then
|
||||
git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \
|
||||
flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605
|
||||
flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,streaming/python/generated,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605
|
||||
fi
|
||||
fi
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from os.path import dirname
|
||||
import sys
|
||||
|
||||
# MUST add pickle5 to the import path because it will be imported by some
|
||||
|
@ -19,6 +20,14 @@ pickle5_path = os.path.join(
|
|||
os.path.abspath(os.path.dirname(__file__)), "pickle5_files")
|
||||
sys.path.insert(0, pickle5_path)
|
||||
|
||||
# Expose ray ABI symbols which may be dependent by other shared
|
||||
# libraries such as _streaming.so. See BUILD.bazel:_raylet
|
||||
so_path = os.path.join(dirname(__file__), "_raylet.so")
|
||||
if os.path.exists(so_path):
|
||||
import ctypes
|
||||
from ctypes import CDLL
|
||||
CDLL(so_path, ctypes.RTLD_GLOBAL)
|
||||
|
||||
# MUST import ray._raylet before pyarrow to initialize some global variables.
|
||||
# It seems the library related to memory allocation in pyarrow will destroy the
|
||||
# initialization of grpc if we import pyarrow at first.
|
||||
|
|
70
python/ray/_raylet.pxd
Normal file
70
python/ray/_raylet.pxd
Normal file
|
@ -0,0 +1,70 @@
|
|||
# cython: profile=False
|
||||
# distutils: language = c++
|
||||
# cython: embedsignature = True
|
||||
# cython: language_level = 3
|
||||
|
||||
from libcpp cimport bool as c_bool
|
||||
from libcpp.string cimport string as c_string
|
||||
from libcpp.vector cimport vector as c_vector
|
||||
from libcpp.memory cimport (
|
||||
shared_ptr,
|
||||
unique_ptr
|
||||
)
|
||||
|
||||
from ray.includes.common cimport (
|
||||
CBuffer,
|
||||
CRayObject
|
||||
)
|
||||
from ray.includes.libcoreworker cimport CCoreWorker
|
||||
from ray.includes.unique_ids cimport (
|
||||
CObjectID,
|
||||
CActorID
|
||||
)
|
||||
|
||||
cdef class Buffer:
|
||||
cdef:
|
||||
shared_ptr[CBuffer] buffer
|
||||
Py_ssize_t shape
|
||||
Py_ssize_t strides
|
||||
|
||||
@staticmethod
|
||||
cdef make(const shared_ptr[CBuffer]& buffer)
|
||||
|
||||
cdef class BaseID:
|
||||
# To avoid the error of "Python int too large to convert to C ssize_t",
|
||||
# here `cdef size_t` is required.
|
||||
cdef size_t hash(self)
|
||||
|
||||
cdef class ObjectID(BaseID):
|
||||
cdef:
|
||||
CObjectID data
|
||||
object buffer_ref
|
||||
# Flag indicating whether or not this object ID was added to the set
|
||||
# of active IDs in the core worker so we know whether we should clean
|
||||
# it up.
|
||||
c_bool in_core_worker
|
||||
|
||||
cdef CObjectID native(self)
|
||||
|
||||
cdef class ActorID(BaseID):
|
||||
cdef CActorID data
|
||||
|
||||
cdef CActorID native(self)
|
||||
|
||||
cdef size_t hash(self)
|
||||
|
||||
cdef class CoreWorker:
|
||||
cdef:
|
||||
unique_ptr[CCoreWorker] core_worker
|
||||
object async_thread
|
||||
object async_event_loop
|
||||
|
||||
cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata,
|
||||
size_t data_size, ObjectID object_id,
|
||||
CObjectID *c_object_id, shared_ptr[CBuffer] *data)
|
||||
# TODO: handle noreturn better
|
||||
cdef store_task_outputs(
|
||||
self, worker, outputs, const c_vector[CObjectID] return_ids,
|
||||
c_vector[shared_ptr[CRayObject]] *returns)
|
||||
|
||||
cdef c_vector[c_string] string_vector_from_list(list string_list)
|
|
@ -41,6 +41,7 @@ from libcpp.vector cimport vector as c_vector
|
|||
from cython.operator import dereference, postincrement
|
||||
|
||||
from ray.includes.common cimport (
|
||||
CBuffer,
|
||||
CAddress,
|
||||
CLanguage,
|
||||
CRayObject,
|
||||
|
@ -346,13 +347,29 @@ cdef c_vector[c_string] string_vector_from_list(list string_list):
|
|||
return out
|
||||
|
||||
|
||||
cdef:
|
||||
c_string pickle_metadata_str = PICKLE_BUFFER_METADATA
|
||||
shared_ptr[CBuffer] pickle_metadata = dynamic_pointer_cast[
|
||||
CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](
|
||||
<uint8_t*>(pickle_metadata_str.data()),
|
||||
pickle_metadata_str.size(), True))
|
||||
c_string raw_meta_str = RAW_BUFFER_METADATA
|
||||
shared_ptr[CBuffer] raw_metadata = dynamic_pointer_cast[
|
||||
CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](
|
||||
<uint8_t*>(raw_meta_str.data()),
|
||||
raw_meta_str.size(), True))
|
||||
|
||||
cdef void prepare_args(list args, c_vector[CTaskArg] *args_vector):
|
||||
cdef:
|
||||
c_string pickled_str
|
||||
c_string metadata_str = PICKLE_BUFFER_METADATA
|
||||
const unsigned char[:] buffer
|
||||
size_t size
|
||||
shared_ptr[CBuffer] arg_data
|
||||
shared_ptr[CBuffer] arg_metadata
|
||||
|
||||
# TODO be consistent with store_task_outputs
|
||||
for arg in args:
|
||||
if isinstance(arg, ObjectID):
|
||||
args_vector.push_back(
|
||||
|
@ -360,23 +377,25 @@ cdef void prepare_args(list args, c_vector[CTaskArg] *args_vector):
|
|||
elif not ray._raylet.check_simple_value(arg):
|
||||
args_vector.push_back(
|
||||
CTaskArg.PassByReference((<ObjectID>ray.put(arg)).native()))
|
||||
else:
|
||||
pickled_str = pickle.dumps(
|
||||
arg, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
# TODO(edoakes): This makes a copy that could be avoided.
|
||||
elif type(arg) is bytes:
|
||||
buffer = arg
|
||||
size = buffer.nbytes
|
||||
arg_data = dynamic_pointer_cast[CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](
|
||||
<uint8_t*>(pickled_str.data()),
|
||||
pickled_str.size(),
|
||||
True))
|
||||
arg_metadata = dynamic_pointer_cast[
|
||||
CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](
|
||||
<uint8_t*>(
|
||||
metadata_str.data()), metadata_str.size(), True))
|
||||
<uint8_t*>(&buffer[0]), size, True))
|
||||
args_vector.push_back(
|
||||
CTaskArg.PassByValue(
|
||||
make_shared[CRayObject](arg_data, arg_metadata)))
|
||||
make_shared[CRayObject](arg_data, raw_metadata)))
|
||||
else:
|
||||
buffer = pickle.dumps(
|
||||
arg, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
size = buffer.nbytes
|
||||
arg_data = dynamic_pointer_cast[CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](
|
||||
<uint8_t*>(&buffer[0]), size, True))
|
||||
args_vector.push_back(
|
||||
CTaskArg.PassByValue(
|
||||
make_shared[CRayObject](arg_data, pickle_metadata)))
|
||||
|
||||
|
||||
cdef class RayletClient:
|
||||
|
@ -738,10 +757,6 @@ cdef write_serialized_object(
|
|||
|
||||
|
||||
cdef class CoreWorker:
|
||||
cdef:
|
||||
unique_ptr[CCoreWorker] core_worker
|
||||
object async_thread
|
||||
object async_event_loop
|
||||
|
||||
def __cinit__(self, is_driver, store_socket, raylet_socket,
|
||||
JobID job_id, GcsClientOptions gcs_options, log_dir,
|
||||
|
@ -1085,7 +1100,6 @@ cdef class CoreWorker:
|
|||
c_vector[shared_ptr[CRayObject]] *returns):
|
||||
cdef:
|
||||
c_vector[size_t] data_sizes
|
||||
c_string metadata_str
|
||||
c_vector[shared_ptr[CBuffer]] metadatas
|
||||
|
||||
if return_ids.size() == 0:
|
||||
|
|
|
@ -1,216 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import threading
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.experimental import internal_kv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel("INFO")
|
||||
|
||||
|
||||
def plasma_prefetch(object_id):
|
||||
"""Tells plasma to prefetch the given object_id."""
|
||||
local_sched_client = ray.worker.global_worker.raylet_client
|
||||
ray_obj_id = ray.ObjectID(object_id)
|
||||
local_sched_client.fetch_or_reconstruct([ray_obj_id], True)
|
||||
|
||||
|
||||
# TODO: doing the timer in Python land is a bit slow
|
||||
class FlushThread(threading.Thread):
|
||||
"""A thread that flushes periodically to plasma.
|
||||
|
||||
Attributes:
|
||||
interval: The flush timeout per batch.
|
||||
flush_fn: The flush function.
|
||||
"""
|
||||
|
||||
def __init__(self, interval, flush_fn):
|
||||
threading.Thread.__init__(self)
|
||||
self.interval = interval # Interval is the max_batch_time
|
||||
self.flush_fn = flush_fn
|
||||
self.daemon = True
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
time.sleep(self.interval) # Flushing period
|
||||
self.flush_fn()
|
||||
|
||||
|
||||
class BatchedQueue(object):
|
||||
"""A batched queue for actor to actor communication.
|
||||
|
||||
Attributes:
|
||||
max_size (int): The maximum size of the queue in number of batches
|
||||
(if exceeded, backpressure kicks in)
|
||||
max_batch_size (int): The size of each batch in number of records.
|
||||
max_batch_time (float): The flush timeout per batch.
|
||||
prefetch_depth (int): The number of batches to prefetch from plasma.
|
||||
background_flush (bool): Denotes whether a daemon flush thread should
|
||||
be used (True) to flush batches to plasma.
|
||||
base (ndarray): A unique signature for the queue.
|
||||
read_ack_key (bytes): The signature of the queue in bytes.
|
||||
prefetch_batch_offset (int): The number of the last read prefetched
|
||||
batch.
|
||||
read_batch_offset (int): The number of the last read batch.
|
||||
read_item_offset (int): The number of the last read record inside a
|
||||
batch.
|
||||
write_batch_offset (int): The number of the last written batch.
|
||||
write_item_offset (int): The numebr of the last written item inside a
|
||||
batch.
|
||||
write_buffer (list): The write buffer, i.e. an in-memory batch.
|
||||
last_flush_time (float): The time the last flushing to plasma took
|
||||
place.
|
||||
cached_remote_offset (int): The number of the last read batch as
|
||||
recorded by the writer after the previous flush.
|
||||
flush_lock (RLock): A python lock used for flushing batches to plasma.
|
||||
flush_thread (Threading): The python thread used for flushing batches
|
||||
to plasma.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_size=999999,
|
||||
max_batch_size=99999,
|
||||
max_batch_time=0.01,
|
||||
prefetch_depth=10,
|
||||
background_flush=True):
|
||||
self.max_size = max_size
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_batch_time = max_batch_time
|
||||
self.prefetch_depth = prefetch_depth
|
||||
self.background_flush = background_flush
|
||||
|
||||
# Common queue metadata -- This serves as the unique id of the queue
|
||||
self.base = np.random.randint(0, 2**32 - 1, size=5, dtype="uint32")
|
||||
self.base[-2] = 0
|
||||
self.base[-1] = 0
|
||||
self.read_ack_key = np.ndarray.tobytes(self.base)
|
||||
|
||||
# Reader state
|
||||
self.prefetch_batch_offset = 0
|
||||
self.read_item_offset = 0
|
||||
self.read_batch_offset = 0
|
||||
self.read_buffer = []
|
||||
|
||||
# Writer state
|
||||
self.write_item_offset = 0
|
||||
self.write_batch_offset = 0
|
||||
self.write_buffer = []
|
||||
self.last_flush_time = 0.0
|
||||
self.cached_remote_offset = 0
|
||||
|
||||
self.flush_lock = threading.RLock()
|
||||
self.flush_thread = FlushThread(self.max_batch_time,
|
||||
self._flush_writes)
|
||||
|
||||
def __getstate__(self):
|
||||
state = dict(self.__dict__)
|
||||
del state["flush_lock"]
|
||||
del state["flush_thread"]
|
||||
del state["write_buffer"]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
# This is to enable writing functionality in
|
||||
# case the queue is not created by the writer
|
||||
# The reason is that python locks cannot be serialized
|
||||
def enable_writes(self):
|
||||
"""Restores the state of the batched queue for writing."""
|
||||
self.write_buffer = []
|
||||
self.flush_lock = threading.RLock()
|
||||
self.flush_thread = FlushThread(self.max_batch_time,
|
||||
self._flush_writes)
|
||||
|
||||
# Batch ids consist of a unique queue id used as prefix along with
|
||||
# two numbers generated using the batch offset in the queue
|
||||
def _batch_id(self, batch_offset):
|
||||
oid = self.base.copy()
|
||||
oid[-2] = batch_offset // 2**32
|
||||
oid[-1] = batch_offset % 2**32
|
||||
return np.ndarray.tobytes(oid)
|
||||
|
||||
def _flush_writes(self):
|
||||
with self.flush_lock:
|
||||
if not self.write_buffer:
|
||||
return
|
||||
batch_id = self._batch_id(self.write_batch_offset)
|
||||
ray.worker.global_worker.put_object(self.write_buffer,
|
||||
ray.ObjectID(batch_id))
|
||||
logger.debug("[writer] Flush batch {} offset {} size {}".format(
|
||||
self.write_batch_offset, self.write_item_offset,
|
||||
len(self.write_buffer)))
|
||||
self.write_buffer = []
|
||||
self.write_batch_offset += 1
|
||||
self._wait_for_reader()
|
||||
self.last_flush_time = time.time()
|
||||
|
||||
def _wait_for_reader(self):
|
||||
"""Checks for backpressure by the downstream reader."""
|
||||
if self.max_size <= 0: # Unlimited queue
|
||||
return
|
||||
if self.write_item_offset - self.cached_remote_offset <= self.max_size:
|
||||
return # Hasn't reached max size
|
||||
remote_offset = internal_kv._internal_kv_get(self.read_ack_key)
|
||||
if remote_offset is None:
|
||||
# logger.debug("[writer] Waiting for reader to start...")
|
||||
while remote_offset is None:
|
||||
time.sleep(0.01)
|
||||
remote_offset = internal_kv._internal_kv_get(self.read_ack_key)
|
||||
remote_offset = int(remote_offset)
|
||||
if self.write_item_offset - remote_offset > self.max_size:
|
||||
logger.debug(
|
||||
"[writer] Waiting for reader to catch up {} to {} - {}".format(
|
||||
remote_offset, self.write_item_offset, self.max_size))
|
||||
while self.write_item_offset - remote_offset > self.max_size:
|
||||
time.sleep(0.01)
|
||||
remote_offset = int(
|
||||
internal_kv._internal_kv_get(self.read_ack_key))
|
||||
self.cached_remote_offset = remote_offset
|
||||
|
||||
def _read_next_batch(self):
|
||||
while (self.prefetch_batch_offset <
|
||||
self.read_batch_offset + self.prefetch_depth):
|
||||
plasma_prefetch(self._batch_id(self.prefetch_batch_offset))
|
||||
self.prefetch_batch_offset += 1
|
||||
self.read_buffer = ray.get(
|
||||
ray.ObjectID(self._batch_id(self.read_batch_offset)))
|
||||
self.read_batch_offset += 1
|
||||
logger.debug("[reader] Fetched batch {} offset {} size {}".format(
|
||||
self.read_batch_offset, self.read_item_offset,
|
||||
len(self.read_buffer)))
|
||||
self._ack_reads(self.read_item_offset + len(self.read_buffer))
|
||||
|
||||
# Reader acks the key it reads so that writer knows reader's offset.
|
||||
# This is to cap queue size and simulate backpressure
|
||||
def _ack_reads(self, offset):
|
||||
if self.max_size > 0:
|
||||
internal_kv._internal_kv_put(
|
||||
self.read_ack_key, offset, overwrite=True)
|
||||
|
||||
def put_next(self, item):
|
||||
with self.flush_lock:
|
||||
if self.background_flush and not self.flush_thread.is_alive():
|
||||
logger.debug("[writer] Starting batch flush thread")
|
||||
self.flush_thread.start()
|
||||
self.write_buffer.append(item)
|
||||
self.write_item_offset += 1
|
||||
if not self.last_flush_time:
|
||||
self.last_flush_time = time.time()
|
||||
delay = time.time() - self.last_flush_time
|
||||
if (len(self.write_buffer) > self.max_batch_size
|
||||
or delay > self.max_batch_time):
|
||||
self._flush_writes()
|
||||
|
||||
def read_next(self):
|
||||
if not self.read_buffer:
|
||||
self._read_next_batch()
|
||||
assert self.read_buffer
|
||||
self.read_item_offset += 1
|
||||
return self.read_buffer.pop(0)
|
|
@ -1,182 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.experimental.streaming.batched_queue import BatchedQueue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--rounds", default=10, help="the number of experiment rounds")
|
||||
parser.add_argument(
|
||||
"--num-queues", default=1, help="the number of queues in the chain")
|
||||
parser.add_argument(
|
||||
"--queue-size", default=10000, help="the queue size in number of batches")
|
||||
parser.add_argument(
|
||||
"--batch-size", default=1000, help="the batch size in number of elements")
|
||||
parser.add_argument(
|
||||
"--flush-timeout", default=0.001, help="the timeout to flush a batch")
|
||||
parser.add_argument(
|
||||
"--prefetch-depth",
|
||||
default=10,
|
||||
help="the number of batches to prefetch from plasma")
|
||||
parser.add_argument(
|
||||
"--background-flush",
|
||||
default=False,
|
||||
help="whether to flush in the backrgound or not")
|
||||
parser.add_argument(
|
||||
"--max-throughput",
|
||||
default="inf",
|
||||
help="maximum read throughput (elements/s)")
|
||||
|
||||
|
||||
@ray.remote
|
||||
class Node(object):
|
||||
"""An actor that reads from an input queue and writes to an output queue.
|
||||
|
||||
Attributes:
|
||||
id (int): The id of the actor.
|
||||
queue (BatchedQueue): The input queue.
|
||||
out_queue (BatchedQueue): The output queue.
|
||||
max_reads_per_second (int): The max read throughput (default: inf).
|
||||
num_reads (int): Number of elements read.
|
||||
num_writes (int): Number of elements written.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
id,
|
||||
in_queue,
|
||||
out_queue,
|
||||
max_reads_per_second=float("inf")):
|
||||
self.id = id
|
||||
self.queue = in_queue
|
||||
self.out_queue = out_queue
|
||||
self.max_reads_per_second = max_reads_per_second
|
||||
self.num_reads = 0
|
||||
self.num_writes = 0
|
||||
self.start = time.time()
|
||||
|
||||
def read_write_forever(self):
|
||||
debug_log = "[actor {}] Reads throttled to {} reads/s"
|
||||
log = ""
|
||||
if self.out_queue is not None:
|
||||
self.out_queue.enable_writes()
|
||||
log += "[actor {}] Reads/Writes per second {}"
|
||||
else: # It's just a reader
|
||||
log += "[actor {}] Reads per second {}"
|
||||
# Start spinning
|
||||
expected_value = 0
|
||||
while True:
|
||||
start = time.time()
|
||||
N = 100000
|
||||
for _ in range(N):
|
||||
x = self.queue.read_next()
|
||||
assert x == expected_value, (x, expected_value)
|
||||
expected_value += 1
|
||||
self.num_reads += 1
|
||||
if self.out_queue is not None:
|
||||
self.out_queue.put_next(x)
|
||||
self.num_writes += 1
|
||||
while (self.num_reads / (time.time() - self.start) >
|
||||
self.max_reads_per_second):
|
||||
logger.debug(
|
||||
debug_log.format(self.id, self.max_reads_per_second))
|
||||
time.sleep(0.1)
|
||||
logger.info(log.format(self.id, N / (time.time() - start)))
|
||||
# Flush any remaining elements
|
||||
if self.out_queue is not None:
|
||||
self.out_queue._flush_writes()
|
||||
|
||||
|
||||
def test_max_throughput(rounds,
|
||||
max_queue_size,
|
||||
max_batch_size,
|
||||
batch_timeout,
|
||||
prefetch_depth,
|
||||
background_flush,
|
||||
num_queues,
|
||||
max_reads_per_second=float("inf")):
|
||||
assert num_queues >= 1
|
||||
first_queue = BatchedQueue(
|
||||
max_size=max_queue_size,
|
||||
max_batch_size=max_batch_size,
|
||||
max_batch_time=batch_timeout,
|
||||
prefetch_depth=prefetch_depth,
|
||||
background_flush=background_flush)
|
||||
previous_queue = first_queue
|
||||
for i in range(num_queues):
|
||||
# Construct the batched queue
|
||||
in_queue = previous_queue
|
||||
out_queue = None
|
||||
if i < num_queues - 1:
|
||||
out_queue = BatchedQueue(
|
||||
max_size=max_queue_size,
|
||||
max_batch_size=max_batch_size,
|
||||
max_batch_time=batch_timeout,
|
||||
prefetch_depth=prefetch_depth,
|
||||
background_flush=background_flush)
|
||||
|
||||
node = Node.remote(i, in_queue, out_queue, max_reads_per_second)
|
||||
node.read_write_forever.remote()
|
||||
previous_queue = out_queue
|
||||
|
||||
value = 0
|
||||
# Feed the chain
|
||||
for round in range(rounds):
|
||||
logger.info("Round {}".format(round))
|
||||
N = 100000
|
||||
start = time.time()
|
||||
for i in range(N):
|
||||
first_queue.put_next(value)
|
||||
value += 1
|
||||
log = "[writer] Puts per second {}"
|
||||
logger.info(log.format(N / (time.time() - start)))
|
||||
first_queue._flush_writes()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
ray.register_custom_serializer(BatchedQueue, use_pickle=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
rounds = int(args.rounds)
|
||||
max_queue_size = int(args.queue_size)
|
||||
max_batch_size = int(args.batch_size)
|
||||
batch_timeout = float(args.flush_timeout)
|
||||
prefetch_depth = int(args.prefetch_depth)
|
||||
background_flush = bool(args.background_flush)
|
||||
num_queues = int(args.num_queues)
|
||||
max_reads_per_second = float(args.max_throughput)
|
||||
|
||||
logger.info("== Parameters ==")
|
||||
logger.info("Rounds: {}".format(rounds))
|
||||
logger.info("Max queue size: {}".format(max_queue_size))
|
||||
logger.info("Max batch size: {}".format(max_batch_size))
|
||||
logger.info("Batch timeout: {}".format(batch_timeout))
|
||||
logger.info("Prefetch depth: {}".format(prefetch_depth))
|
||||
logger.info("Background flush: {}".format(background_flush))
|
||||
logger.info("Max read throughput: {}".format(max_reads_per_second))
|
||||
|
||||
# Estimate the ideal throughput
|
||||
value = 0
|
||||
start = time.time()
|
||||
for round in range(rounds):
|
||||
N = 100000
|
||||
for _ in range(N):
|
||||
value += 1
|
||||
logger.info("Ideal throughput: {}".format(value / (time.time() - start)))
|
||||
|
||||
logger.info("== Testing max throughput ==")
|
||||
start = time.time()
|
||||
test_max_throughput(rounds, max_queue_size, max_batch_size, batch_timeout,
|
||||
prefetch_depth, background_flush, num_queues,
|
||||
max_reads_per_second)
|
||||
logger.info("Elapsed time: {}".format(time.time() - start))
|
|
@ -1,359 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from ray.experimental.streaming.operator import PStrategy
|
||||
from ray.experimental.streaming.batched_queue import BatchedQueue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Forward and broadcast stream partitioning strategies
|
||||
forward_broadcast_strategies = [PStrategy.Forward, PStrategy.Broadcast]
|
||||
|
||||
|
||||
# Used to choose output channel in case of hash-based shuffling
|
||||
def _hash(value):
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
try:
|
||||
return int(hashlib.sha1(value.encode("utf-8")).hexdigest(), 16)
|
||||
except AttributeError:
|
||||
return int(hashlib.sha1(value).hexdigest(), 16)
|
||||
|
||||
|
||||
# A data channel is a batched queue between two
|
||||
# operator instances in a streaming environment
|
||||
class DataChannel(object):
|
||||
"""A data channel for actor-to-actor communication.
|
||||
|
||||
Attributes:
|
||||
env (Environment): The environment the channel belongs to.
|
||||
src_operator_id (UUID): The id of the source operator of the channel.
|
||||
dst_operator_id (UUID): The id of the destination operator of the
|
||||
channel.
|
||||
src_instance_id (int): The id of the source instance.
|
||||
dst_instance_id (int): The id of the destination instance.
|
||||
queue (BatchedQueue): The batched queue used for data movement.
|
||||
"""
|
||||
|
||||
def __init__(self, env, src_operator_id, dst_operator_id, src_instance_id,
|
||||
dst_instance_id):
|
||||
self.env = env
|
||||
self.src_operator_id = src_operator_id
|
||||
self.dst_operator_id = dst_operator_id
|
||||
self.src_instance_id = src_instance_id
|
||||
self.dst_instance_id = dst_instance_id
|
||||
self.queue = BatchedQueue(
|
||||
max_size=self.env.config.queue_config.max_size,
|
||||
max_batch_size=self.env.config.queue_config.max_batch_size,
|
||||
max_batch_time=self.env.config.queue_config.max_batch_time,
|
||||
prefetch_depth=self.env.config.queue_config.prefetch_depth,
|
||||
background_flush=self.env.config.queue_config.background_flush)
|
||||
|
||||
def __repr__(self):
|
||||
return "({},{},{},{})".format(
|
||||
self.src_operator_id, self.dst_operator_id, self.src_instance_id,
|
||||
self.dst_instance_id)
|
||||
|
||||
|
||||
# Pulls and merges data from multiple input channels
|
||||
class DataInput(object):
|
||||
"""An input gate of an operator instance.
|
||||
|
||||
The input gate pulls records from all input channels in a round-robin
|
||||
fashion.
|
||||
|
||||
Attributes:
|
||||
input_channels (list): The list of input channels.
|
||||
channel_index (int): The index of the next channel to pull from.
|
||||
max_index (int): The number of input channels.
|
||||
closed (list): A list of flags indicating whether an input channel
|
||||
has been marked as 'closed'.
|
||||
all_closed (bool): Denotes whether all input channels have been
|
||||
closed (True) or not (False).
|
||||
"""
|
||||
|
||||
def __init__(self, channels):
|
||||
self.input_channels = channels
|
||||
self.channel_index = 0
|
||||
self.max_index = len(channels)
|
||||
self.closed = [False] * len(
|
||||
self.input_channels) # Tracks the channels that have been closed
|
||||
self.all_closed = False
|
||||
|
||||
# Fetches records from input channels in a round-robin fashion
|
||||
# TODO (john): Make sure the instance is not blocked on any of its input
|
||||
# channels
|
||||
# TODO (john): In case of input skew, it might be better to pull from
|
||||
# the largest queue more often
|
||||
def _pull(self):
|
||||
while True:
|
||||
if self.max_index == 0:
|
||||
# TODO (john): We should detect this earlier
|
||||
return None
|
||||
# Channel to pull from
|
||||
channel = self.input_channels[self.channel_index]
|
||||
self.channel_index += 1
|
||||
if self.channel_index == self.max_index: # Reset channel index
|
||||
self.channel_index = 0
|
||||
if self.closed[self.channel_index - 1]:
|
||||
continue # Channel has been 'closed', check next
|
||||
record = channel.queue.read_next()
|
||||
logger.debug("Actor ({},{}) pulled '{}'.".format(
|
||||
channel.src_operator_id, channel.src_instance_id, record))
|
||||
if record is None:
|
||||
# Mark channel as 'closed' and pull from the next open one
|
||||
self.closed[self.channel_index - 1] = True
|
||||
self.all_closed = True
|
||||
for flag in self.closed:
|
||||
if flag is False:
|
||||
self.all_closed = False
|
||||
break
|
||||
if not self.all_closed:
|
||||
continue
|
||||
# Returns 'None' iff all input channels are 'closed'
|
||||
return record
|
||||
|
||||
|
||||
# Selects output channel(s) and pushes data
|
||||
class DataOutput(object):
|
||||
"""An output gate of an operator instance.
|
||||
|
||||
The output gate pushes records to output channels according to the
|
||||
user-defined partitioning scheme.
|
||||
|
||||
Attributes:
|
||||
partitioning_schemes (dict): A mapping from destination operator ids
|
||||
to partitioning schemes (see: PScheme in operator.py).
|
||||
forward_channels (list): A list of channels to forward records.
|
||||
shuffle_channels (list(list)): A list of output channels to shuffle
|
||||
records grouped by destination operator.
|
||||
shuffle_key_channels (list(list)): A list of output channels to
|
||||
shuffle records by a key grouped by destination operator.
|
||||
shuffle_exists (bool): A flag indicating that there exists at least
|
||||
one shuffle_channel.
|
||||
shuffle_key_exists (bool): A flag indicating that there exists at
|
||||
least one shuffle_key_channel.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, partitioning_schemes):
|
||||
self.key_selector = None
|
||||
self.round_robin_indexes = [0]
|
||||
self.partitioning_schemes = partitioning_schemes
|
||||
# Prepare output -- collect channels by type
|
||||
self.forward_channels = [] # Forward and broadcast channels
|
||||
slots = sum(1 for scheme in self.partitioning_schemes.values()
|
||||
if scheme.strategy == PStrategy.RoundRobin)
|
||||
self.round_robin_channels = [[]] * slots # RoundRobin channels
|
||||
self.round_robin_indexes = [-1] * slots
|
||||
slots = sum(1 for scheme in self.partitioning_schemes.values()
|
||||
if scheme.strategy == PStrategy.Shuffle)
|
||||
# Flag used to avoid hashing when there is no shuffling
|
||||
self.shuffle_exists = slots > 0
|
||||
self.shuffle_channels = [[]] * slots # Shuffle channels
|
||||
slots = sum(1 for scheme in self.partitioning_schemes.values()
|
||||
if scheme.strategy == PStrategy.ShuffleByKey)
|
||||
# Flag used to avoid hashing when there is no shuffling by key
|
||||
self.shuffle_key_exists = slots > 0
|
||||
self.shuffle_key_channels = [[]] * slots # Shuffle by key channels
|
||||
# Distinct shuffle destinations
|
||||
shuffle_destinations = {}
|
||||
# Distinct shuffle by key destinations
|
||||
shuffle_by_key_destinations = {}
|
||||
# Distinct round robin destinations
|
||||
round_robin_destinations = {}
|
||||
index_1 = 0
|
||||
index_2 = 0
|
||||
index_3 = 0
|
||||
for channel in channels:
|
||||
p_scheme = self.partitioning_schemes[channel.dst_operator_id]
|
||||
strategy = p_scheme.strategy
|
||||
if strategy in forward_broadcast_strategies:
|
||||
self.forward_channels.append(channel)
|
||||
elif strategy == PStrategy.Shuffle:
|
||||
pos = shuffle_destinations.setdefault(channel.dst_operator_id,
|
||||
index_1)
|
||||
self.shuffle_channels[pos].append(channel)
|
||||
if pos == index_1:
|
||||
index_1 += 1
|
||||
elif strategy == PStrategy.ShuffleByKey:
|
||||
pos = shuffle_by_key_destinations.setdefault(
|
||||
channel.dst_operator_id, index_2)
|
||||
self.shuffle_key_channels[pos].append(channel)
|
||||
if pos == index_2:
|
||||
index_2 += 1
|
||||
elif strategy == PStrategy.RoundRobin:
|
||||
pos = round_robin_destinations.setdefault(
|
||||
channel.dst_operator_id, index_3)
|
||||
self.round_robin_channels[pos].append(channel)
|
||||
if pos == index_3:
|
||||
index_3 += 1
|
||||
else: # TODO (john): Add support for other strategies
|
||||
sys.exit("Unrecognized or unsupported partitioning strategy.")
|
||||
# A KeyedDataStream can only be shuffled by key
|
||||
assert not (self.shuffle_exists and self.shuffle_key_exists)
|
||||
|
||||
# Flushes any remaining records in the output channels
|
||||
# 'close' indicates whether we should also 'close' the channel (True)
|
||||
# by propagating 'None'
|
||||
# or just flush the remaining records to plasma (False)
|
||||
def _flush(self, close=False):
|
||||
"""Flushes remaining output records in the output queues to plasma.
|
||||
|
||||
None is used as special type of record that is propagated from sources
|
||||
to sink to notify that the end of data in a stream.
|
||||
|
||||
Attributes:
|
||||
close (bool): A flag denoting whether the channel should be
|
||||
also marked as 'closed' (True) or not (False) after flushing.
|
||||
"""
|
||||
for channel in self.forward_channels:
|
||||
if close is True:
|
||||
channel.queue.put_next(None)
|
||||
channel.queue._flush_writes()
|
||||
for channels in self.shuffle_channels:
|
||||
for channel in channels:
|
||||
if close is True:
|
||||
channel.queue.put_next(None)
|
||||
channel.queue._flush_writes()
|
||||
for channels in self.shuffle_key_channels:
|
||||
for channel in channels:
|
||||
if close is True:
|
||||
channel.queue.put_next(None)
|
||||
channel.queue._flush_writes()
|
||||
for channels in self.round_robin_channels:
|
||||
for channel in channels:
|
||||
if close is True:
|
||||
channel.queue.put_next(None)
|
||||
channel.queue._flush_writes()
|
||||
# TODO (john): Add more channel types
|
||||
|
||||
# Returns all destination actor ids
|
||||
def _destination_actor_ids(self):
|
||||
destinations = []
|
||||
for channel in self.forward_channels:
|
||||
destinations.append((channel.dst_operator_id,
|
||||
channel.dst_instance_id))
|
||||
for channels in self.shuffle_channels:
|
||||
for channel in channels:
|
||||
destinations.append((channel.dst_operator_id,
|
||||
channel.dst_instance_id))
|
||||
for channels in self.shuffle_key_channels:
|
||||
for channel in channels:
|
||||
destinations.append((channel.dst_operator_id,
|
||||
channel.dst_instance_id))
|
||||
for channels in self.round_robin_channels:
|
||||
for channel in channels:
|
||||
destinations.append((channel.dst_operator_id,
|
||||
channel.dst_instance_id))
|
||||
# TODO (john): Add more channel types
|
||||
return destinations
|
||||
|
||||
# Pushes the record to the output
|
||||
# Each individual output queue flushes batches to plasma periodically
|
||||
# based on 'batch_max_size' and 'batch_max_time'
|
||||
def _push(self, record):
|
||||
# Forward record
|
||||
for channel in self.forward_channels:
|
||||
logger.debug("[writer] Push record '{}' to channel {}".format(
|
||||
record, channel))
|
||||
channel.queue.put_next(record)
|
||||
# Forward record
|
||||
index = 0
|
||||
for channels in self.round_robin_channels:
|
||||
self.round_robin_indexes[index] += 1
|
||||
if self.round_robin_indexes[index] == len(channels):
|
||||
self.round_robin_indexes[index] = 0 # Reset index
|
||||
channel = channels[self.round_robin_indexes[index]]
|
||||
logger.debug("[writer] Push record '{}' to channel {}".format(
|
||||
record, channel))
|
||||
channel.queue.put_next(record)
|
||||
index += 1
|
||||
# Hash-based shuffling by key
|
||||
if self.shuffle_key_exists:
|
||||
key, _ = record
|
||||
h = _hash(key)
|
||||
for channels in self.shuffle_key_channels:
|
||||
num_instances = len(channels) # Downstream instances
|
||||
channel = channels[h % num_instances]
|
||||
logger.debug(
|
||||
"[key_shuffle] Push record '{}' to channel {}".format(
|
||||
record, channel))
|
||||
channel.queue.put_next(record)
|
||||
elif self.shuffle_exists: # Hash-based shuffling per destination
|
||||
h = _hash(record)
|
||||
for channels in self.shuffle_channels:
|
||||
num_instances = len(channels) # Downstream instances
|
||||
channel = channels[h % num_instances]
|
||||
logger.debug("[shuffle] Push record '{}' to channel {}".format(
|
||||
record, channel))
|
||||
channel.queue.put_next(record)
|
||||
else: # TODO (john): Handle rescaling
|
||||
pass
|
||||
|
||||
# Pushes a list of records to the output
|
||||
# Each individual output queue flushes batches to plasma periodically
|
||||
# based on 'batch_max_size' and 'batch_max_time'
|
||||
def _push_all(self, records):
|
||||
# Forward records
|
||||
for record in records:
|
||||
for channel in self.forward_channels:
|
||||
logger.debug("[writer] Push record '{}' to channel {}".format(
|
||||
record, channel))
|
||||
channel.queue.put_next(record)
|
||||
# Hash-based shuffling by key per destination
|
||||
if self.shuffle_key_exists:
|
||||
for record in records:
|
||||
key, _ = record
|
||||
h = _hash(key)
|
||||
for channels in self.shuffle_channels:
|
||||
num_instances = len(channels) # Downstream instances
|
||||
channel = channels[h % num_instances]
|
||||
logger.debug(
|
||||
"[key_shuffle] Push record '{}' to channel {}".format(
|
||||
record, channel))
|
||||
channel.queue.put_next(record)
|
||||
elif self.shuffle_exists: # Hash-based shuffling per destination
|
||||
for record in records:
|
||||
h = _hash(record)
|
||||
for channels in self.shuffle_channels:
|
||||
num_instances = len(channels) # Downstream instances
|
||||
channel = channels[h % num_instances]
|
||||
logger.debug(
|
||||
"[shuffle] Push record '{}' to channel {}".format(
|
||||
record, channel))
|
||||
channel.queue.put_next(record)
|
||||
else: # TODO (john): Handle rescaling
|
||||
pass
|
||||
|
||||
|
||||
# Batched queue configuration
|
||||
class QueueConfig(object):
|
||||
"""The configuration of a batched queue.
|
||||
|
||||
Attributes:
|
||||
max_size (int): The maximum size of the queue in number of batches
|
||||
(if exceeded, backpressure kicks in).
|
||||
max_batch_size (int): The size of each batch in number of records.
|
||||
max_batch_time (float): The flush timeout per batch.
|
||||
prefetch_depth (int): The number of batches to prefetch from plasma.
|
||||
background_flush (bool): Denotes whether a daemon flush thread should
|
||||
be used (True) to flush batches to plasma.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_size=999999,
|
||||
max_batch_size=99999,
|
||||
max_batch_time=0.01,
|
||||
prefetch_depth=10,
|
||||
background_flush=False):
|
||||
self.max_size = max_size
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_batch_time = max_batch_time
|
||||
self.prefetch_depth = prefetch_depth
|
||||
self.background_flush = background_flush
|
|
@ -1,365 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
|
||||
import ray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel("DEBUG")
|
||||
|
||||
#
|
||||
# Each Ray actor corresponds to an operator instance in the physical dataflow
|
||||
# Actors communicate using batched queues as data channels (no standing TCP
|
||||
# connections)
|
||||
# Currently, batched queues are based on Eric's implementation (see:
|
||||
# batched_queue.py)
|
||||
|
||||
|
||||
def _identity(element):
|
||||
return element
|
||||
|
||||
|
||||
# TODO (john): Specify the interface of state keepers
|
||||
class OperatorInstance(object):
|
||||
"""A streaming operator instance.
|
||||
|
||||
Attributes:
|
||||
instance_id (UUID): The id of the instance.
|
||||
input (DataInput): The input gate that manages input channels of
|
||||
the instance (see: DataInput in communication.py).
|
||||
input (DataOutput): The output gate that manages output channels of
|
||||
the instance (see: DataOutput in communication.py).
|
||||
state_keepers (list): A list of actor handlers to query the state of
|
||||
the operator instance.
|
||||
"""
|
||||
|
||||
def __init__(self, instance_id, input_gate, output_gate,
|
||||
state_keeper=None):
|
||||
self.key_index = None # Index for key selection
|
||||
self.key_attribute = None # Attribute name for key selection
|
||||
self.instance_id = instance_id
|
||||
self.input = input_gate
|
||||
self.output = output_gate
|
||||
# Handle(s) to one or more user-defined actors
|
||||
# that can retrieve actor's state
|
||||
self.state_keeper = state_keeper
|
||||
# Enable writes
|
||||
for channel in self.output.forward_channels:
|
||||
channel.queue.enable_writes()
|
||||
for channels in self.output.shuffle_channels:
|
||||
for channel in channels:
|
||||
channel.queue.enable_writes()
|
||||
for channels in self.output.shuffle_key_channels:
|
||||
for channel in channels:
|
||||
channel.queue.enable_writes()
|
||||
for channels in self.output.round_robin_channels:
|
||||
for channel in channels:
|
||||
channel.queue.enable_writes()
|
||||
# TODO (john): Add more channel types here
|
||||
|
||||
# Registers actor's handle so that the actor can schedule itself
|
||||
def register_handle(self, actor_handle):
|
||||
self.this_actor = actor_handle
|
||||
|
||||
# Used for index-based key extraction, e.g. for tuples
|
||||
def index_based_selector(self, record):
|
||||
return record[self.key_index]
|
||||
|
||||
# Used for attribute-based key extraction, e.g. for classes
|
||||
def attribute_based_selector(self, record):
|
||||
return vars(record)[self.key_attribute]
|
||||
|
||||
# Starts the actor
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
|
||||
# A source actor that reads a text file line by line
|
||||
@ray.remote
|
||||
class ReadTextFile(OperatorInstance):
|
||||
"""A source operator instance that reads a text file line by line.
|
||||
|
||||
Attributes:
|
||||
filepath (string): The path to the input file.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
instance_id,
|
||||
operator_metadata,
|
||||
input_gate,
|
||||
output_gate,
|
||||
state_keepers=None):
|
||||
OperatorInstance.__init__(self, instance_id, input_gate, output_gate,
|
||||
state_keepers)
|
||||
self.filepath = operator_metadata.other_args
|
||||
# TODO (john): Handle possible exception here
|
||||
self.reader = open(self.filepath, "r")
|
||||
|
||||
# Read input file line by line
|
||||
def start(self):
|
||||
while True:
|
||||
record = self.reader.readline()
|
||||
# Reader returns empty string ('') on EOF
|
||||
if not record:
|
||||
# Flush any remaining records to plasma and close the file
|
||||
self.output._flush(close=True)
|
||||
self.reader.close()
|
||||
return
|
||||
self.output._push(
|
||||
record[:-1]) # Push after removing newline characters
|
||||
|
||||
|
||||
# Map actor
|
||||
@ray.remote
|
||||
class Map(OperatorInstance):
|
||||
"""A map operator instance that applies a user-defined
|
||||
stream transformation.
|
||||
|
||||
A map produces exactly one output record for each record in
|
||||
the input stream.
|
||||
|
||||
Attributes:
|
||||
map_fn (function): The user-defined function.
|
||||
"""
|
||||
|
||||
def __init__(self, instance_id, operator_metadata, input_gate,
|
||||
output_gate):
|
||||
OperatorInstance.__init__(self, instance_id, input_gate, output_gate)
|
||||
self.map_fn = operator_metadata.logic
|
||||
|
||||
# Applies the mapper each record of the input stream(s)
|
||||
# and pushes resulting records to the output stream(s)
|
||||
def start(self):
|
||||
start = time.time()
|
||||
elements = 0
|
||||
while True:
|
||||
record = self.input._pull()
|
||||
if record is None:
|
||||
self.output._flush(close=True)
|
||||
logger.debug("[map {}] read/writes per second: {}".format(
|
||||
self.instance_id, elements / (time.time() - start)))
|
||||
return
|
||||
self.output._push(self.map_fn(record))
|
||||
elements += 1
|
||||
|
||||
|
||||
# Flatmap actor
|
||||
@ray.remote
|
||||
class FlatMap(OperatorInstance):
|
||||
"""A map operator instance that applies a user-defined
|
||||
stream transformation.
|
||||
|
||||
A flatmap produces one or more output records for each record in
|
||||
the input stream.
|
||||
|
||||
Attributes:
|
||||
flatmap_fn (function): The user-defined function.
|
||||
"""
|
||||
|
||||
def __init__(self, instance_id, operator_metadata, input_gate,
|
||||
output_gate):
|
||||
OperatorInstance.__init__(self, instance_id, input_gate, output_gate)
|
||||
self.flatmap_fn = operator_metadata.logic
|
||||
|
||||
# Applies the splitter to the records of the input stream(s)
|
||||
# and pushes resulting records to the output stream(s)
|
||||
def start(self):
|
||||
while True:
|
||||
record = self.input._pull()
|
||||
if record is None:
|
||||
self.output._flush(close=True)
|
||||
return
|
||||
self.output._push_all(self.flatmap_fn(record))
|
||||
|
||||
|
||||
# Filter actor
|
||||
@ray.remote
|
||||
class Filter(OperatorInstance):
|
||||
"""A filter operator instance that applies a user-defined filter to
|
||||
each record of the stream.
|
||||
|
||||
Output records are those that pass the filter, i.e. those for which
|
||||
the filter function returns True.
|
||||
|
||||
Attributes:
|
||||
filter_fn (function): The user-defined boolean function.
|
||||
"""
|
||||
|
||||
def __init__(self, instance_id, operator_metadata, input_gate,
|
||||
output_gate):
|
||||
OperatorInstance.__init__(self, instance_id, input_gate, output_gate)
|
||||
self.filter_fn = operator_metadata.logic
|
||||
|
||||
# Applies the filter to the records of the input stream(s)
|
||||
# and pushes resulting records to the output stream(s)
|
||||
def start(self):
|
||||
while True:
|
||||
record = self.input._pull()
|
||||
if record is None: # Close channel and return
|
||||
self.output._flush(close=True)
|
||||
return
|
||||
if self.filter_fn(record):
|
||||
self.output._push(record)
|
||||
|
||||
|
||||
# Inspect actor
|
||||
@ray.remote
|
||||
class Inspect(OperatorInstance):
|
||||
"""A inspect operator instance that inspects the content of the stream.
|
||||
|
||||
Inspect is useful for printing the records in the stream.
|
||||
|
||||
Attributes:
|
||||
inspect_fn (function): The user-defined inspect logic.
|
||||
"""
|
||||
|
||||
def __init__(self, instance_id, operator_metadata, input_gate,
|
||||
output_gate):
|
||||
OperatorInstance.__init__(self, instance_id, input_gate, output_gate)
|
||||
self.inspect_fn = operator_metadata.logic
|
||||
|
||||
# Applies the inspect logic (e.g. print) to the records of
|
||||
# the input stream(s)
|
||||
# and leaves stream unaffected by simply pushing the records to
|
||||
# the output stream(s)
|
||||
while True:
|
||||
record = self.input._pull()
|
||||
if record is None:
|
||||
self.output._flush(close=True)
|
||||
return
|
||||
self.output._push(record)
|
||||
self.inspect_fn(record)
|
||||
|
||||
|
||||
# Reduce actor
|
||||
@ray.remote
|
||||
class Reduce(OperatorInstance):
|
||||
"""A reduce operator instance that combines a new value for a key
|
||||
with the last reduced one according to a user-defined logic.
|
||||
|
||||
Attributes:
|
||||
reduce_fn (function): The user-defined reduce logic.
|
||||
value_attribute (int): The index of the value to reduce
|
||||
(assuming tuple records).
|
||||
state (dict): A mapping from keys to values.
|
||||
"""
|
||||
|
||||
def __init__(self, instance_id, operator_metadata, input_gate,
|
||||
output_gate):
|
||||
OperatorInstance.__init__(self, instance_id, input_gate, output_gate,
|
||||
operator_metadata.state_actor)
|
||||
self.reduce_fn = operator_metadata.logic
|
||||
# Set the attribute selector
|
||||
self.attribute_selector = operator_metadata.other_args
|
||||
if self.attribute_selector is None:
|
||||
self.attribute_selector = _identity
|
||||
elif isinstance(self.attribute_selector, int):
|
||||
self.key_index = self.attribute_selector
|
||||
self.attribute_selector = self.index_based_selector
|
||||
elif isinstance(self.attribute_selector, str):
|
||||
self.key_attribute = self.attribute_selector
|
||||
self.attribute_selector = self.attribute_based_selector
|
||||
elif not isinstance(self.attribute_selector, types.FunctionType):
|
||||
sys.exit("Unrecognized or unsupported key selector.")
|
||||
self.state = {} # key -> value
|
||||
|
||||
# Combines the input value for a key with the last reduced
|
||||
# value for that key to produce a new value.
|
||||
# Outputs the result as (key,new value)
|
||||
def start(self):
|
||||
while True:
|
||||
record = self.input._pull()
|
||||
if record is None:
|
||||
self.output._flush(close=True)
|
||||
del self.state
|
||||
return
|
||||
key, rest = record
|
||||
new_value = self.attribute_selector(rest)
|
||||
# TODO (john): Is there a way to update state with
|
||||
# a single dictionary lookup?
|
||||
try:
|
||||
old_value = self.state[key]
|
||||
new_value = self.reduce_fn(old_value, new_value)
|
||||
self.state[key] = new_value
|
||||
except KeyError: # Key does not exist in state
|
||||
self.state.setdefault(key, new_value)
|
||||
self.output._push((key, new_value))
|
||||
|
||||
# Returns the state of the actor
|
||||
def get_state(self):
|
||||
return self.state
|
||||
|
||||
|
||||
@ray.remote
|
||||
class KeyBy(OperatorInstance):
|
||||
"""A key_by operator instance that physically partitions the
|
||||
stream based on a key.
|
||||
|
||||
Attributes:
|
||||
key_attribute (int): The index of the value to reduce
|
||||
(assuming tuple records).
|
||||
"""
|
||||
|
||||
def __init__(self, instance_id, operator_metadata, input_gate,
|
||||
output_gate):
|
||||
OperatorInstance.__init__(self, instance_id, input_gate, output_gate)
|
||||
# Set the key selector
|
||||
self.key_selector = operator_metadata.other_args
|
||||
if isinstance(self.key_selector, int):
|
||||
self.key_index = self.key_selector
|
||||
self.key_selector = self.index_based_selector
|
||||
elif isinstance(self.key_selector, str):
|
||||
self.key_attribute = self.key_selector
|
||||
self.key_selector = self.attribute_based_selector
|
||||
elif not isinstance(self.key_selector, types.FunctionType):
|
||||
sys.exit("Unrecognized or unsupported key selector.")
|
||||
|
||||
# The actual partitioning is done by the output gate
|
||||
def start(self):
|
||||
while True:
|
||||
record = self.input._pull()
|
||||
if record is None:
|
||||
self.output._flush(close=True)
|
||||
return
|
||||
key = self.key_selector(record)
|
||||
self.output._push((key, record))
|
||||
|
||||
|
||||
# A custom source actor
|
||||
@ray.remote
|
||||
class Source(OperatorInstance):
|
||||
def __init__(self, instance_id, operator_metadata, input_gate,
|
||||
output_gate):
|
||||
OperatorInstance.__init__(self, instance_id, input_gate, output_gate)
|
||||
# The user-defined source with a get_next() method
|
||||
self.source = operator_metadata.other_args
|
||||
|
||||
# Starts the source by calling get_next() repeatedly
|
||||
def start(self):
|
||||
start = time.time()
|
||||
elements = 0
|
||||
while True:
|
||||
next = self.source.get_next()
|
||||
if next is None:
|
||||
self.output._flush(close=True)
|
||||
logger.debug("[writer {}] puts per second: {}".format(
|
||||
self.instance_id, elements / (time.time() - start)))
|
||||
return
|
||||
self.output._push(next)
|
||||
elements += 1
|
||||
|
||||
|
||||
# TODO(john): Time window actor (uses system time)
|
||||
@ray.remote
|
||||
class TimeWindow(OperatorInstance):
|
||||
def __init__(self, queue, width):
|
||||
self.width = width # In milliseconds
|
||||
|
||||
def time_window(self):
|
||||
while True:
|
||||
pass
|
|
@ -15,11 +15,6 @@ cdef class Buffer:
|
|||
|
||||
See https://docs.python.org/3/c-api/buffer.html for details.
|
||||
"""
|
||||
cdef:
|
||||
shared_ptr[CBuffer] buffer
|
||||
Py_ssize_t shape
|
||||
Py_ssize_t strides
|
||||
|
||||
@staticmethod
|
||||
cdef make(const shared_ptr[CBuffer]& buffer):
|
||||
cdef Buffer self = Buffer.__new__(Buffer)
|
||||
|
|
|
@ -19,7 +19,7 @@ from ray.includes.unique_ids cimport (
|
|||
CObjectID,
|
||||
CTaskID,
|
||||
CUniqueID,
|
||||
CWorkerID,
|
||||
CWorkerID
|
||||
)
|
||||
|
||||
import ray
|
||||
|
@ -40,8 +40,6 @@ cdef extern from "ray/common/constants.h" nogil:
|
|||
|
||||
cdef class BaseID:
|
||||
|
||||
# To avoid the error of "Python int too large to convert to C ssize_t",
|
||||
# here `cdef size_t` is required.
|
||||
cdef size_t hash(self):
|
||||
pass
|
||||
|
||||
|
@ -129,13 +127,6 @@ cdef class UniqueID(BaseID):
|
|||
|
||||
|
||||
cdef class ObjectID(BaseID):
|
||||
cdef:
|
||||
CObjectID data
|
||||
object buffer_ref
|
||||
# Flag indicating whether or not this object ID was added to the set
|
||||
# of active IDs in the core worker so we know whether we should clean
|
||||
# it up.
|
||||
c_bool in_core_worker
|
||||
|
||||
def __init__(self, id):
|
||||
check_id(id)
|
||||
|
@ -332,8 +323,6 @@ cdef class WorkerID(UniqueID):
|
|||
return <CWorkerID>self.data
|
||||
|
||||
cdef class ActorID(BaseID):
|
||||
cdef CActorID data
|
||||
|
||||
def __init__(self, id):
|
||||
check_id(id, CActorID.Size())
|
||||
self.data = CActorID.FromBinary(<c_string>id)
|
||||
|
|
1
python/ray/streaming
Symbolic link
1
python/ray/streaming
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../streaming/python/
|
|
@ -233,14 +233,6 @@ py_test(
|
|||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_logical_graph",
|
||||
size = "small",
|
||||
srcs = ["test_logical_graph.py"],
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_memory_limits",
|
||||
size = "medium",
|
||||
|
|
|
@ -177,6 +177,7 @@ requires = [
|
|||
"six >= 1.0.0",
|
||||
"faulthandler;python_version<'3.3'",
|
||||
"protobuf >= 3.8.0",
|
||||
"cloudpickle",
|
||||
]
|
||||
|
||||
setup(
|
||||
|
|
|
@ -603,11 +603,13 @@ Status CoreWorker::SubmitTask(const RayFunction &function,
|
|||
TaskID::ForNormalTask(worker_context_.GetCurrentJobID(),
|
||||
worker_context_.GetCurrentTaskID(), next_task_index);
|
||||
|
||||
const std::unordered_map<std::string, double> required_resources;
|
||||
// TODO(ekl) offload task building onto a thread pool for performance
|
||||
BuildCommonTaskSpec(
|
||||
builder, worker_context_.GetCurrentJobID(), task_id,
|
||||
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_,
|
||||
function, args, task_options.num_returns, task_options.resources, {},
|
||||
function, args, task_options.num_returns, task_options.resources,
|
||||
required_resources,
|
||||
task_options.is_direct_call ? TaskTransportType::DIRECT : TaskTransportType::RAYLET,
|
||||
return_ids);
|
||||
TaskSpecification task_spec = builder.Build();
|
||||
|
@ -681,10 +683,11 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
|
|||
const TaskID actor_task_id = TaskID::ForActorTask(
|
||||
worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(),
|
||||
next_task_index, actor_handle->GetActorID());
|
||||
const std::unordered_map<std::string, double> required_resources;
|
||||
BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id,
|
||||
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(),
|
||||
rpc_address_, function, args, num_returns, task_options.resources,
|
||||
{}, transport_type, return_ids);
|
||||
required_resources, transport_type, return_ids);
|
||||
|
||||
const ObjectID new_cursor = return_ids->back();
|
||||
actor_handle->SetActorTaskSpec(builder, transport_type, new_cursor);
|
||||
|
|
27
src/ray/ray_exported_symbols.lds
Normal file
27
src/ray/ray_exported_symbols.lds
Normal file
|
@ -0,0 +1,27 @@
|
|||
# This file defines the C++ symbols that need to be exported (aka ABI, application binary interface).
|
||||
# These symbols will be used by other libraries (e.g., streaming).
|
||||
# Note: This file is used for macOS only, and should be kept in sync with `ray_version_script.lds`.
|
||||
# Ray ABI is not finalized, the exact set of exported (C/C++) APIs is subject to change.
|
||||
# common
|
||||
*ray*Language*
|
||||
*ray*RayObject*
|
||||
*ray*Status*
|
||||
*ray*RayFunction*
|
||||
*ray*TaskArg*
|
||||
*ray*TaskOptions*
|
||||
*ray*Buffer*
|
||||
*ray*LocalMemoryBuffer*
|
||||
# util
|
||||
*ray*RayLog*
|
||||
*ray*RayLogLevel*
|
||||
# id
|
||||
*ray*MurmurHash64A*
|
||||
*ray*JobID*
|
||||
*ray*TaskID*
|
||||
*ray*ActorID*
|
||||
*ray*ObjectID*
|
||||
# Others
|
||||
*ray*CoreWorker*
|
||||
*PyInit*
|
||||
*init_raylet*
|
||||
*Java*
|
31
src/ray/ray_version_script.lds
Normal file
31
src/ray/ray_version_script.lds
Normal file
|
@ -0,0 +1,31 @@
|
|||
# This file defines the C++ symbols that need to be exported (aka ABI, application binary interface).
|
||||
# These symbols will be used by other libraries (e.g., streaming).
|
||||
# Note: This file is used for linux only, and should be kept in sync with `ray_exported_symbols.lds`.
|
||||
# Ray ABI is not finalized, the exact set of exported (C/C++) APIs is subject to change.
|
||||
VERSION_1.0 {
|
||||
global:
|
||||
# common
|
||||
*ray*Language*;
|
||||
*ray*RayObject*;
|
||||
*ray*Status*;
|
||||
*ray*RayFunction*;
|
||||
*ray*TaskArg*;
|
||||
*ray*TaskOptions*;
|
||||
*ray*Buffer*;
|
||||
*ray*LocalMemoryBuffer*;
|
||||
# util
|
||||
*ray*RayLog*;
|
||||
*ray*RayLogLevel*;
|
||||
# id
|
||||
*ray*MurmurHash64A*;
|
||||
*ray*JobID*;
|
||||
*ray*TaskID*;
|
||||
*ray*ActorID*;
|
||||
*ray*ObjectID*;
|
||||
# Others
|
||||
*ray*CoreWorker*;
|
||||
*PyInit*;
|
||||
*init_raylet*;
|
||||
*Java*;
|
||||
local: *;
|
||||
};
|
235
streaming/BUILD.bazel
Normal file
235
streaming/BUILD.bazel
Normal file
|
@ -0,0 +1,235 @@
|
|||
# Bazel build
|
||||
# C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html
|
||||
|
||||
load("@com_github_grpc_grpc//bazel:cython_library.bzl", "pyx_library")
|
||||
load("@rules_proto_grpc//python:defs.bzl", "python_proto_compile")
|
||||
|
||||
proto_library(
|
||||
name = "streaming_proto",
|
||||
srcs = ["src/protobuf/streaming.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_proto_library(
|
||||
name = "streaming_cc_proto",
|
||||
deps = [":streaming_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "streaming_queue_proto",
|
||||
srcs = ["src/protobuf/streaming_queue.proto"],
|
||||
)
|
||||
|
||||
cc_proto_library(
|
||||
name = "streaming_queue_cc_proto",
|
||||
deps = ["streaming_queue_proto"],
|
||||
)
|
||||
|
||||
# Use `linkshared` to ensure ray related symbols are not packed into streaming libs
|
||||
# to avoid duplicate symbols. In runtime we expose ray related symbols, which can
|
||||
# be linked into streaming libs by dynamic linker. See bazel rule `//:_raylet`
|
||||
cc_binary(
|
||||
name = "ray_util.so",
|
||||
linkshared = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//:ray_util"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "ray_common.so",
|
||||
linkshared = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//:ray_common"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "core_worker_lib.so",
|
||||
linkshared = 1,
|
||||
deps = ["//:core_worker_lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "streaming_util",
|
||||
srcs = glob([
|
||||
"src/util/*.cc",
|
||||
]),
|
||||
hdrs = glob([
|
||||
"src/util/*.h",
|
||||
]),
|
||||
includes = [
|
||||
"src",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"ray_util.so",
|
||||
"@boost//:any",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "streaming_config",
|
||||
srcs = glob([
|
||||
"src/config/*.cc",
|
||||
]),
|
||||
hdrs = glob([
|
||||
"src/config/*.h",
|
||||
]),
|
||||
deps = [
|
||||
"ray_common.so",
|
||||
":streaming_cc_proto",
|
||||
":streaming_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "streaming_message",
|
||||
srcs = glob([
|
||||
"src/message/*.cc",
|
||||
]),
|
||||
hdrs = glob([
|
||||
"src/message/*.h",
|
||||
]),
|
||||
deps = [
|
||||
"ray_common.so",
|
||||
":streaming_config",
|
||||
":streaming_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "streaming_queue",
|
||||
srcs = glob([
|
||||
"src/queue/*.cc",
|
||||
]),
|
||||
hdrs = glob([
|
||||
"src/queue/*.h",
|
||||
]),
|
||||
deps = [
|
||||
"core_worker_lib.so",
|
||||
"ray_common.so",
|
||||
"ray_util.so",
|
||||
":streaming_config",
|
||||
":streaming_message",
|
||||
":streaming_queue_cc_proto",
|
||||
":streaming_util",
|
||||
"@boost//:asio",
|
||||
"@boost//:thread",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "streaming_lib",
|
||||
srcs = glob([
|
||||
"src/*.cc",
|
||||
]),
|
||||
hdrs = glob([
|
||||
"src/*.h",
|
||||
"src/queue/*.h",
|
||||
"src/test/*.h",
|
||||
]),
|
||||
includes = ["src"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"ray_common.so",
|
||||
"ray_util.so",
|
||||
":streaming_config",
|
||||
":streaming_message",
|
||||
":streaming_queue",
|
||||
":streaming_util",
|
||||
"@boost//:circular_buffer",
|
||||
],
|
||||
)
|
||||
|
||||
test_common_deps = [
|
||||
":streaming_lib",
|
||||
"//:ray_common",
|
||||
"//:ray_util",
|
||||
"//:core_worker_lib",
|
||||
]
|
||||
|
||||
# streaming queue mock actor binary
|
||||
cc_binary(
|
||||
name = "streaming_test_worker",
|
||||
srcs = glob(["src/test/*.h"]) + [
|
||||
"src/test/mock_actor.cc",
|
||||
],
|
||||
includes = [
|
||||
"streaming/src/test",
|
||||
],
|
||||
deps = test_common_deps,
|
||||
)
|
||||
|
||||
# use src/test/run_streaming_queue_test.sh to run this test
|
||||
cc_binary(
|
||||
name = "streaming_queue_tests",
|
||||
srcs = glob(["src/test/*.h"]) + [
|
||||
"src/test/streaming_queue_tests.cc",
|
||||
],
|
||||
deps = test_common_deps,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "streaming_message_ring_buffer_tests",
|
||||
srcs = [
|
||||
"src/test/ring_buffer_tests.cc",
|
||||
],
|
||||
includes = [
|
||||
"streaming/src/test",
|
||||
],
|
||||
deps = test_common_deps,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "streaming_message_serialization_tests",
|
||||
srcs = [
|
||||
"src/test/message_serialization_tests.cc",
|
||||
],
|
||||
deps = test_common_deps,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "streaming_mock_transfer",
|
||||
srcs = [
|
||||
"src/test/mock_transfer_tests.cc",
|
||||
],
|
||||
deps = test_common_deps,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "streaming_util_tests",
|
||||
srcs = [
|
||||
"src/test/streaming_util_tests.cc",
|
||||
],
|
||||
deps = test_common_deps,
|
||||
)
|
||||
|
||||
python_proto_compile(
|
||||
name = "streaming_py_proto",
|
||||
deps = ["//streaming:streaming_proto"],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "copy_streaming_py_proto",
|
||||
srcs = [
|
||||
":streaming_py_proto",
|
||||
],
|
||||
outs = [
|
||||
"copy_streaming_py_proto.out",
|
||||
],
|
||||
cmd = """
|
||||
set -e
|
||||
set -x
|
||||
WORK_DIR=$$(pwd)
|
||||
# Copy generated files.
|
||||
GENERATED_DIR=$$WORK_DIR/streaming/python/generated
|
||||
rm -rf $$GENERATED_DIR
|
||||
mkdir -p $$GENERATED_DIR
|
||||
for f in $(locations //streaming:streaming_py_proto); do
|
||||
cp $$f $$GENERATED_DIR
|
||||
done
|
||||
echo $$(date) > $@
|
||||
""",
|
||||
local = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
28
streaming/README.md
Normal file
28
streaming/README.md
Normal file
|
@ -0,0 +1,28 @@
|
|||
# Ray Streaming
|
||||
|
||||
1. Build streaming java
|
||||
* build ray
|
||||
* `sh build.sh -l java`
|
||||
* `cd java && mvn clean install -Dmaven.test.skip=true`
|
||||
* build streaming
|
||||
* `cd ray/streaming/java && bazel build all_modules`
|
||||
* `mvn clean install -Dmaven.test.skip=true`
|
||||
|
||||
2. Build ray will build ray streaming python.
|
||||
|
||||
3. Run examples
|
||||
```bash
|
||||
# c++ test
|
||||
cd streaming/ && bazel test ...
|
||||
sh src/test/run_streaming_queue_test.sh
|
||||
cd ..
|
||||
|
||||
# python test
|
||||
cd python/ray/streaming/
|
||||
pushd examples
|
||||
python simple.py --input-file toy.txt
|
||||
popd
|
||||
pushd tests
|
||||
pytest .
|
||||
popd
|
||||
```
|
3
streaming/python/__init__.py
Normal file
3
streaming/python/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
# flake8: noqa
|
||||
# Ray should be imported before streaming
|
||||
import ray
|
6
streaming/python/_streaming.pyx
Normal file
6
streaming/python/_streaming.pyx
Normal file
|
@ -0,0 +1,6 @@
|
|||
# cython: profile=False
|
||||
# distutils: language = c++
|
||||
# cython: embedsignature = True
|
||||
# cython: language_level = 3
|
||||
|
||||
include "includes/transfer.pxi"
|
283
streaming/python/communication.py
Normal file
283
streaming/python/communication.py
Normal file
|
@ -0,0 +1,283 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
|
||||
import ray
|
||||
import ray.streaming.runtime.transfer as transfer
|
||||
from ray.streaming.config import Config
|
||||
from ray.streaming.operator import PStrategy
|
||||
from ray.streaming.runtime.transfer import ChannelID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Forward and broadcast stream partitioning strategies
|
||||
forward_broadcast_strategies = [PStrategy.Forward, PStrategy.Broadcast]
|
||||
|
||||
|
||||
# Used to choose output channel in case of hash-based shuffling
|
||||
def _hash(value):
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
try:
|
||||
return int(hashlib.sha1(value.encode("utf-8")).hexdigest(), 16)
|
||||
except AttributeError:
|
||||
return int(hashlib.sha1(value).hexdigest(), 16)
|
||||
|
||||
|
||||
class DataChannel(object):
|
||||
"""A data channel for actor-to-actor communication.
|
||||
|
||||
Attributes:
|
||||
env (Environment): The environment the channel belongs to.
|
||||
src_operator_id (UUID): The id of the source operator of the channel.
|
||||
src_instance_index (int): The id of the source instance.
|
||||
dst_operator_id (UUID): The id of the destination operator of the
|
||||
channel.
|
||||
dst_instance_index (int): The id of the destination instance.
|
||||
"""
|
||||
|
||||
def __init__(self, src_operator_id, src_instance_index, dst_operator_id,
|
||||
dst_instance_index, str_qid):
|
||||
self.src_operator_id = src_operator_id
|
||||
self.src_instance_index = src_instance_index
|
||||
self.dst_operator_id = dst_operator_id
|
||||
self.dst_instance_index = dst_instance_index
|
||||
self.str_qid = str_qid
|
||||
self.qid = ChannelID(str_qid)
|
||||
|
||||
def __repr__(self):
|
||||
return "(src({},{}),dst({},{}), qid({}))".format(
|
||||
self.src_operator_id, self.src_instance_index,
|
||||
self.dst_operator_id, self.dst_instance_index, self.str_qid)
|
||||
|
||||
|
||||
_CLOSE_FLAG = b" "
|
||||
|
||||
|
||||
# Pulls and merges data from multiple input channels
|
||||
class DataInput(object):
|
||||
"""An input gate of an operator instance.
|
||||
|
||||
The input gate pulls records from all input channels in a round-robin
|
||||
fashion.
|
||||
|
||||
Attributes:
|
||||
input_channels (list): The list of input channels.
|
||||
channel_index (int): The index of the next channel to pull from.
|
||||
max_index (int): The number of input channels.
|
||||
closed (list): A list of flags indicating whether an input channel
|
||||
has been marked as 'closed'.
|
||||
all_closed (bool): Denotes whether all input channels have been
|
||||
closed (True) or not (False).
|
||||
"""
|
||||
|
||||
def __init__(self, env, channels):
|
||||
assert len(channels) > 0
|
||||
self.env = env
|
||||
self.reader = None # created in `init` method
|
||||
self.input_channels = channels
|
||||
self.channel_index = 0
|
||||
self.max_index = len(channels)
|
||||
# Tracks the channels that have been closed. qid: close status
|
||||
self.closed = {}
|
||||
|
||||
def init(self):
|
||||
channels = [c.str_qid for c in self.input_channels]
|
||||
input_actors = []
|
||||
for c in self.input_channels:
|
||||
actor = self.env.execution_graph.get_actor(c.src_operator_id,
|
||||
c.src_instance_index)
|
||||
input_actors.append(actor)
|
||||
logger.info("DataInput input_actors %s", input_actors)
|
||||
conf = {
|
||||
Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context()
|
||||
.current_driver_id,
|
||||
Config.CHANNEL_TYPE: self.env.config.channel_type
|
||||
}
|
||||
self.reader = transfer.DataReader(channels, input_actors, conf)
|
||||
|
||||
def pull(self):
|
||||
# pull from channel
|
||||
item = self.reader.read(100)
|
||||
while item is None:
|
||||
time.sleep(0.001)
|
||||
item = self.reader.read(100)
|
||||
msg_data = item.body()
|
||||
if msg_data == _CLOSE_FLAG:
|
||||
self.closed[item.channel_id] = True
|
||||
if len(self.closed) == len(self.input_channels):
|
||||
return None
|
||||
else:
|
||||
return self.pull()
|
||||
else:
|
||||
return pickle.loads(msg_data)
|
||||
|
||||
def close(self):
|
||||
self.reader.stop()
|
||||
|
||||
|
||||
# Selects output channel(s) and pushes data
|
||||
class DataOutput(object):
|
||||
"""An output gate of an operator instance.
|
||||
|
||||
The output gate pushes records to output channels according to the
|
||||
user-defined partitioning scheme.
|
||||
|
||||
Attributes:
|
||||
partitioning_schemes (dict): A mapping from destination operator ids
|
||||
to partitioning schemes (see: PScheme in operator.py).
|
||||
forward_channels (list): A list of channels to forward records.
|
||||
shuffle_channels (list(list)): A list of output channels to shuffle
|
||||
records grouped by destination operator.
|
||||
shuffle_key_channels (list(list)): A list of output channels to
|
||||
shuffle records by a key grouped by destination operator.
|
||||
shuffle_exists (bool): A flag indicating that there exists at least
|
||||
one shuffle_channel.
|
||||
shuffle_key_exists (bool): A flag indicating that there exists at
|
||||
least one shuffle_key_channel.
|
||||
"""
|
||||
|
||||
def __init__(self, env, channels, partitioning_schemes):
|
||||
assert len(channels) > 0
|
||||
self.env = env
|
||||
self.writer = None # created in `init` method
|
||||
self.channels = channels
|
||||
self.key_selector = None
|
||||
self.round_robin_indexes = [0]
|
||||
self.partitioning_schemes = partitioning_schemes
|
||||
# Prepare output -- collect channels by type
|
||||
self.forward_channels = [] # Forward and broadcast channels
|
||||
slots = sum(1 for scheme in self.partitioning_schemes.values()
|
||||
if scheme.strategy == PStrategy.RoundRobin)
|
||||
self.round_robin_channels = [[]] * slots # RoundRobin channels
|
||||
self.round_robin_indexes = [-1] * slots
|
||||
slots = sum(1 for scheme in self.partitioning_schemes.values()
|
||||
if scheme.strategy == PStrategy.Shuffle)
|
||||
# Flag used to avoid hashing when there is no shuffling
|
||||
self.shuffle_exists = slots > 0
|
||||
self.shuffle_channels = [[]] * slots # Shuffle channels
|
||||
slots = sum(1 for scheme in self.partitioning_schemes.values()
|
||||
if scheme.strategy == PStrategy.ShuffleByKey)
|
||||
# Flag used to avoid hashing when there is no shuffling by key
|
||||
self.shuffle_key_exists = slots > 0
|
||||
self.shuffle_key_channels = [[]] * slots # Shuffle by key channels
|
||||
# Distinct shuffle destinations
|
||||
shuffle_destinations = {}
|
||||
# Distinct shuffle by key destinations
|
||||
shuffle_by_key_destinations = {}
|
||||
# Distinct round robin destinations
|
||||
round_robin_destinations = {}
|
||||
index_1 = 0
|
||||
index_2 = 0
|
||||
index_3 = 0
|
||||
for channel in channels:
|
||||
p_scheme = self.partitioning_schemes[channel.dst_operator_id]
|
||||
strategy = p_scheme.strategy
|
||||
if strategy in forward_broadcast_strategies:
|
||||
self.forward_channels.append(channel)
|
||||
elif strategy == PStrategy.Shuffle:
|
||||
pos = shuffle_destinations.setdefault(channel.dst_operator_id,
|
||||
index_1)
|
||||
self.shuffle_channels[pos].append(channel)
|
||||
if pos == index_1:
|
||||
index_1 += 1
|
||||
elif strategy == PStrategy.ShuffleByKey:
|
||||
pos = shuffle_by_key_destinations.setdefault(
|
||||
channel.dst_operator_id, index_2)
|
||||
self.shuffle_key_channels[pos].append(channel)
|
||||
if pos == index_2:
|
||||
index_2 += 1
|
||||
elif strategy == PStrategy.RoundRobin:
|
||||
pos = round_robin_destinations.setdefault(
|
||||
channel.dst_operator_id, index_3)
|
||||
self.round_robin_channels[pos].append(channel)
|
||||
if pos == index_3:
|
||||
index_3 += 1
|
||||
else: # TODO (john): Add support for other strategies
|
||||
sys.exit("Unrecognized or unsupported partitioning strategy.")
|
||||
# A KeyedDataStream can only be shuffled by key
|
||||
assert not (self.shuffle_exists and self.shuffle_key_exists)
|
||||
|
||||
def init(self):
|
||||
"""init DataOutput which creates DataWriter"""
|
||||
channel_ids = [c.str_qid for c in self.channels]
|
||||
to_actors = []
|
||||
for c in self.channels:
|
||||
actor = self.env.execution_graph.get_actor(c.dst_operator_id,
|
||||
c.dst_instance_index)
|
||||
to_actors.append(actor)
|
||||
logger.info("DataOutput output_actors %s", to_actors)
|
||||
|
||||
conf = {
|
||||
Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context()
|
||||
.current_driver_id,
|
||||
Config.CHANNEL_TYPE: self.env.config.channel_type
|
||||
}
|
||||
self.writer = transfer.DataWriter(channel_ids, to_actors, conf)
|
||||
|
||||
def close(self):
|
||||
"""Close the channel (True) by propagating _CLOSE_FLAG
|
||||
|
||||
_CLOSE_FLAG is used as special type of record that is propagated from
|
||||
sources to sink to notify that the end of data in a stream.
|
||||
"""
|
||||
for c in self.channels:
|
||||
self.writer.write(c.qid, _CLOSE_FLAG)
|
||||
# must ensure DataWriter send None flag to peer actor
|
||||
self.writer.stop()
|
||||
|
||||
def push(self, record):
|
||||
target_channels = []
|
||||
# Forward record
|
||||
for c in self.forward_channels:
|
||||
logger.debug("[writer] Push record '{}' to channel {}".format(
|
||||
record, c))
|
||||
target_channels.append(c)
|
||||
# Forward record
|
||||
index = 0
|
||||
for channels in self.round_robin_channels:
|
||||
self.round_robin_indexes[index] += 1
|
||||
if self.round_robin_indexes[index] == len(channels):
|
||||
self.round_robin_indexes[index] = 0 # Reset index
|
||||
c = channels[self.round_robin_indexes[index]]
|
||||
logger.debug("[writer] Push record '{}' to channel {}".format(
|
||||
record, c))
|
||||
target_channels.append(c)
|
||||
index += 1
|
||||
# Hash-based shuffling by key
|
||||
if self.shuffle_key_exists:
|
||||
key, _ = record
|
||||
h = _hash(key)
|
||||
for channels in self.shuffle_key_channels:
|
||||
num_instances = len(channels) # Downstream instances
|
||||
c = channels[h % num_instances]
|
||||
logger.debug(
|
||||
"[key_shuffle] Push record '{}' to channel {}".format(
|
||||
record, c))
|
||||
target_channels.append(c)
|
||||
elif self.shuffle_exists: # Hash-based shuffling per destination
|
||||
h = _hash(record)
|
||||
for channels in self.shuffle_channels:
|
||||
num_instances = len(channels) # Downstream instances
|
||||
c = channels[h % num_instances]
|
||||
logger.debug("[shuffle] Push record '{}' to channel {}".format(
|
||||
record, c))
|
||||
target_channels.append(c)
|
||||
else: # TODO (john): Handle rescaling
|
||||
pass
|
||||
|
||||
msg_data = pickle.dumps(record)
|
||||
for c in target_channels:
|
||||
# send data to channel
|
||||
self.writer.write(c.qid, msg_data)
|
||||
|
||||
def push_all(self, records):
|
||||
for record in records:
|
||||
self.push(record)
|
23
streaming/python/config.py
Normal file
23
streaming/python/config.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
class Config:
|
||||
STREAMING_JOB_NAME = "streaming.job.name"
|
||||
STREAMING_OP_NAME = "streaming.op_name"
|
||||
TASK_JOB_ID = "streaming.task_job_id"
|
||||
STREAMING_WORKER_NAME = "streaming.worker_name"
|
||||
# channel
|
||||
CHANNEL_TYPE = "channel_type"
|
||||
MEMORY_CHANNEL = "memory_channel"
|
||||
NATIVE_CHANNEL = "native_channel"
|
||||
CHANNEL_SIZE = "channel_size"
|
||||
CHANNEL_SIZE_DEFAULT = 10**8
|
||||
IS_RECREATE = "streaming.is_recreate"
|
||||
# return from StreamingReader.getBundle if only empty message read in this
|
||||
# interval.
|
||||
TIMER_INTERVAL_MS = "timer_interval_ms"
|
||||
|
||||
STREAMING_RING_BUFFER_CAPACITY = "streaming.ring_buffer_capacity"
|
||||
# write an empty message if there is no data to be written in this
|
||||
# interval.
|
||||
STREAMING_EMPTY_MESSAGE_INTERVAL = "streaming.empty_message_interval"
|
||||
|
||||
# operator type
|
||||
OPERATOR_TYPE = "operator_type"
|
|
@ -7,9 +7,7 @@ import logging
|
|||
import time
|
||||
|
||||
import ray
|
||||
from ray.experimental.streaming.streaming import Environment
|
||||
from ray.experimental.streaming.batched_queue import BatchedQueue
|
||||
from ray.experimental.streaming.operator import OpType, PStrategy
|
||||
from ray.streaming.streaming import Environment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
@ -48,9 +46,6 @@ if __name__ == "__main__":
|
|||
|
||||
ray.init()
|
||||
ray.register_custom_serializer(Record, use_dict=True)
|
||||
ray.register_custom_serializer(BatchedQueue, use_pickle=True)
|
||||
ray.register_custom_serializer(OpType, use_pickle=True)
|
||||
ray.register_custom_serializer(PStrategy, use_pickle=True)
|
||||
|
||||
# A Ray streaming environment with the default configuration
|
||||
env = Environment()
|
|
@ -7,9 +7,8 @@ import logging
|
|||
import time
|
||||
|
||||
import ray
|
||||
from ray.experimental.streaming.streaming import Environment
|
||||
from ray.experimental.streaming.batched_queue import BatchedQueue
|
||||
from ray.experimental.streaming.operator import OpType, PStrategy
|
||||
from ray.streaming.config import Config
|
||||
from ray.streaming.streaming import Environment, Conf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
@ -33,27 +32,25 @@ if __name__ == "__main__":
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init()
|
||||
ray.register_custom_serializer(BatchedQueue, use_pickle=True)
|
||||
ray.register_custom_serializer(OpType, use_pickle=True)
|
||||
ray.register_custom_serializer(PStrategy, use_pickle=True)
|
||||
ray.init(local_mode=False)
|
||||
|
||||
# A Ray streaming environment with the default configuration
|
||||
env = Environment()
|
||||
env = Environment(config=Conf(channel_type=Config.NATIVE_CHANNEL))
|
||||
|
||||
# Stream represents the ouput of the filter and
|
||||
# can be forked into other dataflows
|
||||
stream = env.read_text_file(args.input_file) \
|
||||
.shuffle() \
|
||||
.flat_map(splitter) \
|
||||
.set_parallelism(4) \
|
||||
.set_parallelism(2) \
|
||||
.filter(filter_fn) \
|
||||
.set_parallelism(2) \
|
||||
.inspect(print) # Prints the contents of the
|
||||
.inspect(lambda x: print("result", x)) # Prints the contents of the
|
||||
# stream to stdout
|
||||
start = time.time()
|
||||
env_handle = env.execute()
|
||||
ray.get(env_handle) # Stay alive until execution finishes
|
||||
env.wait_finish()
|
||||
end = time.time()
|
||||
logger.info("Elapsed time: {} secs".format(end - start))
|
||||
logger.debug("Output stream id: {}".format(stream.id))
|
|
@ -5,12 +5,10 @@ from __future__ import print_function
|
|||
import argparse
|
||||
import logging
|
||||
import time
|
||||
import wikipedia
|
||||
|
||||
import ray
|
||||
from ray.experimental.streaming.streaming import Environment
|
||||
from ray.experimental.streaming.batched_queue import BatchedQueue
|
||||
from ray.experimental.streaming.operator import OpType, PStrategy
|
||||
import wikipedia
|
||||
from ray.streaming.streaming import Environment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
@ -86,9 +84,6 @@ if __name__ == "__main__":
|
|||
titles_file = str(args.titles_file)
|
||||
|
||||
ray.init()
|
||||
ray.register_custom_serializer(BatchedQueue, use_pickle=True)
|
||||
ray.register_custom_serializer(OpType, use_pickle=True)
|
||||
ray.register_custom_serializer(PStrategy, use_pickle=True)
|
||||
|
||||
# A Ray streaming environment with the default configuration
|
||||
env = Environment()
|
||||
|
@ -108,6 +103,7 @@ if __name__ == "__main__":
|
|||
start = time.time()
|
||||
env_handle = env.execute() # Deploys and executes the dataflow
|
||||
ray.get(env_handle) # Stay alive until execution finishes
|
||||
env.wait_finish()
|
||||
end = time.time()
|
||||
logger.info("Elapsed time: {} secs".format(end - start))
|
||||
logger.debug("Output stream id: {}".format(stream.id))
|
0
streaming/python/includes/__init__.pxd
Normal file
0
streaming/python/includes/__init__.pxd
Normal file
153
streaming/python/includes/libstreaming.pxd
Normal file
153
streaming/python/includes/libstreaming.pxd
Normal file
|
@ -0,0 +1,153 @@
|
|||
# cython: profile=False
|
||||
# distutils: language = c++
|
||||
# cython: embedsignature = True
|
||||
# cython: language_level = 3
|
||||
# flake8: noqa
|
||||
|
||||
from libc.stdint cimport *
|
||||
from libcpp cimport bool as c_bool
|
||||
from libcpp.memory cimport shared_ptr
|
||||
from libcpp.vector cimport vector as c_vector
|
||||
from libcpp.list cimport list as c_list
|
||||
from cpython cimport PyObject
|
||||
cimport cpython
|
||||
|
||||
cdef inline object PyObject_to_object(PyObject* o):
|
||||
# Cast to "object" increments reference count
|
||||
cdef object result = <object> o
|
||||
cpython.Py_DECREF(result)
|
||||
return result
|
||||
|
||||
from ray.includes.common cimport (
|
||||
CLanguage,
|
||||
CRayObject,
|
||||
CRayStatus,
|
||||
CRayFunction
|
||||
)
|
||||
|
||||
from ray.includes.unique_ids cimport (
|
||||
CActorID,
|
||||
CJobID,
|
||||
CTaskID,
|
||||
CObjectID,
|
||||
)
|
||||
from ray.includes.libcoreworker cimport CCoreWorker
|
||||
|
||||
cdef extern from "status.h" namespace "ray::streaming" nogil:
|
||||
cdef cppclass CStreamingStatus "ray::streaming::StreamingStatus":
|
||||
pass
|
||||
cdef CStreamingStatus StatusOK "ray::streaming::StreamingStatus::OK"
|
||||
cdef CStreamingStatus StatusReconstructTimeOut "ray::streaming::StreamingStatus::ReconstructTimeOut"
|
||||
cdef CStreamingStatus StatusQueueIdNotFound "ray::streaming::StreamingStatus::QueueIdNotFound"
|
||||
cdef CStreamingStatus StatusResubscribeFailed "ray::streaming::StreamingStatus::ResubscribeFailed"
|
||||
cdef CStreamingStatus StatusEmptyRingBuffer "ray::streaming::StreamingStatus::EmptyRingBuffer"
|
||||
cdef CStreamingStatus StatusFullChannel "ray::streaming::StreamingStatus::FullChannel"
|
||||
cdef CStreamingStatus StatusNoSuchItem "ray::streaming::StreamingStatus::NoSuchItem"
|
||||
cdef CStreamingStatus StatusInitQueueFailed "ray::streaming::StreamingStatus::InitQueueFailed"
|
||||
cdef CStreamingStatus StatusGetBundleTimeOut "ray::streaming::StreamingStatus::GetBundleTimeOut"
|
||||
cdef CStreamingStatus StatusSkipSendEmptyMessage "ray::streaming::StreamingStatus::SkipSendEmptyMessage"
|
||||
cdef CStreamingStatus StatusInterrupted "ray::streaming::StreamingStatus::Interrupted"
|
||||
cdef CStreamingStatus StatusWaitQueueTimeOut "ray::streaming::StreamingStatus::WaitQueueTimeOut"
|
||||
cdef CStreamingStatus StatusOutOfMemory "ray::streaming::StreamingStatus::OutOfMemory"
|
||||
cdef CStreamingStatus StatusInvalid "ray::streaming::StreamingStatus::Invalid"
|
||||
cdef CStreamingStatus StatusUnknownError "ray::streaming::StreamingStatus::UnknownError"
|
||||
cdef CStreamingStatus StatusTailStatus "ray::streaming::StreamingStatus::TailStatus"
|
||||
|
||||
cdef cppclass CStreamingCommon "ray::streaming::StreamingCommon":
|
||||
void SetConfig(const uint8_t *, uint32_t size)
|
||||
|
||||
|
||||
cdef extern from "runtime_context.h" namespace "ray::streaming" nogil:
|
||||
cdef cppclass CRuntimeContext "ray::streaming::RuntimeContext":
|
||||
CRuntimeContext()
|
||||
void SetConfig(const uint8_t *data, uint32_t size)
|
||||
inline void MarkMockTest()
|
||||
inline c_bool IsMockTest()
|
||||
|
||||
cdef extern from "message/message.h" namespace "ray::streaming" nogil:
|
||||
cdef cppclass CStreamingMessageType "ray::streaming::StreamingMessageType":
|
||||
pass
|
||||
cdef CStreamingMessageType MessageTypeBarrier "ray::streaming::StreamingMessageType::Barrier"
|
||||
cdef CStreamingMessageType MessageTypeMessage "ray::streaming::StreamingMessageType::Message"
|
||||
cdef cppclass CStreamingMessage "ray::streaming::StreamingMessage":
|
||||
inline uint8_t *RawData() const
|
||||
inline uint32_t GetDataSize() const
|
||||
inline CStreamingMessageType GetMessageType() const
|
||||
inline uint64_t GetMessageSeqId() const
|
||||
|
||||
cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil:
|
||||
cdef cppclass CStreamingMessageBundleType "ray::streaming::StreamingMessageBundleType":
|
||||
pass
|
||||
cdef CStreamingMessageBundleType BundleTypeEmpty "ray::streaming::StreamingMessageBundleType::Empty"
|
||||
cdef CStreamingMessageBundleType BundleTypeBarrier "ray::streaming::StreamingMessageBundleType::Barrier"
|
||||
cdef CStreamingMessageBundleType BundleTypeBundle "ray::streaming::StreamingMessageBundleType::Bundle"
|
||||
|
||||
cdef cppclass CStreamingMessageBundleMeta "ray::streaming::StreamingMessageBundleMeta":
|
||||
CStreamingMessageBundleMeta()
|
||||
inline uint64_t GetMessageBundleTs() const
|
||||
inline uint64_t GetLastMessageId() const
|
||||
inline uint32_t GetMessageListSize() const
|
||||
inline CStreamingMessageBundleType GetBundleType() const
|
||||
inline c_bool IsBarrier()
|
||||
inline c_bool IsBundle()
|
||||
|
||||
ctypedef shared_ptr[CStreamingMessageBundleMeta] CStreamingMessageBundleMetaPtr
|
||||
uint32_t kMessageBundleHeaderSize "ray::streaming::kMessageBundleHeaderSize"
|
||||
cdef cppclass CStreamingMessageBundle "ray::streaming::StreamingMessageBundle"(CStreamingMessageBundleMeta):
|
||||
@staticmethod
|
||||
void GetMessageListFromRawData(const uint8_t *data, uint32_t size, uint32_t msg_nums,
|
||||
c_list[shared_ptr[CStreamingMessage]] &msg_list);
|
||||
|
||||
cdef extern from "queue/queue_client.h" namespace "ray::streaming" nogil:
|
||||
cdef cppclass CReaderClient "ray::streaming::ReaderClient":
|
||||
CReaderClient(CCoreWorker *core_worker,
|
||||
CRayFunction &async_func,
|
||||
CRayFunction &sync_func)
|
||||
void OnReaderMessage(shared_ptr[CLocalMemoryBuffer] buffer);
|
||||
shared_ptr[CLocalMemoryBuffer] OnReaderMessageSync(shared_ptr[CLocalMemoryBuffer] buffer);
|
||||
|
||||
cdef cppclass CWriterClient "ray::streaming::WriterClient":
|
||||
CWriterClient(CCoreWorker *core_worker,
|
||||
CRayFunction &async_func,
|
||||
CRayFunction &sync_func)
|
||||
void OnWriterMessage(shared_ptr[CLocalMemoryBuffer] buffer);
|
||||
shared_ptr[CLocalMemoryBuffer] OnWriterMessageSync(shared_ptr[CLocalMemoryBuffer] buffer);
|
||||
|
||||
|
||||
cdef extern from "data_reader.h" namespace "ray::streaming" nogil:
|
||||
cdef cppclass CDataBundle "ray::streaming::DataBundle":
|
||||
uint8_t *data
|
||||
uint32_t data_size
|
||||
CObjectID c_from "from"
|
||||
uint64_t seq_id
|
||||
CStreamingMessageBundleMetaPtr meta
|
||||
|
||||
cdef cppclass CDataReader "ray::streaming::DataReader"(CStreamingCommon):
|
||||
CDataReader(shared_ptr[CRuntimeContext] &runtime_context)
|
||||
void Init(const c_vector[CObjectID] &input_ids,
|
||||
const c_vector[CActorID] &actor_ids,
|
||||
const c_vector[uint64_t] &seq_ids,
|
||||
const c_vector[uint64_t] &msg_ids,
|
||||
int64_t timer_interval);
|
||||
CStreamingStatus GetBundle(const uint32_t timeout_ms,
|
||||
shared_ptr[CDataBundle] &message)
|
||||
void Stop()
|
||||
|
||||
|
||||
cdef extern from "data_writer.h" namespace "ray::streaming" nogil:
|
||||
cdef cppclass CDataWriter "ray::streaming::DataWriter"(CStreamingCommon):
|
||||
CDataWriter(shared_ptr[CRuntimeContext] &runtime_context)
|
||||
CStreamingStatus Init(const c_vector[CObjectID] &channel_ids,
|
||||
const c_vector[CActorID] &actor_ids,
|
||||
const c_vector[uint64_t] &message_ids,
|
||||
const c_vector[uint64_t] &queue_size_vec);
|
||||
long WriteMessageToBufferRing(
|
||||
const CObjectID &q_id, uint8_t *data, uint32_t data_size)
|
||||
void Run()
|
||||
void Stop()
|
||||
|
||||
|
||||
cdef extern from "ray/common/buffer.h" nogil:
|
||||
cdef cppclass CLocalMemoryBuffer "ray::LocalMemoryBuffer":
|
||||
uint8_t *Data() const
|
||||
size_t Size() const
|
323
streaming/python/includes/transfer.pxi
Normal file
323
streaming/python/includes/transfer.pxi
Normal file
|
@ -0,0 +1,323 @@
|
|||
# flake8: noqa
|
||||
|
||||
from libc.stdint cimport *
|
||||
from libcpp cimport bool as c_bool
|
||||
from libcpp.memory cimport shared_ptr, make_shared, dynamic_pointer_cast
|
||||
from libcpp.string cimport string as c_string
|
||||
from libcpp.vector cimport vector as c_vector
|
||||
from libcpp.list cimport list as c_list
|
||||
|
||||
from ray.includes.common cimport (
|
||||
CRayFunction,
|
||||
LANGUAGE_PYTHON,
|
||||
CBuffer
|
||||
)
|
||||
|
||||
from ray.includes.unique_ids cimport (
|
||||
CActorID,
|
||||
CObjectID
|
||||
)
|
||||
from ray._raylet cimport (
|
||||
Buffer,
|
||||
CoreWorker,
|
||||
ActorID,
|
||||
ObjectID,
|
||||
string_vector_from_list
|
||||
)
|
||||
|
||||
from ray.includes.libcoreworker cimport CCoreWorker
|
||||
|
||||
cimport ray.streaming.includes.libstreaming as libstreaming
|
||||
from ray.streaming.includes.libstreaming cimport (
|
||||
CStreamingStatus,
|
||||
CStreamingMessage,
|
||||
CStreamingMessageBundle,
|
||||
CRuntimeContext,
|
||||
CDataBundle,
|
||||
CDataWriter,
|
||||
CDataReader,
|
||||
CReaderClient,
|
||||
CWriterClient,
|
||||
CLocalMemoryBuffer,
|
||||
)
|
||||
|
||||
import logging
|
||||
from ray.function_manager import FunctionDescriptor
|
||||
|
||||
|
||||
channel_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
cdef class ReaderClient:
|
||||
cdef:
|
||||
CReaderClient *client
|
||||
|
||||
def __cinit__(self,
|
||||
CoreWorker worker,
|
||||
async_func: FunctionDescriptor,
|
||||
sync_func: FunctionDescriptor):
|
||||
cdef:
|
||||
CCoreWorker *core_worker = worker.core_worker.get()
|
||||
CRayFunction async_native_func
|
||||
CRayFunction sync_native_func
|
||||
async_native_func = CRayFunction(
|
||||
LANGUAGE_PYTHON, string_vector_from_list(async_func.get_function_descriptor_list()))
|
||||
sync_native_func = CRayFunction(
|
||||
LANGUAGE_PYTHON, string_vector_from_list(sync_func.get_function_descriptor_list()))
|
||||
self.client = new CReaderClient(core_worker, async_native_func, sync_native_func)
|
||||
|
||||
def __dealloc__(self):
|
||||
del self.client
|
||||
self.client = NULL
|
||||
|
||||
def on_reader_message(self, const unsigned char[:] value):
|
||||
cdef:
|
||||
size_t size = value.nbytes
|
||||
shared_ptr[CLocalMemoryBuffer] local_buf = \
|
||||
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
|
||||
with nogil:
|
||||
self.client.OnReaderMessage(local_buf)
|
||||
|
||||
def on_reader_message_sync(self, const unsigned char[:] value):
|
||||
cdef:
|
||||
size_t size = value.nbytes
|
||||
shared_ptr[CLocalMemoryBuffer] local_buf = \
|
||||
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
|
||||
shared_ptr[CLocalMemoryBuffer] result_buffer
|
||||
with nogil:
|
||||
result_buffer = self.client.OnReaderMessageSync(local_buf)
|
||||
return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer))
|
||||
|
||||
|
||||
cdef class WriterClient:
|
||||
cdef:
|
||||
CWriterClient * client
|
||||
|
||||
def __cinit__(self,
|
||||
CoreWorker worker,
|
||||
async_func: FunctionDescriptor,
|
||||
sync_func: FunctionDescriptor):
|
||||
cdef:
|
||||
CCoreWorker *core_worker = worker.core_worker.get()
|
||||
CRayFunction async_native_func
|
||||
CRayFunction sync_native_func
|
||||
async_native_func = CRayFunction(
|
||||
LANGUAGE_PYTHON, string_vector_from_list(async_func.get_function_descriptor_list()))
|
||||
sync_native_func = CRayFunction(
|
||||
LANGUAGE_PYTHON, string_vector_from_list(sync_func.get_function_descriptor_list()))
|
||||
self.client = new CWriterClient(core_worker, async_native_func, sync_native_func)
|
||||
|
||||
def __dealloc__(self):
|
||||
del self.client
|
||||
self.client = NULL
|
||||
|
||||
def on_writer_message(self, const unsigned char[:] value):
|
||||
cdef:
|
||||
size_t size = value.nbytes
|
||||
shared_ptr[CLocalMemoryBuffer] local_buf = \
|
||||
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
|
||||
with nogil:
|
||||
self.client.OnWriterMessage(local_buf)
|
||||
|
||||
def on_writer_message_sync(self, const unsigned char[:] value):
|
||||
cdef:
|
||||
size_t size = value.nbytes
|
||||
shared_ptr[CLocalMemoryBuffer] local_buf = \
|
||||
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
|
||||
shared_ptr[CLocalMemoryBuffer] result_buffer
|
||||
with nogil:
|
||||
result_buffer = self.client.OnWriterMessageSync(local_buf)
|
||||
return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer))
|
||||
|
||||
|
||||
cdef class DataWriter:
|
||||
cdef:
|
||||
CDataWriter *writer
|
||||
|
||||
def __init__(self):
|
||||
raise Exception("use create() to create DataWriter")
|
||||
|
||||
@staticmethod
|
||||
def create(list py_output_channels,
|
||||
list output_actor_ids: list[ActorID],
|
||||
uint64_t queue_size,
|
||||
list py_msg_ids,
|
||||
bytes config_bytes,
|
||||
c_bool is_mock):
|
||||
cdef:
|
||||
c_vector[CObjectID] channel_ids = bytes_list_to_qid_vec(py_output_channels)
|
||||
c_vector[CActorID] actor_ids
|
||||
c_vector[uint64_t] msg_ids
|
||||
CDataWriter *c_writer
|
||||
cdef const unsigned char[:] config_data
|
||||
for actor_id in output_actor_ids:
|
||||
actor_ids.push_back((<ActorID>actor_id).data)
|
||||
for py_msg_id in py_msg_ids:
|
||||
msg_ids.push_back(<uint64_t>py_msg_id)
|
||||
|
||||
cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]()
|
||||
if is_mock:
|
||||
ctx.get().MarkMockTest()
|
||||
if config_bytes:
|
||||
config_data = config_bytes
|
||||
channel_logger.info("load config, config bytes size: %s", config_data.nbytes)
|
||||
ctx.get().SetConfig(<uint8_t *>(&config_data[0]), config_data.nbytes)
|
||||
c_writer = new CDataWriter(ctx)
|
||||
cdef:
|
||||
c_vector[CObjectID] remain_id_vec
|
||||
c_vector[uint64_t] queue_size_vec
|
||||
for i in range(channel_ids.size()):
|
||||
queue_size_vec.push_back(queue_size)
|
||||
cdef CStreamingStatus status = c_writer.Init(channel_ids, actor_ids, msg_ids, queue_size_vec)
|
||||
if remain_id_vec.size() != 0:
|
||||
channel_logger.warning("failed queue amounts => %s", remain_id_vec.size())
|
||||
if <uint32_t>status != <uint32_t> libstreaming.StatusOK:
|
||||
msg = "initialize writer failed, status={}".format(<uint32_t>status)
|
||||
channel_logger.error(msg)
|
||||
del c_writer
|
||||
import ray.streaming.runtime.transfer as transfer
|
||||
raise transfer.ChannelInitException(msg, qid_vector_to_list(remain_id_vec))
|
||||
|
||||
c_writer.Run()
|
||||
channel_logger.info("create native writer succeed")
|
||||
cdef DataWriter writer = DataWriter.__new__(DataWriter)
|
||||
writer.writer = c_writer
|
||||
return writer
|
||||
|
||||
def __dealloc__(self):
|
||||
if self.writer != NULL:
|
||||
del self.writer
|
||||
channel_logger.info("deleted DataWriter")
|
||||
self.writer = NULL
|
||||
|
||||
def write(self, ObjectID qid, const unsigned char[:] value):
|
||||
"""support zero-copy bytes, bytearray, array of unsigned char"""
|
||||
cdef:
|
||||
CObjectID native_id = qid.data
|
||||
uint64_t msg_id
|
||||
uint8_t *data = <uint8_t *>(&value[0])
|
||||
uint32_t size = value.nbytes
|
||||
with nogil:
|
||||
msg_id = self.writer.WriteMessageToBufferRing(native_id, data, size)
|
||||
return msg_id
|
||||
|
||||
def stop(self):
|
||||
self.writer.Stop()
|
||||
channel_logger.info("stopped DataWriter")
|
||||
|
||||
|
||||
cdef class DataReader:
|
||||
cdef:
|
||||
CDataReader *reader
|
||||
readonly bytes meta
|
||||
readonly bytes data
|
||||
|
||||
def __init__(self):
|
||||
raise Exception("use create() to create DataReader")
|
||||
|
||||
@staticmethod
|
||||
def create(list py_input_queues,
|
||||
list input_actor_ids: list[ActorID],
|
||||
list py_seq_ids,
|
||||
list py_msg_ids,
|
||||
int64_t timer_interval,
|
||||
c_bool is_recreate,
|
||||
bytes config_bytes,
|
||||
c_bool is_mock):
|
||||
cdef:
|
||||
c_vector[CObjectID] queue_id_vec = bytes_list_to_qid_vec(py_input_queues)
|
||||
c_vector[CActorID] actor_ids
|
||||
c_vector[uint64_t] seq_ids
|
||||
c_vector[uint64_t] msg_ids
|
||||
CDataReader *c_reader
|
||||
cdef const unsigned char[:] config_data
|
||||
for actor_id in input_actor_ids:
|
||||
actor_ids.push_back((<ActorID>actor_id).data)
|
||||
for py_seq_id in py_seq_ids:
|
||||
seq_ids.push_back(<uint64_t>py_seq_id)
|
||||
for py_msg_id in py_msg_ids:
|
||||
msg_ids.push_back(<uint64_t>py_msg_id)
|
||||
cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]()
|
||||
if config_bytes:
|
||||
config_data = config_bytes
|
||||
channel_logger.info("load config, config bytes size: %s", config_data.nbytes)
|
||||
ctx.get().SetConfig(<uint8_t *>(&(config_data[0])), config_data.nbytes)
|
||||
if is_mock:
|
||||
ctx.get().MarkMockTest()
|
||||
c_reader = new CDataReader(ctx)
|
||||
c_reader.Init(queue_id_vec, actor_ids, seq_ids, msg_ids, timer_interval)
|
||||
channel_logger.info("create native reader succeed")
|
||||
cdef DataReader reader = DataReader.__new__(DataReader)
|
||||
reader.reader = c_reader
|
||||
return reader
|
||||
|
||||
def __dealloc__(self):
|
||||
if self.reader != NULL:
|
||||
del self.reader
|
||||
channel_logger.info("deleted DataReader")
|
||||
self.reader = NULL
|
||||
|
||||
def read(self, uint32_t timeout_millis):
|
||||
cdef:
|
||||
shared_ptr[CDataBundle] bundle
|
||||
CStreamingStatus status
|
||||
with nogil:
|
||||
status = self.reader.GetBundle(timeout_millis, bundle)
|
||||
cdef uint32_t bundle_type = <uint32_t>(bundle.get().meta.get().GetBundleType())
|
||||
if <uint32_t> status != <uint32_t> libstreaming.StatusOK:
|
||||
if <uint32_t> status == <uint32_t> libstreaming.StatusInterrupted:
|
||||
# avoid cyclic import
|
||||
import ray.streaming.runtime.transfer as transfer
|
||||
raise transfer.ChannelInterruptException("reader interrupted")
|
||||
elif <uint32_t> status == <uint32_t> libstreaming.StatusInitQueueFailed:
|
||||
raise Exception("init channel failed")
|
||||
elif <uint32_t> status == <uint32_t> libstreaming.StatusWaitQueueTimeOut:
|
||||
raise Exception("wait channel object timeout")
|
||||
cdef:
|
||||
uint32_t msg_nums
|
||||
CObjectID queue_id
|
||||
c_list[shared_ptr[CStreamingMessage]] msg_list
|
||||
list msgs = []
|
||||
uint64_t timestamp
|
||||
uint64_t msg_id
|
||||
if bundle_type == <uint32_t> libstreaming.BundleTypeBundle:
|
||||
msg_nums = bundle.get().meta.get().GetMessageListSize()
|
||||
CStreamingMessageBundle.GetMessageListFromRawData(
|
||||
bundle.get().data + libstreaming.kMessageBundleHeaderSize,
|
||||
bundle.get().data_size - libstreaming.kMessageBundleHeaderSize,
|
||||
msg_nums,
|
||||
msg_list)
|
||||
timestamp = bundle.get().meta.get().GetMessageBundleTs()
|
||||
for msg in msg_list:
|
||||
msg_bytes = msg.get().RawData()[:msg.get().GetDataSize()]
|
||||
qid_bytes = queue_id.Binary()
|
||||
msg_id = msg.get().GetMessageSeqId()
|
||||
msgs.append((msg_bytes, msg_id, timestamp, qid_bytes))
|
||||
return msgs
|
||||
elif bundle_type == <uint32_t> libstreaming.BundleTypeEmpty:
|
||||
return []
|
||||
else:
|
||||
raise Exception("Unsupported bundle type {}".format(bundle_type))
|
||||
|
||||
def stop(self):
|
||||
self.reader.Stop()
|
||||
channel_logger.info("stopped DataReader")
|
||||
|
||||
|
||||
cdef c_vector[CObjectID] bytes_list_to_qid_vec(list py_queue_ids) except *:
|
||||
assert len(py_queue_ids) > 0
|
||||
cdef:
|
||||
c_vector[CObjectID] queue_id_vec
|
||||
c_string q_id_data
|
||||
for q_id in py_queue_ids:
|
||||
q_id_data = q_id
|
||||
assert q_id_data.size() == CObjectID.Size()
|
||||
obj_id = CObjectID.FromBinary(q_id_data)
|
||||
queue_id_vec.push_back(obj_id)
|
||||
return queue_id_vec
|
||||
|
||||
cdef c_vector[c_string] qid_vector_to_list(c_vector[CObjectID] queue_id_vec):
|
||||
queues = []
|
||||
for obj_id in queue_id_vec:
|
||||
queues.append(obj_id.Binary())
|
||||
return queues
|
124
streaming/python/jobworker.py
Normal file
124
streaming/python/jobworker.py
Normal file
|
@ -0,0 +1,124 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
import threading
|
||||
|
||||
import ray
|
||||
import ray.streaming._streaming as _streaming
|
||||
from ray.streaming.config import Config
|
||||
from ray.function_manager import FunctionDescriptor
|
||||
from ray.streaming.communication import DataInput, DataOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ray.remote
|
||||
class JobWorker(object):
|
||||
"""A streaming job worker.
|
||||
|
||||
Attributes:
|
||||
worker_id: The id of the instance.
|
||||
input_channels: The input gate that manages input channels of
|
||||
the instance (see: DataInput in communication.py).
|
||||
output_channels (DataOutput): The output gate that manages output
|
||||
channels of the instance (see: DataOutput in communication.py).
|
||||
the operator instance.
|
||||
"""
|
||||
|
||||
def __init__(self, worker_id, operator, input_channels, output_channels):
|
||||
self.env = None
|
||||
self.worker_id = worker_id
|
||||
self.operator = operator
|
||||
processor_name = operator.processor_class.__name__
|
||||
processor_instance = operator.processor_class(operator)
|
||||
self.processor_name = processor_name
|
||||
self.processor_instance = processor_instance
|
||||
self.input_channels = input_channels
|
||||
self.output_channels = output_channels
|
||||
self.input_gate = None
|
||||
self.output_gate = None
|
||||
self.reader_client = None
|
||||
self.writer_client = None
|
||||
|
||||
def init(self, env):
|
||||
"""init streaming actor"""
|
||||
env = pickle.loads(env)
|
||||
self.env = env
|
||||
logger.info("init operator instance %s", self.processor_name)
|
||||
|
||||
if env.config.channel_type == Config.NATIVE_CHANNEL:
|
||||
core_worker = ray.worker.global_worker.core_worker
|
||||
reader_async_func = FunctionDescriptor(
|
||||
__name__, self.on_reader_message.__name__,
|
||||
self.__class__.__name__)
|
||||
reader_sync_func = FunctionDescriptor(
|
||||
__name__, self.on_reader_message_sync.__name__,
|
||||
self.__class__.__name__)
|
||||
self.reader_client = _streaming.ReaderClient(
|
||||
core_worker, reader_async_func, reader_sync_func)
|
||||
writer_async_func = FunctionDescriptor(
|
||||
__name__, self.on_writer_message.__name__,
|
||||
self.__class__.__name__)
|
||||
writer_sync_func = FunctionDescriptor(
|
||||
__name__, self.on_writer_message_sync.__name__,
|
||||
self.__class__.__name__)
|
||||
self.writer_client = _streaming.WriterClient(
|
||||
core_worker, writer_async_func, writer_sync_func)
|
||||
if len(self.input_channels) > 0:
|
||||
self.input_gate = DataInput(env, self.input_channels)
|
||||
self.input_gate.init()
|
||||
if len(self.output_channels) > 0:
|
||||
self.output_gate = DataOutput(
|
||||
env, self.output_channels,
|
||||
self.operator.partitioning_strategies)
|
||||
self.output_gate.init()
|
||||
logger.info("init operator instance %s succeed", self.processor_name)
|
||||
return True
|
||||
|
||||
# Starts the actor
|
||||
def start(self):
|
||||
self.t = threading.Thread(target=self.run, daemon=True)
|
||||
self.t.start()
|
||||
actor_id = ray.worker.global_worker.actor_id
|
||||
logger.info("%s %s started, actor id %s", self.__class__.__name__,
|
||||
self.processor_name, actor_id)
|
||||
|
||||
def run(self):
|
||||
logger.info("%s start running", self.processor_name)
|
||||
self.processor_instance.run(self.input_gate, self.output_gate)
|
||||
logger.info("%s finished running", self.processor_name)
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
if self.input_gate:
|
||||
self.input_gate.close()
|
||||
if self.output_gate:
|
||||
self.output_gate.close()
|
||||
|
||||
def is_finished(self):
|
||||
return not self.t.is_alive()
|
||||
|
||||
def on_reader_message(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
self.reader_client.on_reader_message(buffer)
|
||||
|
||||
def on_reader_message_sync(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
if self.reader_client is None:
|
||||
return b" " * 4 # special flag to indicate this actor not ready
|
||||
result = self.reader_client.on_reader_message_sync(buffer)
|
||||
return result.to_pybytes()
|
||||
|
||||
def on_writer_message(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
self.writer_client.on_writer_message(buffer)
|
||||
|
||||
def on_writer_message_sync(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
if self.writer_client is None:
|
||||
return b" " * 4 # special flag to indicate this actor not ready
|
||||
result = self.writer_client.on_writer_message_sync(buffer)
|
||||
return result.to_pybytes()
|
|
@ -5,6 +5,8 @@ from __future__ import print_function
|
|||
import enum
|
||||
import logging
|
||||
|
||||
import cloudpickle
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel("DEBUG")
|
||||
|
||||
|
@ -52,16 +54,18 @@ class OpType(enum.Enum):
|
|||
class Operator(object):
|
||||
def __init__(self,
|
||||
id,
|
||||
type,
|
||||
op_type,
|
||||
processor_class,
|
||||
name="",
|
||||
logic=None,
|
||||
num_instances=1,
|
||||
other=None,
|
||||
state_actor=None):
|
||||
self.id = id
|
||||
self.type = type
|
||||
self.type = op_type
|
||||
self.processor_class = processor_class
|
||||
self.name = name
|
||||
self.logic = logic # The operator's logic
|
||||
self._logic = cloudpickle.dumps(logic) # The operator's logic
|
||||
self.num_instances = num_instances
|
||||
# One partitioning strategy per downstream operator (default: forward)
|
||||
self.partitioning_strategies = {}
|
||||
|
@ -96,10 +100,14 @@ class Operator(object):
|
|||
self.partitioning_strategies = strategies
|
||||
|
||||
def print(self):
|
||||
log = "Operator<\nID = {}\nName = {}\nType = {}\n"
|
||||
log = "Operator<\nID = {}\nName = {}\nprocessor_class = {}\n"
|
||||
log += "Logic = {}\nNumber_of_Instances = {}\n"
|
||||
log += "Partitioning_Scheme = {}\nOther_Args = {}>\n"
|
||||
logger.debug(
|
||||
log.format(self.id, self.name, self.type, self.logic,
|
||||
log.format(self.id, self.name, self.processor_class, self.logic,
|
||||
self.num_instances, self.partitioning_strategies,
|
||||
self.other_args))
|
||||
|
||||
@property
|
||||
def logic(self):
|
||||
return cloudpickle.loads(self._logic)
|
226
streaming/python/processor.py
Normal file
226
streaming/python/processor.py
Normal file
|
@ -0,0 +1,226 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel("INFO")
|
||||
|
||||
|
||||
def _identity(element):
|
||||
return element
|
||||
|
||||
|
||||
class ReadTextFile:
|
||||
"""A source operator instance that reads a text file line by line.
|
||||
|
||||
Attributes:
|
||||
filepath (string): The path to the input file.
|
||||
"""
|
||||
|
||||
def __init__(self, operator):
|
||||
self.filepath = operator.other_args
|
||||
# TODO (john): Handle possible exception here
|
||||
self.reader = open(self.filepath, "r")
|
||||
|
||||
# Read input file line by line
|
||||
def run(self, input_gate, output_gate):
|
||||
while True:
|
||||
record = self.reader.readline()
|
||||
# Reader returns empty string ('') on EOF
|
||||
if not record:
|
||||
self.reader.close()
|
||||
return
|
||||
output_gate.push(
|
||||
record[:-1]) # Push after removing newline characters
|
||||
|
||||
|
||||
class Map:
|
||||
"""A map operator instance that applies a user-defined
|
||||
stream transformation.
|
||||
|
||||
A map produces exactly one output record for each record in
|
||||
the input stream.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, operator):
|
||||
self.map_fn = operator.logic
|
||||
|
||||
# Applies the mapper each record of the input stream(s)
|
||||
# and pushes resulting records to the output stream(s)
|
||||
def run(self, input_gate, output_gate):
|
||||
elements = 0
|
||||
while True:
|
||||
record = input_gate.pull()
|
||||
if record is None:
|
||||
return
|
||||
output_gate.push(self.map_fn(record))
|
||||
elements += 1
|
||||
|
||||
|
||||
class FlatMap:
|
||||
"""A map operator instance that applies a user-defined
|
||||
stream transformation.
|
||||
|
||||
A flatmap produces one or more output records for each record in
|
||||
the input stream.
|
||||
|
||||
Attributes:
|
||||
flatmap_fn (function): The user-defined function.
|
||||
"""
|
||||
|
||||
def __init__(self, operator):
|
||||
self.flatmap_fn = operator.logic
|
||||
|
||||
# Applies the splitter to the records of the input stream(s)
|
||||
# and pushes resulting records to the output stream(s)
|
||||
def run(self, input_gate, output_gate):
|
||||
while True:
|
||||
record = input_gate.pull()
|
||||
if record is None:
|
||||
return
|
||||
output_gate.push_all(self.flatmap_fn(record))
|
||||
|
||||
|
||||
class Filter:
|
||||
"""A filter operator instance that applies a user-defined filter to
|
||||
each record of the stream.
|
||||
|
||||
Output records are those that pass the filter, i.e. those for which
|
||||
the filter function returns True.
|
||||
|
||||
Attributes:
|
||||
filter_fn (function): The user-defined boolean function.
|
||||
"""
|
||||
|
||||
def __init__(self, operator):
|
||||
self.filter_fn = operator.logic
|
||||
|
||||
# Applies the filter to the records of the input stream(s)
|
||||
# and pushes resulting records to the output stream(s)
|
||||
def run(self, input_gate, output_gate):
|
||||
while True:
|
||||
record = input_gate.pull()
|
||||
if record is None:
|
||||
return
|
||||
if self.filter_fn(record):
|
||||
output_gate.push(record)
|
||||
|
||||
|
||||
class Inspect:
|
||||
"""A inspect operator instance that inspects the content of the stream.
|
||||
Inspect is useful for printing the records in the stream.
|
||||
"""
|
||||
|
||||
def __init__(self, operator):
|
||||
self.inspect_fn = operator.logic
|
||||
|
||||
def run(self, input_gate, output_gate):
|
||||
# Applies the inspect logic (e.g. print) to the records of
|
||||
# the input stream(s)
|
||||
# and leaves stream unaffected by simply pushing the records to
|
||||
# the output stream(s)
|
||||
while True:
|
||||
record = input_gate.pull()
|
||||
if record is None:
|
||||
return
|
||||
if output_gate:
|
||||
output_gate.push(record)
|
||||
self.inspect_fn(record)
|
||||
|
||||
|
||||
class Reduce:
|
||||
"""A reduce operator instance that combines a new value for a key
|
||||
with the last reduced one according to a user-defined logic.
|
||||
"""
|
||||
|
||||
def __init__(self, operator):
|
||||
self.reduce_fn = operator.logic
|
||||
# Set the attribute selector
|
||||
self.attribute_selector = operator.other_args
|
||||
if self.attribute_selector is None:
|
||||
self.attribute_selector = _identity
|
||||
elif isinstance(self.attribute_selector, int):
|
||||
self.key_index = self.attribute_selector
|
||||
self.attribute_selector =\
|
||||
lambda record: record[self.attribute_selector]
|
||||
elif isinstance(self.attribute_selector, str):
|
||||
self.attribute_selector =\
|
||||
lambda record: vars(record)[self.attribute_selector]
|
||||
elif not isinstance(self.attribute_selector, types.FunctionType):
|
||||
sys.exit("Unrecognized or unsupported key selector.")
|
||||
self.state = {} # key -> value
|
||||
|
||||
# Combines the input value for a key with the last reduced
|
||||
# value for that key to produce a new value.
|
||||
# Outputs the result as (key,new value)
|
||||
def run(self, input_gate, output_gate):
|
||||
while True:
|
||||
record = input_gate.pull()
|
||||
if record is None:
|
||||
return
|
||||
key, rest = record
|
||||
new_value = self.attribute_selector(rest)
|
||||
# TODO (john): Is there a way to update state with
|
||||
# a single dictionary lookup?
|
||||
try:
|
||||
old_value = self.state[key]
|
||||
new_value = self.reduce_fn(old_value, new_value)
|
||||
self.state[key] = new_value
|
||||
except KeyError: # Key does not exist in state
|
||||
self.state.setdefault(key, new_value)
|
||||
output_gate.push((key, new_value))
|
||||
|
||||
# Returns the state of the actor
|
||||
def get_state(self):
|
||||
return self.state
|
||||
|
||||
|
||||
class KeyBy:
|
||||
"""A key_by operator instance that physically partitions the
|
||||
stream based on a key.
|
||||
"""
|
||||
|
||||
def __init__(self, operator):
|
||||
# Set the key selector
|
||||
self.key_selector = operator.other_args
|
||||
if isinstance(self.key_selector, int):
|
||||
self.key_selector = lambda r: r[self.key_selector]
|
||||
elif isinstance(self.key_selector, str):
|
||||
self.key_selector = lambda record: vars(record)[self.key_selector]
|
||||
elif not isinstance(self.key_selector, types.FunctionType):
|
||||
sys.exit("Unrecognized or unsupported key selector.")
|
||||
|
||||
# The actual partitioning is done by the output gate
|
||||
def run(self, input_gate, output_gate):
|
||||
while True:
|
||||
record = input_gate.pull()
|
||||
if record is None:
|
||||
return
|
||||
key = self.key_selector(record)
|
||||
output_gate.push((key, record))
|
||||
|
||||
|
||||
# A custom source actor
|
||||
class Source:
|
||||
def __init__(self, operator):
|
||||
# The user-defined source with a get_next() method
|
||||
self.source = operator.logic
|
||||
|
||||
# Starts the source by calling get_next() repeatedly
|
||||
def run(self, input_gate, output_gate):
|
||||
start = time.time()
|
||||
elements = 0
|
||||
while True:
|
||||
record = self.source.get_next()
|
||||
if not record:
|
||||
logger.debug("[writer] puts per second: {}".format(
|
||||
elements / (time.time() - start)))
|
||||
return
|
||||
output_gate.push(record)
|
||||
elements += 1
|
0
streaming/python/runtime/__init__.py
Normal file
0
streaming/python/runtime/__init__.py
Normal file
291
streaming/python/runtime/transfer.py
Normal file
291
streaming/python/runtime/transfer.py
Normal file
|
@ -0,0 +1,291 @@
|
|||
import logging
|
||||
import random
|
||||
from queue import Queue
|
||||
from typing import List
|
||||
|
||||
import ray
|
||||
import ray.streaming._streaming as _streaming
|
||||
import ray.streaming.generated.streaming_pb2 as streaming_pb
|
||||
from ray.actor import ActorHandle, ActorID
|
||||
from ray.streaming.config import Config
|
||||
|
||||
CHANNEL_ID_LEN = 20
|
||||
|
||||
|
||||
class ChannelID:
|
||||
"""
|
||||
ChannelID is used to identify a transfer channel between
|
||||
a upstream worker and downstream worker.
|
||||
"""
|
||||
|
||||
def __init__(self, channel_id_str: str):
|
||||
"""
|
||||
Args:
|
||||
channel_id_str: string representation of channel id
|
||||
"""
|
||||
self.channel_id_str = channel_id_str
|
||||
self.object_qid = ray.ObjectID(channel_id_str_to_bytes(channel_id_str))
|
||||
|
||||
def __eq__(self, other):
|
||||
if other is None:
|
||||
return False
|
||||
if type(other) is ChannelID:
|
||||
return self.channel_id_str == other.channel_id_str
|
||||
else:
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.channel_id_str)
|
||||
|
||||
def __repr__(self):
|
||||
return self.channel_id_str
|
||||
|
||||
@staticmethod
|
||||
def gen_random_id():
|
||||
"""Generate a random channel id string
|
||||
"""
|
||||
res = ""
|
||||
for i in range(CHANNEL_ID_LEN * 2):
|
||||
res += str(chr(random.randint(0, 5) + ord("A")))
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def gen_id(from_index, to_index, ts):
|
||||
"""Generate channel id, which is 20 character"""
|
||||
channel_id = bytearray(20)
|
||||
for i in range(11, 7, -1):
|
||||
channel_id[i] = ts & 0xff
|
||||
ts >>= 8
|
||||
channel_id[16] = (from_index & 0xffff) >> 8
|
||||
channel_id[17] = (from_index & 0xff)
|
||||
channel_id[18] = (to_index & 0xffff) >> 8
|
||||
channel_id[19] = (to_index & 0xff)
|
||||
return channel_bytes_to_str(bytes(channel_id))
|
||||
|
||||
|
||||
def channel_id_str_to_bytes(channel_id_str):
|
||||
"""
|
||||
Args:
|
||||
channel_id_str: string representation of channel id
|
||||
|
||||
Returns:
|
||||
bytes representation of channel id
|
||||
"""
|
||||
assert type(channel_id_str) in [str, bytes]
|
||||
if isinstance(channel_id_str, bytes):
|
||||
return channel_id_str
|
||||
qid_bytes = bytes.fromhex(channel_id_str)
|
||||
assert len(qid_bytes) == CHANNEL_ID_LEN
|
||||
return qid_bytes
|
||||
|
||||
|
||||
def channel_bytes_to_str(id_bytes):
|
||||
"""
|
||||
Args:
|
||||
id_bytes: bytes representation of channel id
|
||||
|
||||
Returns:
|
||||
string representation of channel id
|
||||
"""
|
||||
assert type(id_bytes) in [str, bytes]
|
||||
if isinstance(id_bytes, str):
|
||||
return id_bytes
|
||||
return bytes.hex(id_bytes)
|
||||
|
||||
|
||||
class DataMessage:
|
||||
"""
|
||||
DataMessage represents data between upstream and downstream operator
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
body,
|
||||
timestamp,
|
||||
channel_id,
|
||||
message_id_,
|
||||
is_empty_message=False):
|
||||
self.__body = body
|
||||
self.__timestamp = timestamp
|
||||
self.__channel_id = channel_id
|
||||
self.__message_id = message_id_
|
||||
self.__is_empty_message = is_empty_message
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__body)
|
||||
|
||||
def body(self):
|
||||
"""Message data"""
|
||||
return self.__body
|
||||
|
||||
def timestamp(self):
|
||||
"""Get timestamp when item is written by upstream DataWriter
|
||||
"""
|
||||
return self.__timestamp
|
||||
|
||||
def channel_id(self):
|
||||
"""Get string id of channel where data is coming from
|
||||
"""
|
||||
return self.__channel_id
|
||||
|
||||
def is_empty_message(self):
|
||||
"""Whether this message is an empty message.
|
||||
Upstream DataWriter will send an empty message when this is no data
|
||||
in specified interval.
|
||||
"""
|
||||
return self.__is_empty_message
|
||||
|
||||
@property
|
||||
def message_id(self):
|
||||
return self.__message_id
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataWriter:
|
||||
"""Data Writer is a wrapper of streaming c++ DataWriter, which sends data
|
||||
to downstream workers
|
||||
"""
|
||||
|
||||
def __init__(self, output_channels, to_actors: List[ActorHandle],
|
||||
conf: dict):
|
||||
"""Get DataWriter of output channels
|
||||
Args:
|
||||
output_channels: output channels ids
|
||||
to_actors: downstream output actors
|
||||
Returns:
|
||||
DataWriter
|
||||
"""
|
||||
assert len(output_channels) > 0
|
||||
py_output_channels = [
|
||||
channel_id_str_to_bytes(qid_str) for qid_str in output_channels
|
||||
]
|
||||
output_actor_ids: List[ActorID] = [
|
||||
handle._ray_actor_id for handle in to_actors
|
||||
]
|
||||
channel_size = conf.get(Config.CHANNEL_SIZE,
|
||||
Config.CHANNEL_SIZE_DEFAULT)
|
||||
py_msg_ids = [0 for _ in range(len(output_channels))]
|
||||
config_bytes = _to_native_conf(conf)
|
||||
is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL
|
||||
self.writer = _streaming.DataWriter.create(
|
||||
py_output_channels, output_actor_ids, channel_size, py_msg_ids,
|
||||
config_bytes, is_mock)
|
||||
|
||||
logger.info("create DataWriter succeed")
|
||||
|
||||
def write(self, channel_id: ChannelID, item: bytes):
|
||||
"""Write data into native channel
|
||||
Args:
|
||||
channel_id: channel id
|
||||
item: bytes data
|
||||
Returns:
|
||||
msg_id
|
||||
"""
|
||||
assert type(item) == bytes
|
||||
msg_id = self.writer.write(channel_id.object_qid, item)
|
||||
return msg_id
|
||||
|
||||
def stop(self):
|
||||
logger.info("stopping channel writer.")
|
||||
self.writer.stop()
|
||||
# destruct DataWriter
|
||||
self.writer = None
|
||||
|
||||
def close(self):
|
||||
logger.info("closing channel writer.")
|
||||
|
||||
|
||||
class DataReader:
|
||||
"""Data Reader is wrapper of streaming c++ DataReader, which read data
|
||||
from channels of upstream workers
|
||||
"""
|
||||
|
||||
def __init__(self, input_channels: List, from_actors: List[ActorHandle],
|
||||
conf: dict):
|
||||
"""Get DataReader of input channels
|
||||
Args:
|
||||
input_channels: input channels
|
||||
from_actors: upstream input actors
|
||||
Returns:
|
||||
DataReader
|
||||
"""
|
||||
assert len(input_channels) > 0
|
||||
py_input_channels = [
|
||||
channel_id_str_to_bytes(qid_str) for qid_str in input_channels
|
||||
]
|
||||
input_actor_ids: List[ActorID] = [
|
||||
handle._ray_actor_id for handle in from_actors
|
||||
]
|
||||
py_seq_ids = [0 for _ in range(len(input_channels))]
|
||||
py_msg_ids = [0 for _ in range(len(input_channels))]
|
||||
timer_interval = int(conf.get(Config.TIMER_INTERVAL_MS, -1))
|
||||
is_recreate = bool(conf.get(Config.IS_RECREATE, False))
|
||||
config_bytes = _to_native_conf(conf)
|
||||
self.__queue = Queue(10000)
|
||||
is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL
|
||||
self.reader = _streaming.DataReader.create(
|
||||
py_input_channels, input_actor_ids, py_seq_ids, py_msg_ids,
|
||||
timer_interval, is_recreate, config_bytes, is_mock)
|
||||
logger.info("create DataReader succeed")
|
||||
|
||||
def read(self, timeout_millis):
|
||||
"""Read data from channel
|
||||
Args:
|
||||
timeout_millis: timeout millis when there is no data in channel
|
||||
for this duration
|
||||
Returns:
|
||||
channel item
|
||||
"""
|
||||
if self.__queue.empty():
|
||||
msgs = self.reader.read(timeout_millis)
|
||||
for msg in msgs:
|
||||
msg_bytes, msg_id, timestamp, qid_bytes = msg
|
||||
data_msg = DataMessage(msg_bytes, timestamp,
|
||||
channel_bytes_to_str(qid_bytes), msg_id)
|
||||
self.__queue.put(data_msg)
|
||||
if self.__queue.empty():
|
||||
return None
|
||||
return self.__queue.get()
|
||||
|
||||
def stop(self):
|
||||
logger.info("stopping Data Reader.")
|
||||
self.reader.stop()
|
||||
# destruct DataReader
|
||||
self.reader = None
|
||||
|
||||
def close(self):
|
||||
logger.info("closing Data Reader.")
|
||||
|
||||
|
||||
def _to_native_conf(conf):
|
||||
config = streaming_pb.StreamingConfig()
|
||||
if Config.STREAMING_JOB_NAME in conf:
|
||||
config.job_name = conf[Config.STREAMING_JOB_NAME]
|
||||
if Config.TASK_JOB_ID in conf:
|
||||
job_id = conf[Config.TASK_JOB_ID]
|
||||
config.task_job_id = job_id.hex()
|
||||
if Config.STREAMING_WORKER_NAME in conf:
|
||||
config.worker_name = conf[Config.STREAMING_WORKER_NAME]
|
||||
if Config.STREAMING_OP_NAME in conf:
|
||||
config.op_name = conf[Config.STREAMING_OP_NAME]
|
||||
# TODO set operator type
|
||||
if Config.STREAMING_RING_BUFFER_CAPACITY in conf:
|
||||
config.ring_buffer_capacity = \
|
||||
conf[Config.STREAMING_RING_BUFFER_CAPACITY]
|
||||
if Config.STREAMING_EMPTY_MESSAGE_INTERVAL in conf:
|
||||
config.empty_message_interval = \
|
||||
conf[Config.STREAMING_EMPTY_MESSAGE_INTERVAL]
|
||||
logger.info("conf: %s", str(config))
|
||||
return config.SerializeToString()
|
||||
|
||||
|
||||
class ChannelInitException(Exception):
|
||||
def __init__(self, msg, abnormal_channels):
|
||||
self.abnormal_channels = abnormal_channels
|
||||
self.msg = msg
|
||||
|
||||
|
||||
class ChannelInterruptException(Exception):
|
||||
def __init__(self, msg=None):
|
||||
self.msg = msg
|
|
@ -3,26 +3,24 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
import sys
|
||||
import uuid
|
||||
import time
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ray.experimental.streaming.communication import DataChannel, DataInput
|
||||
from ray.experimental.streaming.communication import DataOutput, QueueConfig
|
||||
from ray.experimental.streaming.operator import Operator, OpType
|
||||
from ray.experimental.streaming.operator import PScheme, PStrategy
|
||||
import ray.experimental.streaming.operator_instance as operator_instance
|
||||
import ray
|
||||
import ray.streaming.processor as processor
|
||||
import ray.streaming.runtime.transfer as transfer
|
||||
from ray.streaming.communication import DataChannel
|
||||
from ray.streaming.config import Config
|
||||
from ray.streaming.jobworker import JobWorker
|
||||
from ray.streaming.operator import Operator, OpType
|
||||
from ray.streaming.operator import PScheme, PStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel("INFO")
|
||||
|
||||
|
||||
# Generates UUIDs
|
||||
def _generate_uuid():
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
# Rolling sum's logic
|
||||
def _sum(value_1, value_2):
|
||||
return value_1 + value_2
|
||||
|
@ -36,28 +34,224 @@ all_to_all_strategies = [
|
|||
|
||||
|
||||
# Environment configuration
|
||||
class Config(object):
|
||||
class Conf(object):
|
||||
"""Environment configuration.
|
||||
|
||||
This class includes all information about the configuration of the
|
||||
streaming environment.
|
||||
|
||||
Attributes:
|
||||
queue_config (QueueConfig): Batched Queue configuration
|
||||
(see: communication.py)
|
||||
A batched queue configuration includes the max queue size,
|
||||
the size of each batch (in number of elements), the batch flush
|
||||
timeout, and the number of batches to prefetch from plasma
|
||||
parallelism (int): The number of isntances (actors) for each logical
|
||||
dataflow operator (default: 1)
|
||||
"""
|
||||
|
||||
def __init__(self, parallelism=1):
|
||||
self.queue_config = QueueConfig()
|
||||
def __init__(self, parallelism=1, channel_type=Config.MEMORY_CHANNEL):
|
||||
self.parallelism = parallelism
|
||||
self.channel_type = channel_type
|
||||
# ...
|
||||
|
||||
|
||||
class ExecutionGraph:
|
||||
def __init__(self, env):
|
||||
self.env = env
|
||||
self.physical_topo = nx.DiGraph() # DAG
|
||||
# Handles to all actors in the physical dataflow
|
||||
self.actor_handles = []
|
||||
# (op_id, op_instance_index) -> ActorID
|
||||
self.actors_map = {}
|
||||
# execution graph build time: milliseconds since epoch
|
||||
self.build_time = 0
|
||||
self.task_id_counter = 0
|
||||
self.task_ids = {}
|
||||
self.input_channels = {} # operator id -> input channels
|
||||
self.output_channels = {} # operator id -> output channels
|
||||
|
||||
# Constructs and deploys a Ray actor of a specific type
|
||||
# TODO (john): Actor placement information should be specified in
|
||||
# the environment's configuration
|
||||
def __generate_actor(self, instance_index, operator, input_channels,
|
||||
output_channels):
|
||||
"""Generates an actor that will execute a particular instance of
|
||||
the logical operator
|
||||
|
||||
Attributes:
|
||||
instance_index: The index of the instance the actor will execute.
|
||||
operator: The metadata of the logical operator.
|
||||
input_channels: The input channels of the instance.
|
||||
output_channels The output channels of the instance.
|
||||
"""
|
||||
worker_id = (operator.id, instance_index)
|
||||
# Record the physical dataflow graph (for debugging purposes)
|
||||
self.__add_channel(worker_id, output_channels)
|
||||
# Note direct_call only support pass by value
|
||||
return JobWorker._remote(
|
||||
args=[worker_id, operator, input_channels, output_channels],
|
||||
is_direct_call=True)
|
||||
|
||||
# Constructs and deploys a Ray actor for each instance of
|
||||
# the given operator
|
||||
def __generate_actors(self, operator, upstream_channels,
|
||||
downstream_channels):
|
||||
"""Generates one actor for each instance of the given logical
|
||||
operator.
|
||||
|
||||
Attributes:
|
||||
operator (Operator): The logical operator metadata.
|
||||
upstream_channels (list): A list of all upstream channels for
|
||||
all instances of the operator.
|
||||
downstream_channels (list): A list of all downstream channels
|
||||
for all instances of the operator.
|
||||
"""
|
||||
num_instances = operator.num_instances
|
||||
logger.info("Generating {} actors of type {}...".format(
|
||||
num_instances, operator.type))
|
||||
handles = []
|
||||
for i in range(num_instances):
|
||||
# Collect input and output channels for the particular instance
|
||||
ip = [c for c in upstream_channels if c.dst_instance_index == i]
|
||||
op = [c for c in downstream_channels if c.src_instance_index == i]
|
||||
log = "Constructed {} input and {} output channels "
|
||||
log += "for the {}-th instance of the {} operator."
|
||||
logger.debug(log.format(len(ip), len(op), i, operator.type))
|
||||
handle = self.__generate_actor(i, operator, ip, op)
|
||||
if handle:
|
||||
handles.append(handle)
|
||||
self.actors_map[(operator.id, i)] = handle
|
||||
return handles
|
||||
|
||||
# Adds a channel/edge to the physical dataflow graph
|
||||
def __add_channel(self, actor_id, output_channels):
|
||||
for c in output_channels:
|
||||
dest_actor_id = (c.dst_operator_id, c.dst_instance_index)
|
||||
self.physical_topo.add_edge(actor_id, dest_actor_id)
|
||||
|
||||
# Generates all required data channels between an operator
|
||||
# and its downstream operators
|
||||
def _generate_channels(self, operator):
|
||||
"""Generates all output data channels
|
||||
(see: DataChannel in communication.py) for all instances of
|
||||
the given logical operator.
|
||||
|
||||
The function constructs one data channel for each pair of
|
||||
communicating operator instances (instance_1,instance_2),
|
||||
where instance_1 is an instance of the given operator and instance_2
|
||||
is an instance of a direct downstream operator.
|
||||
|
||||
The number of total channels generated depends on the partitioning
|
||||
strategy specified by the user.
|
||||
"""
|
||||
channels = {} # destination operator id -> channels
|
||||
strategies = operator.partitioning_strategies
|
||||
for dst_operator, p_scheme in strategies.items():
|
||||
num_dest_instances = self.env.operators[dst_operator].num_instances
|
||||
entry = channels.setdefault(dst_operator, [])
|
||||
if p_scheme.strategy == PStrategy.Forward:
|
||||
for i in range(operator.num_instances):
|
||||
# ID of destination instance to connect
|
||||
id = i % num_dest_instances
|
||||
qid = self._gen_str_qid(operator.id, i, dst_operator, id)
|
||||
c = DataChannel(operator.id, i, dst_operator, id, qid)
|
||||
entry.append(c)
|
||||
elif p_scheme.strategy in all_to_all_strategies:
|
||||
for i in range(operator.num_instances):
|
||||
for j in range(num_dest_instances):
|
||||
qid = self._gen_str_qid(operator.id, i, dst_operator,
|
||||
j)
|
||||
c = DataChannel(operator.id, i, dst_operator, j, qid)
|
||||
entry.append(c)
|
||||
else:
|
||||
# TODO (john): Add support for other partitioning strategies
|
||||
sys.exit("Unrecognized or unsupported partitioning strategy.")
|
||||
return channels
|
||||
|
||||
def _gen_str_qid(self, src_operator_id, src_instance_index,
|
||||
dst_operator_id, dst_instance_index):
|
||||
from_task_id = self.env.execution_graph.get_task_id(
|
||||
src_operator_id, src_instance_index)
|
||||
to_task_id = self.env.execution_graph.get_task_id(
|
||||
dst_operator_id, dst_instance_index)
|
||||
return transfer.ChannelID.gen_id(from_task_id, to_task_id,
|
||||
self.build_time)
|
||||
|
||||
def _gen_task_id(self):
|
||||
task_id = self.task_id_counter
|
||||
self.task_id_counter += 1
|
||||
return task_id
|
||||
|
||||
def get_task_id(self, op_id, op_instance_id):
|
||||
return self.task_ids[(op_id, op_instance_id)]
|
||||
|
||||
def get_actor(self, op_id, op_instance_id):
|
||||
return self.actors_map[(op_id, op_instance_id)]
|
||||
|
||||
# Prints the physical dataflow graph
|
||||
def print_physical_graph(self):
|
||||
logger.info("===================================")
|
||||
logger.info("======Physical Dataflow Graph======")
|
||||
logger.info("===================================")
|
||||
# Print all data channels between operator instances
|
||||
log = "(Source Operator ID,Source Operator Name,Source Instance ID)"
|
||||
log += " --> "
|
||||
log += "(Destination Operator ID,Destination Operator Name,"
|
||||
log += "Destination Instance ID)"
|
||||
logger.info(log)
|
||||
for src_actor_id, dst_actor_id in self.physical_topo.edges:
|
||||
src_operator_id, src_instance_index = src_actor_id
|
||||
dst_operator_id, dst_instance_index = dst_actor_id
|
||||
logger.info("({},{},{}) --> ({},{},{})".format(
|
||||
src_operator_id, self.env.operators[src_operator_id].name,
|
||||
src_instance_index, dst_operator_id,
|
||||
self.env.operators[dst_operator_id].name, dst_instance_index))
|
||||
|
||||
def build_graph(self):
|
||||
self.build_channels()
|
||||
|
||||
# to support cyclic reference serialization
|
||||
try:
|
||||
ray.register_custom_serializer(Environment, use_pickle=True)
|
||||
ray.register_custom_serializer(ExecutionGraph, use_pickle=True)
|
||||
ray.register_custom_serializer(OpType, use_pickle=True)
|
||||
ray.register_custom_serializer(PStrategy, use_pickle=True)
|
||||
except Exception:
|
||||
# local mode can't use pickle
|
||||
pass
|
||||
|
||||
# Each operator instance is implemented as a Ray actor
|
||||
# Actors are deployed in topological order, as we traverse the
|
||||
# logical dataflow from sources to sinks.
|
||||
for node in nx.topological_sort(self.env.logical_topo):
|
||||
operator = self.env.operators[node]
|
||||
# Instantiate Ray actors
|
||||
handles = self.__generate_actors(
|
||||
operator, self.input_channels.get(node, []),
|
||||
self.output_channels.get(node, []))
|
||||
if handles:
|
||||
self.actor_handles.extend(handles)
|
||||
|
||||
def build_channels(self):
|
||||
self.build_time = int(time.time() * 1000)
|
||||
# gen auto-incremented unique task id for every operator instance
|
||||
for node in nx.topological_sort(self.env.logical_topo):
|
||||
operator = self.env.operators[node]
|
||||
for i in range(operator.num_instances):
|
||||
operator_instance_id = (operator.id, i)
|
||||
self.task_ids[operator_instance_id] = self._gen_task_id()
|
||||
channels = {}
|
||||
for node in nx.topological_sort(self.env.logical_topo):
|
||||
operator = self.env.operators[node]
|
||||
# Generate downstream data channels
|
||||
downstream_channels = self._generate_channels(operator)
|
||||
channels[node] = downstream_channels
|
||||
# op_id -> channels
|
||||
input_channels = {}
|
||||
output_channels = {}
|
||||
for op_id, all_downstream_channels in channels.items():
|
||||
for dst_op_channels in all_downstream_channels.values():
|
||||
for c in dst_op_channels:
|
||||
dst = input_channels.setdefault(c.dst_operator_id, [])
|
||||
dst.append(c)
|
||||
src = output_channels.setdefault(c.src_operator_id, [])
|
||||
src.append(c)
|
||||
self.input_channels = input_channels
|
||||
self.output_channels = output_channels
|
||||
|
||||
|
||||
# The execution environment for a streaming job
|
||||
class Environment(object):
|
||||
"""A streaming environment.
|
||||
|
@ -81,176 +275,18 @@ class Environment(object):
|
|||
the streaming dataflow.
|
||||
"""
|
||||
|
||||
def __init__(self, config=Config()):
|
||||
def __init__(self, config=Conf()):
|
||||
self.logical_topo = nx.DiGraph() # DAG
|
||||
self.physical_topo = nx.DiGraph() # DAG
|
||||
self.operators = {} # operator id --> operator object
|
||||
self.config = config # Environment's configuration
|
||||
self.topo_cleaned = False
|
||||
# Handles to all actors in the physical dataflow
|
||||
self.actor_handles = []
|
||||
self.operator_id_counter = 0
|
||||
self.execution_graph = None # set when executed
|
||||
|
||||
# Constructs and deploys a Ray actor of a specific type
|
||||
# TODO (john): Actor placement information should be specified in
|
||||
# the environment's configuration
|
||||
def __generate_actor(self, instance_id, operator, input, output):
|
||||
"""Generates an actor that will execute a particular instance of
|
||||
the logical operator
|
||||
|
||||
Attributes:
|
||||
instance_id (UUID): The id of the instance the actor will execute.
|
||||
operator (Operator): The metadata of the logical operator.
|
||||
input (DataInput): The input gate that manages input channels of
|
||||
the instance (see: DataInput in communication.py).
|
||||
input (DataOutput): The output gate that manages output channels
|
||||
of the instance (see: DataOutput in communication.py).
|
||||
"""
|
||||
actor_id = (operator.id, instance_id)
|
||||
# Record the physical dataflow graph (for debugging purposes)
|
||||
self.__add_channel(actor_id, input, output)
|
||||
# Select actor to construct
|
||||
if operator.type == OpType.Source:
|
||||
source = operator_instance.Source.remote(actor_id, operator, input,
|
||||
output)
|
||||
source.register_handle.remote(source)
|
||||
return source.start.remote()
|
||||
elif operator.type == OpType.Map:
|
||||
map = operator_instance.Map.remote(actor_id, operator, input,
|
||||
output)
|
||||
map.register_handle.remote(map)
|
||||
return map.start.remote()
|
||||
elif operator.type == OpType.FlatMap:
|
||||
flatmap = operator_instance.FlatMap.remote(actor_id, operator,
|
||||
input, output)
|
||||
flatmap.register_handle.remote(flatmap)
|
||||
return flatmap.start.remote()
|
||||
elif operator.type == OpType.Filter:
|
||||
filter = operator_instance.Filter.remote(actor_id, operator, input,
|
||||
output)
|
||||
filter.register_handle.remote(filter)
|
||||
return filter.start.remote()
|
||||
elif operator.type == OpType.Reduce:
|
||||
reduce = operator_instance.Reduce.remote(actor_id, operator, input,
|
||||
output)
|
||||
reduce.register_handle.remote(reduce)
|
||||
return reduce.start.remote()
|
||||
elif operator.type == OpType.TimeWindow:
|
||||
pass
|
||||
elif operator.type == OpType.KeyBy:
|
||||
keyby = operator_instance.KeyBy.remote(actor_id, operator, input,
|
||||
output)
|
||||
keyby.register_handle.remote(keyby)
|
||||
return keyby.start.remote()
|
||||
elif operator.type == OpType.Sum:
|
||||
sum = operator_instance.Reduce.remote(actor_id, operator, input,
|
||||
output)
|
||||
# Register target handle at state actor
|
||||
state_actor = operator.state_actor
|
||||
if state_actor is not None:
|
||||
state_actor.register_target.remote(sum)
|
||||
# Register own handle
|
||||
sum.register_handle.remote(sum)
|
||||
return sum.start.remote()
|
||||
elif operator.type == OpType.Sink:
|
||||
pass
|
||||
elif operator.type == OpType.Inspect:
|
||||
inspect = operator_instance.Inspect.remote(actor_id, operator,
|
||||
input, output)
|
||||
inspect.register_handle.remote(inspect)
|
||||
return inspect.start.remote()
|
||||
elif operator.type == OpType.ReadTextFile:
|
||||
# TODO (john): Colocate the source with the input file
|
||||
read = operator_instance.ReadTextFile.remote(
|
||||
actor_id, operator, input, output)
|
||||
read.register_handle.remote(read)
|
||||
return read.start.remote()
|
||||
else: # TODO (john): Add support for other types of operators
|
||||
sys.exit("Unrecognized or unsupported {} operator type.".format(
|
||||
operator.type))
|
||||
|
||||
# Constructs and deploys a Ray actor for each instance of
|
||||
# the given operator
|
||||
def __generate_actors(self, operator, upstream_channels,
|
||||
downstream_channels):
|
||||
"""Generates one actor for each instance of the given logical
|
||||
operator.
|
||||
|
||||
Attributes:
|
||||
operator (Operator): The logical operator metadata.
|
||||
upstream_channels (list): A list of all upstream channels for
|
||||
all instances of the operator.
|
||||
downstream_channels (list): A list of all downstream channels
|
||||
for all instances of the operator.
|
||||
"""
|
||||
num_instances = operator.num_instances
|
||||
logger.info("Generating {} actors of type {}...".format(
|
||||
num_instances, operator.type))
|
||||
in_channels = upstream_channels.pop(
|
||||
operator.id) if upstream_channels else []
|
||||
handles = []
|
||||
for i in range(num_instances):
|
||||
# Collect input and output channels for the particular instance
|
||||
ip = [
|
||||
channel for channel in in_channels
|
||||
if channel.dst_instance_id == i
|
||||
] if in_channels else []
|
||||
op = [
|
||||
channel for channels_list in downstream_channels.values()
|
||||
for channel in channels_list if channel.src_instance_id == i
|
||||
]
|
||||
log = "Constructed {} input and {} output channels "
|
||||
log += "for the {}-th instance of the {} operator."
|
||||
logger.debug(log.format(len(ip), len(op), i, operator.type))
|
||||
input_gate = DataInput(ip)
|
||||
output_gate = DataOutput(op, operator.partitioning_strategies)
|
||||
handle = self.__generate_actor(i, operator, input_gate,
|
||||
output_gate)
|
||||
if handle:
|
||||
handles.append(handle)
|
||||
return handles
|
||||
|
||||
# Adds a channel/edge to the physical dataflow graph
|
||||
def __add_channel(self, actor_id, input, output):
|
||||
for dest_actor_id in output._destination_actor_ids():
|
||||
self.physical_topo.add_edge(actor_id, dest_actor_id)
|
||||
|
||||
# Generates all required data channels between an operator
|
||||
# and its downstream operators
|
||||
def _generate_channels(self, operator):
|
||||
"""Generates all output data channels
|
||||
(see: DataChannel in communication.py) for all instances of
|
||||
the given logical operator.
|
||||
|
||||
The function constructs one data channel for each pair of
|
||||
communicating operator instances (instance_1,instance_2),
|
||||
where instance_1 is an instance of the given operator and instance_2
|
||||
is an instance of a direct downstream operator.
|
||||
|
||||
The number of total channels generated depends on the partitioning
|
||||
strategy specified by the user.
|
||||
"""
|
||||
channels = {} # destination operator id -> channels
|
||||
strategies = operator.partitioning_strategies
|
||||
for dst_operator, p_scheme in strategies.items():
|
||||
num_dest_instances = self.operators[dst_operator].num_instances
|
||||
entry = channels.setdefault(dst_operator, [])
|
||||
if p_scheme.strategy == PStrategy.Forward:
|
||||
for i in range(operator.num_instances):
|
||||
# ID of destination instance to connect
|
||||
id = i % num_dest_instances
|
||||
channel = DataChannel(self, operator.id, dst_operator, i,
|
||||
id)
|
||||
entry.append(channel)
|
||||
elif p_scheme.strategy in all_to_all_strategies:
|
||||
for i in range(operator.num_instances):
|
||||
for j in range(num_dest_instances):
|
||||
channel = DataChannel(self, operator.id, dst_operator,
|
||||
i, j)
|
||||
entry.append(channel)
|
||||
else:
|
||||
# TODO (john): Add support for other partitioning strategies
|
||||
sys.exit("Unrecognized or unsupported partitioning strategy.")
|
||||
return channels
|
||||
def gen_operator_id(self):
|
||||
op_id = self.operator_id_counter
|
||||
self.operator_id_counter += 1
|
||||
return op_id
|
||||
|
||||
# An edge denotes a flow of data between logical operators
|
||||
# and may correspond to multiple data channels in the physical dataflow
|
||||
|
@ -275,19 +311,15 @@ class Environment(object):
|
|||
def set_parallelism(self, parallelism):
|
||||
self.config.parallelism = parallelism
|
||||
|
||||
# Sets batched queue configuration for the environment
|
||||
def set_queue_config(self, queue_config):
|
||||
self.config.queue_config = queue_config
|
||||
|
||||
# Creates and registers a user-defined data source
|
||||
# TODO (john): There should be different types of sources, e.g. sources
|
||||
# reading from Kafka, text files, etc.
|
||||
# TODO (john): Handle case where environment parallelism is set
|
||||
def source(self, source):
|
||||
source_id = _generate_uuid()
|
||||
source_id = self.gen_operator_id()
|
||||
source_stream = DataStream(self, source_id)
|
||||
self.operators[source_id] = Operator(
|
||||
source_id, OpType.Source, "Source", other=source)
|
||||
source_id, OpType.Source, processor.Source, "Source", logic=source)
|
||||
return source_stream
|
||||
|
||||
# Creates and registers a new data source that reads a
|
||||
|
@ -296,10 +328,14 @@ class Environment(object):
|
|||
# e.g. sources reading from Kafka, text files, etc.
|
||||
# TODO (john): Handle case where environment parallelism is set
|
||||
def read_text_file(self, filepath):
|
||||
source_id = _generate_uuid()
|
||||
source_id = self.gen_operator_id()
|
||||
source_stream = DataStream(self, source_id)
|
||||
self.operators[source_id] = Operator(
|
||||
source_id, OpType.ReadTextFile, "Read Text File", other=filepath)
|
||||
source_id,
|
||||
OpType.ReadTextFile,
|
||||
processor.ReadTextFile,
|
||||
"Read Text File",
|
||||
other=filepath)
|
||||
return source_stream
|
||||
|
||||
# Constructs and deploys the physical dataflow
|
||||
|
@ -312,24 +348,27 @@ class Environment(object):
|
|||
# upstream instances, some of the downstream instances will not be
|
||||
# used at all
|
||||
|
||||
# Each operator instance is implemented as a Ray actor
|
||||
# Actors are deployed in topological order, as we traverse the
|
||||
# logical dataflow from sources to sinks. At each step, data
|
||||
# producers wait for acknowledge from consumers before starting
|
||||
# generating data.
|
||||
upstream_channels = {}
|
||||
for node in nx.topological_sort(self.logical_topo):
|
||||
operator = self.operators[node]
|
||||
# Generate downstream data channels
|
||||
downstream_channels = self._generate_channels(operator)
|
||||
# Instantiate Ray actors
|
||||
handles = self.__generate_actors(operator, upstream_channels,
|
||||
downstream_channels)
|
||||
if handles:
|
||||
self.actor_handles.extend(handles)
|
||||
upstream_channels.update(downstream_channels)
|
||||
logger.debug("Running...")
|
||||
return self.actor_handles
|
||||
self.execution_graph = ExecutionGraph(self)
|
||||
self.execution_graph.build_graph()
|
||||
logger.info("init...")
|
||||
# init
|
||||
init_waits = []
|
||||
for actor_handle in self.execution_graph.actor_handles:
|
||||
init_waits.append(actor_handle.init.remote(pickle.dumps(self)))
|
||||
for wait in init_waits:
|
||||
assert ray.get(wait) is True
|
||||
logger.info("running...")
|
||||
# start
|
||||
exec_handles = []
|
||||
for actor_handle in self.execution_graph.actor_handles:
|
||||
exec_handles.append(actor_handle.start.remote())
|
||||
|
||||
return exec_handles
|
||||
|
||||
def wait_finish(self):
|
||||
for actor_handle in self.execution_graph.actor_handles:
|
||||
while not ray.get(actor_handle.is_finished.remote()):
|
||||
time.sleep(1)
|
||||
|
||||
# Prints the logical dataflow graph
|
||||
def print_logical_graph(self):
|
||||
|
@ -349,25 +388,6 @@ class Environment(object):
|
|||
for downstream_node in downstream_neighbors:
|
||||
self.operators[downstream_node].print()
|
||||
|
||||
# Prints the physical dataflow graph
|
||||
def print_physical_graph(self):
|
||||
logger.info("===================================")
|
||||
logger.info("======Physical Dataflow Graph======")
|
||||
logger.info("===================================")
|
||||
# Print all data channels between operator instances
|
||||
log = "(Source Operator ID,Source Operator Name,Source Instance ID)"
|
||||
log += " --> "
|
||||
log += "(Destination Operator ID,Destination Operator Name,"
|
||||
log += "Destination Instance ID)"
|
||||
logger.info(log)
|
||||
for src_actor_id, dst_actor_id in self.physical_topo.edges:
|
||||
src_operator_id, src_instance_id = src_actor_id
|
||||
dst_operator_id, dst_instance_id = dst_actor_id
|
||||
logger.info("({},{},{}) --> ({},{},{})".format(
|
||||
src_operator_id, self.operators[src_operator_id].name,
|
||||
src_instance_id, dst_operator_id,
|
||||
self.operators[dst_operator_id].name, dst_instance_id))
|
||||
|
||||
|
||||
# TODO (john): We also need KeyedDataStream and WindowedDataStream as
|
||||
# subclasses of DataStream to prevent ill-defined logical dataflows
|
||||
|
@ -389,14 +409,16 @@ class DataStream(object):
|
|||
is_partitioned (bool): Denotes if there is a partitioning strategy
|
||||
(e.g. shuffle) for the stream or not (default stategy: Forward).
|
||||
"""
|
||||
stream_id_counter = 0
|
||||
|
||||
def __init__(self,
|
||||
environment,
|
||||
source_id=None,
|
||||
dest_id=None,
|
||||
is_partitioned=False):
|
||||
self.id = _generate_uuid()
|
||||
self.env = environment
|
||||
self.id = DataStream.stream_id_counter
|
||||
DataStream.stream_id_counter += 1
|
||||
self.src_operator_id = source_id
|
||||
self.dst_operator_id = dest_id
|
||||
# True if a partitioning strategy for this stream exists,
|
||||
|
@ -448,17 +470,17 @@ class DataStream(object):
|
|||
src_operator = self.env.operators[self.src_operator_id]
|
||||
if self.is_partitioned is True:
|
||||
partitioning, _ = src_operator._get_partition_strategy(self.id)
|
||||
src_operator._set_partition_strategy(_generate_uuid(),
|
||||
partitioning, operator.id)
|
||||
src_operator._set_partition_strategy(self.id, partitioning,
|
||||
operator.id)
|
||||
elif src_operator.type == OpType.KeyBy:
|
||||
# Set the output partitioning strategy to shuffle by key
|
||||
partitioning = PScheme(PStrategy.ShuffleByKey)
|
||||
src_operator._set_partition_strategy(_generate_uuid(),
|
||||
partitioning, operator.id)
|
||||
src_operator._set_partition_strategy(self.id, partitioning,
|
||||
operator.id)
|
||||
else: # No partitioning strategy has been defined - set default
|
||||
partitioning = PScheme(PStrategy.Forward)
|
||||
src_operator._set_partition_strategy(_generate_uuid(),
|
||||
partitioning, operator.id)
|
||||
src_operator._set_partition_strategy(self.id, partitioning,
|
||||
operator.id)
|
||||
return self.__expand()
|
||||
|
||||
# Sets the level of parallelism for an operator, i.e. its total
|
||||
|
@ -525,8 +547,9 @@ class DataStream(object):
|
|||
map_fn (function): The user-defined logic of the map.
|
||||
"""
|
||||
op = Operator(
|
||||
_generate_uuid(),
|
||||
self.env.gen_operator_id(),
|
||||
OpType.Map,
|
||||
processor.Map,
|
||||
name,
|
||||
map_fn,
|
||||
num_instances=self.env.config.parallelism)
|
||||
|
@ -541,8 +564,9 @@ class DataStream(object):
|
|||
(e.g. split()).
|
||||
"""
|
||||
op = Operator(
|
||||
_generate_uuid(),
|
||||
self.env.gen_operator_id(),
|
||||
OpType.FlatMap,
|
||||
processor.FlatMap,
|
||||
"FlatMap",
|
||||
flatmap_fn,
|
||||
num_instances=self.env.config.parallelism)
|
||||
|
@ -558,8 +582,9 @@ class DataStream(object):
|
|||
(assuming tuple records).
|
||||
"""
|
||||
op = Operator(
|
||||
_generate_uuid(),
|
||||
self.env.gen_operator_id(),
|
||||
OpType.KeyBy,
|
||||
processor.KeyBy,
|
||||
"KeyBy",
|
||||
other=key_selector,
|
||||
num_instances=self.env.config.parallelism)
|
||||
|
@ -574,8 +599,9 @@ class DataStream(object):
|
|||
(assuming tuple records).
|
||||
"""
|
||||
op = Operator(
|
||||
_generate_uuid(),
|
||||
self.env.gen_operator_id(),
|
||||
OpType.Reduce,
|
||||
processor.Reduce,
|
||||
"Sum",
|
||||
reduce_fn,
|
||||
num_instances=self.env.config.parallelism)
|
||||
|
@ -590,8 +616,9 @@ class DataStream(object):
|
|||
(assuming tuple records).
|
||||
"""
|
||||
op = Operator(
|
||||
_generate_uuid(),
|
||||
self.env.gen_operator_id(),
|
||||
OpType.Sum,
|
||||
processor.Reduce,
|
||||
"Sum",
|
||||
_sum,
|
||||
other=attribute_selector,
|
||||
|
@ -608,13 +635,7 @@ class DataStream(object):
|
|||
Attributes:
|
||||
window_width_ms (int): The length of the window in ms.
|
||||
"""
|
||||
op = Operator(
|
||||
_generate_uuid(),
|
||||
OpType.TimeWindow,
|
||||
"TimeWindow",
|
||||
num_instances=self.env.config.parallelism,
|
||||
other=window_width_ms)
|
||||
return self.__register(op)
|
||||
raise Exception("time_window is unsupported")
|
||||
|
||||
# Registers filter operator to the environment
|
||||
def filter(self, filter_fn):
|
||||
|
@ -624,8 +645,9 @@ class DataStream(object):
|
|||
filter_fn (function): The user-defined filter function.
|
||||
"""
|
||||
op = Operator(
|
||||
_generate_uuid(),
|
||||
self.env.gen_operator_id(),
|
||||
OpType.Filter,
|
||||
processor.Filter,
|
||||
"Filter",
|
||||
filter_fn,
|
||||
num_instances=self.env.config.parallelism)
|
||||
|
@ -634,8 +656,9 @@ class DataStream(object):
|
|||
# TODO (john): Registers window join operator to the environment
|
||||
def window_join(self, other_stream, join_attribute, window_width):
|
||||
op = Operator(
|
||||
_generate_uuid(),
|
||||
self.env.gen_operator_id(),
|
||||
OpType.WindowJoin,
|
||||
processor.WindowJoin,
|
||||
"WindowJoin",
|
||||
num_instances=self.env.config.parallelism)
|
||||
return self.__register(op)
|
||||
|
@ -648,8 +671,9 @@ class DataStream(object):
|
|||
inspect_logic (function): The user-defined inspect function.
|
||||
"""
|
||||
op = Operator(
|
||||
_generate_uuid(),
|
||||
self.env.gen_operator_id(),
|
||||
OpType.Inspect,
|
||||
processor.Inspect,
|
||||
"Inspect",
|
||||
inspect_logic,
|
||||
num_instances=self.env.config.parallelism)
|
||||
|
@ -661,8 +685,9 @@ class DataStream(object):
|
|||
def sink(self):
|
||||
"""Closes the stream with a sink operator."""
|
||||
op = Operator(
|
||||
_generate_uuid(),
|
||||
self.env.gen_operator_id(),
|
||||
OpType.Sink,
|
||||
processor.Sink,
|
||||
"Sink",
|
||||
num_instances=self.env.config.parallelism)
|
||||
return self.__register(op)
|
0
streaming/python/tests/__init__.py
Normal file
0
streaming/python/tests/__init__.py
Normal file
127
streaming/python/tests/test_direct_transfer.py
Normal file
127
streaming/python/tests/test_direct_transfer.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
import pickle
|
||||
import threading
|
||||
import time
|
||||
|
||||
import ray
|
||||
import ray.streaming._streaming as _streaming
|
||||
import ray.streaming.runtime.transfer as transfer
|
||||
from ray.function_manager import FunctionDescriptor
|
||||
from ray.streaming.config import Config
|
||||
|
||||
|
||||
@ray.remote
|
||||
class Worker:
|
||||
def __init__(self):
|
||||
core_worker = ray.worker.global_worker.core_worker
|
||||
writer_async_func = FunctionDescriptor(
|
||||
__name__, self.on_writer_message.__name__, self.__class__.__name__)
|
||||
writer_sync_func = FunctionDescriptor(
|
||||
__name__, self.on_writer_message_sync.__name__,
|
||||
self.__class__.__name__)
|
||||
self.writer_client = _streaming.WriterClient(
|
||||
core_worker, writer_async_func, writer_sync_func)
|
||||
reader_async_func = FunctionDescriptor(
|
||||
__name__, self.on_reader_message.__name__, self.__class__.__name__)
|
||||
reader_sync_func = FunctionDescriptor(
|
||||
__name__, self.on_reader_message_sync.__name__,
|
||||
self.__class__.__name__)
|
||||
self.reader_client = _streaming.ReaderClient(
|
||||
core_worker, reader_async_func, reader_sync_func)
|
||||
self.writer = None
|
||||
self.output_channel_id = None
|
||||
self.reader = None
|
||||
|
||||
def init_writer(self, output_channel, reader_actor):
|
||||
conf = {
|
||||
Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context()
|
||||
.current_driver_id,
|
||||
Config.CHANNEL_TYPE: Config.NATIVE_CHANNEL
|
||||
}
|
||||
self.writer = transfer.DataWriter([output_channel],
|
||||
[pickle.loads(reader_actor)], conf)
|
||||
self.output_channel_id = transfer.ChannelID(output_channel)
|
||||
|
||||
def init_reader(self, input_channel, writer_actor):
|
||||
conf = {
|
||||
Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context()
|
||||
.current_driver_id,
|
||||
Config.CHANNEL_TYPE: Config.NATIVE_CHANNEL
|
||||
}
|
||||
self.reader = transfer.DataReader([input_channel],
|
||||
[pickle.loads(writer_actor)], conf)
|
||||
|
||||
def start_write(self, msg_nums):
|
||||
self.t = threading.Thread(
|
||||
target=self.run_writer, args=[msg_nums], daemon=True)
|
||||
self.t.start()
|
||||
|
||||
def run_writer(self, msg_nums):
|
||||
for i in range(msg_nums):
|
||||
self.writer.write(self.output_channel_id, pickle.dumps(i))
|
||||
print("WriterWorker done.")
|
||||
|
||||
def start_read(self, msg_nums):
|
||||
self.t = threading.Thread(
|
||||
target=self.run_reader, args=[msg_nums], daemon=True)
|
||||
self.t.start()
|
||||
|
||||
def run_reader(self, msg_nums):
|
||||
count = 0
|
||||
msg = None
|
||||
while count != msg_nums:
|
||||
item = self.reader.read(100)
|
||||
if item is None:
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
msg = pickle.loads(item.body())
|
||||
count += 1
|
||||
assert msg == msg_nums - 1
|
||||
print("ReaderWorker done.")
|
||||
|
||||
def is_finished(self):
|
||||
return not self.t.is_alive()
|
||||
|
||||
def on_reader_message(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
self.reader_client.on_reader_message(buffer)
|
||||
|
||||
def on_reader_message_sync(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
if self.reader_client is None:
|
||||
return b" " * 4 # special flag to indicate this actor not ready
|
||||
result = self.reader_client.on_reader_message_sync(buffer)
|
||||
return result.to_pybytes()
|
||||
|
||||
def on_writer_message(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
self.writer_client.on_writer_message(buffer)
|
||||
|
||||
def on_writer_message_sync(self, buffer: bytes):
|
||||
"""used in direct call mode"""
|
||||
if self.writer_client is None:
|
||||
return b" " * 4 # special flag to indicate this actor not ready
|
||||
result = self.writer_client.on_writer_message_sync(buffer)
|
||||
return result.to_pybytes()
|
||||
|
||||
|
||||
def test_queue():
|
||||
ray.init()
|
||||
writer = Worker._remote(is_direct_call=True)
|
||||
reader = Worker._remote(is_direct_call=True)
|
||||
channel_id_str = transfer.ChannelID.gen_random_id()
|
||||
inits = [
|
||||
writer.init_writer.remote(channel_id_str, pickle.dumps(reader)),
|
||||
reader.init_reader.remote(channel_id_str, pickle.dumps(writer))
|
||||
]
|
||||
ray.get(inits)
|
||||
msg_nums = 1000
|
||||
print("start read/write")
|
||||
reader.start_read.remote(msg_nums)
|
||||
writer.start_write.remote(msg_nums)
|
||||
while not ray.get(reader.is_finished.remote()):
|
||||
time.sleep(0.1)
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_queue()
|
|
@ -2,8 +2,8 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.experimental.streaming.streaming import Environment
|
||||
from ray.experimental.streaming.operator import OpType, PStrategy
|
||||
from ray.streaming.streaming import Environment, ExecutionGraph
|
||||
from ray.streaming.operator import OpType, PStrategy
|
||||
|
||||
|
||||
def test_parallelism():
|
||||
|
@ -169,18 +169,20 @@ def _test_channels(environment, expected_channels):
|
|||
if operator.type == OpType.Map:
|
||||
map_id = id
|
||||
# Collect channels
|
||||
environment.execution_graph = ExecutionGraph(environment)
|
||||
environment.execution_graph.build_channels()
|
||||
channels_per_destination = []
|
||||
for operator in environment.operators.values():
|
||||
channels_per_destination.append(
|
||||
environment._generate_channels(operator))
|
||||
environment.execution_graph._generate_channels(operator))
|
||||
# Check actual connectivity
|
||||
actual = []
|
||||
for destination in channels_per_destination:
|
||||
for channels in destination.values():
|
||||
for channel in channels:
|
||||
src_instance_id = channel.src_instance_id
|
||||
dst_instance_id = channel.dst_instance_id
|
||||
connection = (src_instance_id, dst_instance_id)
|
||||
src_instance_index = channel.src_instance_index
|
||||
dst_instance_index = channel.dst_instance_index
|
||||
connection = (src_instance_index, dst_instance_index)
|
||||
assert channel.dst_operator_id == map_id, (
|
||||
channel.dst_operator_id, map_id)
|
||||
actual.append(connection)
|
||||
|
@ -205,6 +207,4 @@ def test_wordcount():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
test_channel_generation()
|
20
streaming/python/tests/test_word_count.py
Normal file
20
streaming/python/tests/test_word_count.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import ray
|
||||
from ray.streaming.config import Config
|
||||
from ray.streaming.streaming import Environment, Conf
|
||||
|
||||
|
||||
def test_word_count():
|
||||
ray.init()
|
||||
env = Environment(config=Conf(channel_type=Config.NATIVE_CHANNEL))
|
||||
env.read_text_file(__file__) \
|
||||
.set_parallelism(1) \
|
||||
.filter(lambda x: "word" in x) \
|
||||
.inspect(lambda x: print("result", x))
|
||||
env_handle = env.execute()
|
||||
ray.get(env_handle) # Stay alive until execution finishes
|
||||
env.wait_finish()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_word_count()
|
274
streaming/src/channel.cc
Normal file
274
streaming/src/channel.cc
Normal file
|
@ -0,0 +1,274 @@
|
|||
#include "channel.h"
|
||||
#include <unordered_map>
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
ProducerChannel::ProducerChannel(std::shared_ptr<Config> &transfer_config,
|
||||
ProducerChannelInfo &p_channel_info)
|
||||
: transfer_config_(transfer_config), channel_info(p_channel_info) {}
|
||||
|
||||
ConsumerChannel::ConsumerChannel(std::shared_ptr<Config> &transfer_config,
|
||||
ConsumerChannelInfo &c_channel_info)
|
||||
: transfer_config_(transfer_config), channel_info(c_channel_info) {}
|
||||
|
||||
StreamingQueueProducer::StreamingQueueProducer(std::shared_ptr<Config> &transfer_config,
|
||||
ProducerChannelInfo &p_channel_info)
|
||||
: ProducerChannel(transfer_config, p_channel_info) {
|
||||
STREAMING_LOG(INFO) << "Producer Init";
|
||||
}
|
||||
|
||||
StreamingQueueProducer::~StreamingQueueProducer() {
|
||||
STREAMING_LOG(INFO) << "Producer Destory";
|
||||
}
|
||||
|
||||
StreamingStatus StreamingQueueProducer::CreateTransferChannel() {
|
||||
CreateQueue();
|
||||
|
||||
uint64_t queue_last_seq_id = 0;
|
||||
uint64_t last_message_id_in_queue = 0;
|
||||
|
||||
if (!last_message_id_in_queue) {
|
||||
if (last_message_id_in_queue < channel_info.current_message_id) {
|
||||
STREAMING_LOG(WARNING) << "last message id in queue : " << last_message_id_in_queue
|
||||
<< " is less than message checkpoint loaded id : "
|
||||
<< channel_info.current_message_id
|
||||
<< ", an old queue object " << channel_info.channel_id
|
||||
<< " was fond in store";
|
||||
}
|
||||
last_message_id_in_queue = channel_info.current_message_id;
|
||||
}
|
||||
if (queue_last_seq_id == static_cast<uint64_t>(-1)) {
|
||||
queue_last_seq_id = 0;
|
||||
}
|
||||
channel_info.current_seq_id = queue_last_seq_id;
|
||||
|
||||
STREAMING_LOG(WARNING) << "existing last message id => " << last_message_id_in_queue
|
||||
<< ", message id in channel => "
|
||||
<< channel_info.current_message_id << ", queue last seq id => "
|
||||
<< queue_last_seq_id;
|
||||
|
||||
channel_info.message_last_commit_id = last_message_id_in_queue;
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus StreamingQueueProducer::CreateQueue() {
|
||||
STREAMING_LOG(INFO) << "CreateQueue qid: " << channel_info.channel_id
|
||||
<< " data_size: " << channel_info.queue_size;
|
||||
auto upstream_handler = ray::streaming::UpstreamQueueMessageHandler::GetService();
|
||||
if (upstream_handler->UpstreamQueueExists(channel_info.channel_id)) {
|
||||
RAY_LOG(INFO) << "StreamingQueueWriter::CreateQueue duplicate!!!";
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
upstream_handler->SetPeerActorID(channel_info.channel_id, channel_info.actor_id);
|
||||
queue_ = upstream_handler->CreateUpstreamQueue(
|
||||
channel_info.channel_id, channel_info.actor_id, channel_info.queue_size);
|
||||
STREAMING_CHECK(queue_ != nullptr);
|
||||
|
||||
std::vector<ObjectID> queue_ids, failed_queues;
|
||||
queue_ids.push_back(channel_info.channel_id);
|
||||
upstream_handler->WaitQueues(queue_ids, 10 * 1000, failed_queues);
|
||||
|
||||
STREAMING_LOG(INFO) << "q id => " << channel_info.channel_id << ", queue size => "
|
||||
<< channel_info.queue_size;
|
||||
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus StreamingQueueProducer::DestroyTransferChannel() {
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus StreamingQueueProducer::ClearTransferCheckpoint(
|
||||
uint64_t checkpoint_id, uint64_t checkpoint_offset) {
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus StreamingQueueProducer::NotifyChannelConsumed(uint64_t channel_offset) {
|
||||
queue_->SetQueueEvictionLimit(channel_offset);
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus StreamingQueueProducer::ProduceItemToChannel(uint8_t *data,
|
||||
uint32_t data_size) {
|
||||
Status status =
|
||||
PushQueueItem(channel_info.current_seq_id + 1, data, data_size, current_time_ms());
|
||||
|
||||
if (status.code() != StatusCode::OK) {
|
||||
STREAMING_LOG(DEBUG) << channel_info.channel_id << " => Queue is full"
|
||||
<< " meesage => " << status.message();
|
||||
|
||||
// Assume that only status OutOfMemory and OK are acceptable.
|
||||
// OutOfMemory means queue is full at that moment.
|
||||
STREAMING_CHECK(status.code() == StatusCode::OutOfMemory)
|
||||
<< "status => " << status.message()
|
||||
<< ", perhaps data block is so large that it can't be stored in"
|
||||
<< ", data block size => " << data_size;
|
||||
|
||||
return StreamingStatus::FullChannel;
|
||||
}
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
Status StreamingQueueProducer::PushQueueItem(uint64_t seq_id, uint8_t *data,
|
||||
uint32_t data_size, uint64_t timestamp) {
|
||||
STREAMING_LOG(INFO) << "StreamingQueueProducer::PushQueueItem:"
|
||||
<< " qid: " << channel_info.channel_id << " seq_id: " << seq_id
|
||||
<< " data_size: " << data_size;
|
||||
Status status = queue_->Push(seq_id, data, data_size, timestamp, false);
|
||||
if (status.IsOutOfMemory()) {
|
||||
status = queue_->TryEvictItems();
|
||||
if (!status.ok()) {
|
||||
STREAMING_LOG(INFO) << "Evict fail.";
|
||||
return status;
|
||||
}
|
||||
|
||||
status = queue_->Push(seq_id, data, data_size, timestamp, false);
|
||||
}
|
||||
|
||||
queue_->Send();
|
||||
return status;
|
||||
}
|
||||
|
||||
StreamingQueueConsumer::StreamingQueueConsumer(std::shared_ptr<Config> &transfer_config,
|
||||
ConsumerChannelInfo &c_channel_info)
|
||||
: ConsumerChannel(transfer_config, c_channel_info) {
|
||||
STREAMING_LOG(INFO) << "Consumer Init";
|
||||
}
|
||||
|
||||
StreamingQueueConsumer::~StreamingQueueConsumer() {
|
||||
STREAMING_LOG(INFO) << "Consumer Destroy";
|
||||
}
|
||||
|
||||
StreamingStatus StreamingQueueConsumer::CreateTransferChannel() {
|
||||
auto downstream_handler = ray::streaming::DownstreamQueueMessageHandler::GetService();
|
||||
STREAMING_LOG(INFO) << "GetQueue qid: " << channel_info.channel_id
|
||||
<< " start_seq_id: " << channel_info.current_seq_id + 1;
|
||||
if (downstream_handler->DownstreamQueueExists(channel_info.channel_id)) {
|
||||
RAY_LOG(INFO) << "StreamingQueueReader::GetQueue duplicate!!!";
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
downstream_handler->SetPeerActorID(channel_info.channel_id, channel_info.actor_id);
|
||||
STREAMING_LOG(INFO) << "Create ReaderQueue " << channel_info.channel_id
|
||||
<< " pull from start_seq_id: " << channel_info.current_seq_id + 1;
|
||||
queue_ = downstream_handler->CreateDownstreamQueue(channel_info.channel_id,
|
||||
channel_info.actor_id);
|
||||
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus StreamingQueueConsumer::DestroyTransferChannel() {
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus StreamingQueueConsumer::ClearTransferCheckpoint(
|
||||
uint64_t checkpoint_id, uint64_t checkpoint_offset) {
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus StreamingQueueConsumer::ConsumeItemFromChannel(uint64_t &offset_id,
|
||||
uint8_t *&data,
|
||||
uint32_t &data_size,
|
||||
uint32_t timeout) {
|
||||
STREAMING_LOG(INFO) << "GetQueueItem qid: " << channel_info.channel_id;
|
||||
STREAMING_CHECK(queue_ != nullptr);
|
||||
QueueItem item = queue_->PopPendingBlockTimeout(timeout * 1000);
|
||||
if (item.SeqId() == QUEUE_INVALID_SEQ_ID) {
|
||||
STREAMING_LOG(INFO) << "GetQueueItem timeout.";
|
||||
data = nullptr;
|
||||
data_size = 0;
|
||||
offset_id = QUEUE_INVALID_SEQ_ID;
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
data = item.Buffer()->Data();
|
||||
offset_id = item.SeqId();
|
||||
data_size = item.Buffer()->Size();
|
||||
|
||||
STREAMING_LOG(DEBUG) << "GetQueueItem qid: " << channel_info.channel_id
|
||||
<< " seq_id: " << offset_id << " msg_id: " << item.MaxMsgId()
|
||||
<< " data_size: " << data_size;
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus StreamingQueueConsumer::NotifyChannelConsumed(uint64_t offset_id) {
|
||||
STREAMING_CHECK(queue_ != nullptr);
|
||||
queue_->OnConsumed(offset_id);
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
// For mock queue transfer
|
||||
struct MockQueueItem {
|
||||
uint64_t seq_id;
|
||||
uint32_t data_size;
|
||||
std::shared_ptr<uint8_t> data;
|
||||
};
|
||||
|
||||
struct MockQueue {
|
||||
std::unordered_map<ObjectID, std::shared_ptr<AbstractRingBufferImpl<MockQueueItem>>>
|
||||
message_buffer_;
|
||||
std::unordered_map<ObjectID, std::shared_ptr<AbstractRingBufferImpl<MockQueueItem>>>
|
||||
consumed_buffer_;
|
||||
};
|
||||
static MockQueue mock_queue;
|
||||
|
||||
StreamingStatus MockProducer::CreateTransferChannel() {
|
||||
mock_queue.message_buffer_[channel_info.channel_id] =
|
||||
std::make_shared<RingBufferImplThreadSafe<MockQueueItem>>(500);
|
||||
mock_queue.consumed_buffer_[channel_info.channel_id] =
|
||||
std::make_shared<RingBufferImplThreadSafe<MockQueueItem>>(500);
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus MockProducer::DestroyTransferChannel() {
|
||||
mock_queue.message_buffer_.erase(channel_info.channel_id);
|
||||
mock_queue.consumed_buffer_.erase(channel_info.channel_id);
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus MockProducer::ProduceItemToChannel(uint8_t *data, uint32_t data_size) {
|
||||
auto &ring_buffer = mock_queue.message_buffer_[channel_info.channel_id];
|
||||
if (ring_buffer->Full()) {
|
||||
return StreamingStatus::OutOfMemory;
|
||||
}
|
||||
MockQueueItem item;
|
||||
item.seq_id = channel_info.current_seq_id + 1;
|
||||
item.data.reset(new uint8_t[data_size]);
|
||||
item.data_size = data_size;
|
||||
std::memcpy(item.data.get(), data, data_size);
|
||||
ring_buffer->Push(item);
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus MockConsumer::ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data,
|
||||
uint32_t &data_size,
|
||||
uint32_t timeout) {
|
||||
auto &channel_id = channel_info.channel_id;
|
||||
if (mock_queue.message_buffer_.find(channel_id) == mock_queue.message_buffer_.end()) {
|
||||
return StreamingStatus::NoSuchItem;
|
||||
}
|
||||
|
||||
if (mock_queue.message_buffer_[channel_id]->Empty()) {
|
||||
return StreamingStatus::NoSuchItem;
|
||||
}
|
||||
MockQueueItem item = mock_queue.message_buffer_[channel_id]->Front();
|
||||
mock_queue.message_buffer_[channel_id]->Pop();
|
||||
mock_queue.consumed_buffer_[channel_id]->Push(item);
|
||||
offset_id = item.seq_id;
|
||||
data = item.data.get();
|
||||
data_size = item.data_size;
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus MockConsumer::NotifyChannelConsumed(uint64_t offset_id) {
|
||||
auto &channel_id = channel_info.channel_id;
|
||||
auto &ring_buffer = mock_queue.consumed_buffer_[channel_id];
|
||||
while (!ring_buffer->Empty() && ring_buffer->Front().seq_id <= offset_id) {
|
||||
ring_buffer->Pop();
|
||||
}
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
176
streaming/src/channel.h
Normal file
176
streaming/src/channel.h
Normal file
|
@ -0,0 +1,176 @@
|
|||
#ifndef RAY_CHANNEL_H
|
||||
#define RAY_CHANNEL_H
|
||||
|
||||
#include "config/streaming_config.h"
|
||||
#include "queue/queue_handler.h"
|
||||
#include "ring_buffer.h"
|
||||
#include "status.h"
|
||||
#include "util/streaming_util.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
struct StreamingQueueInfo {
|
||||
uint64_t first_seq_id = 0;
|
||||
uint64_t last_seq_id = 0;
|
||||
uint64_t target_seq_id = 0;
|
||||
uint64_t consumed_seq_id = 0;
|
||||
};
|
||||
|
||||
/// PrducerChannelinfo and ConsumerChannelInfo contains channel information and
|
||||
/// its metrics that help us to debug or show important messages in logging.
|
||||
struct ProducerChannelInfo {
|
||||
ObjectID channel_id;
|
||||
StreamingRingBufferPtr writer_ring_buffer;
|
||||
uint64_t current_message_id;
|
||||
uint64_t current_seq_id;
|
||||
uint64_t message_last_commit_id;
|
||||
StreamingQueueInfo queue_info;
|
||||
uint32_t queue_size;
|
||||
int64_t message_pass_by_ts;
|
||||
ActorID actor_id;
|
||||
};
|
||||
|
||||
struct ConsumerChannelInfo {
|
||||
ObjectID channel_id;
|
||||
uint64_t current_message_id;
|
||||
uint64_t current_seq_id;
|
||||
uint64_t barrier_id;
|
||||
uint64_t partial_barrier_id;
|
||||
|
||||
StreamingQueueInfo queue_info;
|
||||
|
||||
uint64_t last_queue_item_delay;
|
||||
uint64_t last_queue_item_latency;
|
||||
uint64_t last_queue_target_diff;
|
||||
uint64_t get_queue_item_times;
|
||||
ActorID actor_id;
|
||||
};
|
||||
|
||||
/// Two types of channel are presented:
|
||||
/// * ProducerChannel is supporting all writing operations for upperlevel.
|
||||
/// * ConsumerChannel is for all reader operations.
|
||||
/// They share similar interfaces:
|
||||
/// * ClearTransferCheckpoint(it's empty and unsupported now, we will add
|
||||
/// implementation in next PR)
|
||||
/// * NotifychannelConsumed (notify owner of channel which range data should
|
||||
// be release to avoid out of memory)
|
||||
/// but some differences in read/write function.(named ProduceItemTochannel and
|
||||
/// ConsumeItemFrom channel)
|
||||
class ProducerChannel {
|
||||
public:
|
||||
explicit ProducerChannel(std::shared_ptr<Config> &transfer_config,
|
||||
ProducerChannelInfo &p_channel_info);
|
||||
virtual ~ProducerChannel() = default;
|
||||
virtual StreamingStatus CreateTransferChannel() = 0;
|
||||
virtual StreamingStatus DestroyTransferChannel() = 0;
|
||||
virtual StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id,
|
||||
uint64_t checkpoint_offset) = 0;
|
||||
virtual StreamingStatus ProduceItemToChannel(uint8_t *data, uint32_t data_size) = 0;
|
||||
virtual StreamingStatus NotifyChannelConsumed(uint64_t channel_offset) = 0;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Config> transfer_config_;
|
||||
ProducerChannelInfo &channel_info;
|
||||
};
|
||||
|
||||
class ConsumerChannel {
|
||||
public:
|
||||
explicit ConsumerChannel(std::shared_ptr<Config> &transfer_config,
|
||||
ConsumerChannelInfo &c_channel_info);
|
||||
virtual ~ConsumerChannel() = default;
|
||||
virtual StreamingStatus CreateTransferChannel() = 0;
|
||||
virtual StreamingStatus DestroyTransferChannel() = 0;
|
||||
virtual StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id,
|
||||
uint64_t checkpoint_offset) = 0;
|
||||
virtual StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data,
|
||||
uint32_t &data_size,
|
||||
uint32_t timeout) = 0;
|
||||
virtual StreamingStatus NotifyChannelConsumed(uint64_t offset_id) = 0;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Config> transfer_config_;
|
||||
ConsumerChannelInfo &channel_info;
|
||||
};
|
||||
|
||||
class StreamingQueueProducer : public ProducerChannel {
|
||||
public:
|
||||
explicit StreamingQueueProducer(std::shared_ptr<Config> &transfer_config,
|
||||
ProducerChannelInfo &p_channel_info);
|
||||
~StreamingQueueProducer() override;
|
||||
StreamingStatus CreateTransferChannel() override;
|
||||
StreamingStatus DestroyTransferChannel() override;
|
||||
StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id,
|
||||
uint64_t checkpoint_offset) override;
|
||||
StreamingStatus ProduceItemToChannel(uint8_t *data, uint32_t data_size) override;
|
||||
StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override;
|
||||
|
||||
private:
|
||||
StreamingStatus CreateQueue();
|
||||
Status PushQueueItem(uint64_t seq_id, uint8_t *data, uint32_t data_size,
|
||||
uint64_t timestamp);
|
||||
|
||||
private:
|
||||
std::shared_ptr<WriterQueue> queue_;
|
||||
};
|
||||
|
||||
class StreamingQueueConsumer : public ConsumerChannel {
|
||||
public:
|
||||
explicit StreamingQueueConsumer(std::shared_ptr<Config> &transfer_config,
|
||||
ConsumerChannelInfo &c_channel_info);
|
||||
~StreamingQueueConsumer() override;
|
||||
StreamingStatus CreateTransferChannel() override;
|
||||
StreamingStatus DestroyTransferChannel() override;
|
||||
StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id,
|
||||
uint64_t checkpoint_offset) override;
|
||||
StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data,
|
||||
uint32_t &data_size, uint32_t timeout) override;
|
||||
StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<ReaderQueue> queue_;
|
||||
};
|
||||
|
||||
/// MockProducer and Mockconsumer are independent implementation of channels that
|
||||
/// conduct a very simple memory channel for unit tests or intergation test.
|
||||
class MockProducer : public ProducerChannel {
|
||||
public:
|
||||
explicit MockProducer(std::shared_ptr<Config> &transfer_config,
|
||||
ProducerChannelInfo &channel_info)
|
||||
: ProducerChannel(transfer_config, channel_info){};
|
||||
StreamingStatus CreateTransferChannel() override;
|
||||
|
||||
StreamingStatus DestroyTransferChannel() override;
|
||||
|
||||
StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id,
|
||||
uint64_t checkpoint_offset) override {
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus ProduceItemToChannel(uint8_t *data, uint32_t data_size) override;
|
||||
|
||||
StreamingStatus NotifyChannelConsumed(uint64_t channel_offset) override {
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
};
|
||||
|
||||
class MockConsumer : public ConsumerChannel {
|
||||
public:
|
||||
explicit MockConsumer(std::shared_ptr<Config> &transfer_config,
|
||||
ConsumerChannelInfo &c_channel_info)
|
||||
: ConsumerChannel(transfer_config, c_channel_info){};
|
||||
StreamingStatus CreateTransferChannel() override { return StreamingStatus::OK; }
|
||||
StreamingStatus DestroyTransferChannel() override { return StreamingStatus::OK; }
|
||||
StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id,
|
||||
uint64_t checkpoint_offset) override {
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data,
|
||||
uint32_t &data_size, uint32_t timeout) override;
|
||||
StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override;
|
||||
};
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_CHANNEL_H
|
89
streaming/src/config/streaming_config.cc
Normal file
89
streaming/src/config/streaming_config.cc
Normal file
|
@ -0,0 +1,89 @@
|
|||
#include <unistd.h>
|
||||
|
||||
#include "streaming_config.h"
|
||||
#include "util/streaming_logging.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
uint64_t StreamingConfig::TIME_WAIT_UINT = 1;
|
||||
uint32_t StreamingConfig::DEFAULT_RING_BUFFER_CAPACITY = 500;
|
||||
uint32_t StreamingConfig::DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL = 20;
|
||||
// Time to force clean if barrier in queue, default 0ms
|
||||
const uint32_t StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE = 2048;
|
||||
|
||||
void StreamingConfig::FromProto(const uint8_t *data, uint32_t size) {
|
||||
proto::StreamingConfig config;
|
||||
STREAMING_CHECK(config.ParseFromArray(data, size)) << "Parse streaming conf failed";
|
||||
if (!config.job_name().empty()) {
|
||||
SetJobName(config.job_name());
|
||||
}
|
||||
if (!config.task_job_id().empty()) {
|
||||
STREAMING_CHECK(config.task_job_id().size() == 2 * JobID::Size());
|
||||
SetTaskJobId(config.task_job_id());
|
||||
}
|
||||
if (!config.worker_name().empty()) {
|
||||
SetWorkerName(config.worker_name());
|
||||
}
|
||||
if (!config.op_name().empty()) {
|
||||
SetOpName(config.op_name());
|
||||
}
|
||||
if (config.role() != proto::OperatorType::UNKNOWN) {
|
||||
SetOperatorType(config.role());
|
||||
}
|
||||
if (config.ring_buffer_capacity() != 0) {
|
||||
SetRingBufferCapacity(config.ring_buffer_capacity());
|
||||
}
|
||||
if (config.empty_message_interval() != 0) {
|
||||
SetEmptyMessageTimeInterval(config.empty_message_interval());
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t StreamingConfig::GetRingBufferCapacity() const { return ring_buffer_capacity_; }
|
||||
|
||||
void StreamingConfig::SetRingBufferCapacity(uint32_t ring_buffer_capacity) {
|
||||
StreamingConfig::ring_buffer_capacity_ =
|
||||
std::min(ring_buffer_capacity, StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE);
|
||||
}
|
||||
|
||||
uint32_t StreamingConfig::GetEmptyMessageTimeInterval() const {
|
||||
return empty_message_time_interval_;
|
||||
}
|
||||
|
||||
void StreamingConfig::SetEmptyMessageTimeInterval(uint32_t empty_message_time_interval) {
|
||||
StreamingConfig::empty_message_time_interval_ = empty_message_time_interval;
|
||||
}
|
||||
|
||||
streaming::proto::OperatorType StreamingConfig::GetOperatorType() const {
|
||||
return operator_type_;
|
||||
}
|
||||
|
||||
void StreamingConfig::SetOperatorType(streaming::proto::OperatorType type) {
|
||||
StreamingConfig::operator_type_ = type;
|
||||
}
|
||||
|
||||
const std::string &StreamingConfig::GetJobName() const { return job_name_; }
|
||||
|
||||
void StreamingConfig::SetJobName(const std::string &job_name) {
|
||||
StreamingConfig::job_name_ = job_name;
|
||||
}
|
||||
|
||||
const std::string &StreamingConfig::GetOpName() const { return op_name_; }
|
||||
|
||||
void StreamingConfig::SetOpName(const std::string &op_name) {
|
||||
StreamingConfig::op_name_ = op_name;
|
||||
}
|
||||
|
||||
const std::string &StreamingConfig::GetWorkerName() const { return worker_name_; }
|
||||
void StreamingConfig::SetWorkerName(const std::string &worker_name) {
|
||||
StreamingConfig::worker_name_ = worker_name;
|
||||
}
|
||||
|
||||
const std::string &StreamingConfig::GetTaskJobId() const { return task_job_id_; }
|
||||
|
||||
void StreamingConfig::SetTaskJobId(const std::string &task_job_id) {
|
||||
StreamingConfig::task_job_id_ = task_job_id;
|
||||
}
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
69
streaming/src/config/streaming_config.h
Normal file
69
streaming/src/config/streaming_config.h
Normal file
|
@ -0,0 +1,69 @@
|
|||
#ifndef RAY_STREAMING_CONFIG_H
|
||||
#define RAY_STREAMING_CONFIG_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "protobuf/streaming.pb.h"
|
||||
#include "ray/common/id.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
class StreamingConfig {
|
||||
public:
|
||||
static uint64_t TIME_WAIT_UINT;
|
||||
static uint32_t DEFAULT_RING_BUFFER_CAPACITY;
|
||||
static uint32_t DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL;
|
||||
static const uint32_t MESSAGE_BUNDLE_MAX_SIZE;
|
||||
|
||||
private:
|
||||
uint32_t ring_buffer_capacity_ = DEFAULT_RING_BUFFER_CAPACITY;
|
||||
|
||||
uint32_t empty_message_time_interval_ = DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL;
|
||||
|
||||
streaming::proto::OperatorType operator_type_ =
|
||||
streaming::proto::OperatorType::TRANSFORM;
|
||||
|
||||
std::string job_name_ = "DEFAULT_JOB_NAME";
|
||||
|
||||
std::string op_name_ = "DEFAULT_OP_NAME";
|
||||
|
||||
std::string worker_name_ = "DEFAULT_WORKER_NAME";
|
||||
|
||||
std::string task_job_id_ = "ffffffff";
|
||||
|
||||
public:
|
||||
void FromProto(const uint8_t *, uint32_t size);
|
||||
|
||||
const std::string &GetTaskJobId() const;
|
||||
|
||||
void SetTaskJobId(const std::string &task_job_id);
|
||||
|
||||
const std::string &GetWorkerName() const;
|
||||
|
||||
void SetWorkerName(const std::string &worker_name);
|
||||
|
||||
const std::string &GetOpName() const;
|
||||
|
||||
void SetOpName(const std::string &op_name);
|
||||
|
||||
uint32_t GetEmptyMessageTimeInterval() const;
|
||||
|
||||
void SetEmptyMessageTimeInterval(uint32_t empty_message_time_interval);
|
||||
|
||||
uint32_t GetRingBufferCapacity() const;
|
||||
|
||||
void SetRingBufferCapacity(uint32_t ring_buffer_capacity);
|
||||
|
||||
streaming::proto::OperatorType GetOperatorType() const;
|
||||
|
||||
void SetOperatorType(streaming::proto::OperatorType type);
|
||||
|
||||
const std::string &GetJobName() const;
|
||||
|
||||
void SetJobName(const std::string &job_name);
|
||||
};
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
#endif // RAY_STREAMING_CONFIG_H
|
297
streaming/src/data_reader.cc
Normal file
297
streaming/src/data_reader.cc
Normal file
|
@ -0,0 +1,297 @@
|
|||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
|
||||
#include "ray/util/logging.h"
|
||||
#include "ray/util/util.h"
|
||||
|
||||
#include "data_reader.h"
|
||||
#include "message/message_bundle.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
const uint32_t DataReader::kReadItemTimeout = 1000;
|
||||
|
||||
void DataReader::Init(const std::vector<ObjectID> &input_ids,
|
||||
const std::vector<ActorID> &actor_ids,
|
||||
const std::vector<uint64_t> &queue_seq_ids,
|
||||
const std::vector<uint64_t> &streaming_msg_ids,
|
||||
int64_t timer_interval) {
|
||||
Init(input_ids, actor_ids, timer_interval);
|
||||
for (size_t i = 0; i < input_ids.size(); ++i) {
|
||||
auto &q_id = input_ids[i];
|
||||
channel_info_map_[q_id].current_seq_id = queue_seq_ids[i];
|
||||
channel_info_map_[q_id].current_message_id = streaming_msg_ids[i];
|
||||
}
|
||||
}
|
||||
|
||||
void DataReader::Init(const std::vector<ObjectID> &input_ids,
|
||||
const std::vector<ActorID> &actor_ids, int64_t timer_interval) {
|
||||
STREAMING_LOG(INFO) << input_ids.size() << " queue to init.";
|
||||
|
||||
transfer_config_->Set(ConfigEnum::QUEUE_ID_VECTOR, input_ids);
|
||||
|
||||
last_fetched_queue_item_ = nullptr;
|
||||
timer_interval_ = timer_interval;
|
||||
last_message_ts_ = 0;
|
||||
input_queue_ids_ = input_ids;
|
||||
last_message_latency_ = 0;
|
||||
last_bundle_unit_ = 0;
|
||||
|
||||
for (size_t i = 0; i < input_ids.size(); ++i) {
|
||||
ObjectID q_id = input_ids[i];
|
||||
STREAMING_LOG(INFO) << "[Reader] Init queue id: " << q_id;
|
||||
auto &channel_info = channel_info_map_[q_id];
|
||||
channel_info.channel_id = q_id;
|
||||
channel_info.actor_id = actor_ids[i];
|
||||
channel_info.last_queue_item_delay = 0;
|
||||
channel_info.last_queue_item_latency = 0;
|
||||
channel_info.last_queue_target_diff = 0;
|
||||
channel_info.get_queue_item_times = 0;
|
||||
}
|
||||
|
||||
/// Make the input id location stable.
|
||||
sort(input_queue_ids_.begin(), input_queue_ids_.end(),
|
||||
[](const ObjectID &a, const ObjectID &b) { return a.Hash() < b.Hash(); });
|
||||
std::copy(input_ids.begin(), input_ids.end(), std::back_inserter(unready_queue_ids_));
|
||||
InitChannel();
|
||||
}
|
||||
|
||||
StreamingStatus DataReader::InitChannel() {
|
||||
STREAMING_LOG(INFO) << "[Reader] Getting queues. total queue num "
|
||||
<< input_queue_ids_.size() << ", unready queue num => "
|
||||
<< unready_queue_ids_.size();
|
||||
|
||||
for (const auto &input_channel : unready_queue_ids_) {
|
||||
auto &channel_info = channel_info_map_[input_channel];
|
||||
std::shared_ptr<ConsumerChannel> channel;
|
||||
if (runtime_context_->IsMockTest()) {
|
||||
channel = std::make_shared<MockConsumer>(transfer_config_, channel_info);
|
||||
} else {
|
||||
channel = std::make_shared<StreamingQueueConsumer>(transfer_config_, channel_info);
|
||||
}
|
||||
|
||||
channel_map_.emplace(input_channel, channel);
|
||||
StreamingStatus status = channel->CreateTransferChannel();
|
||||
if (StreamingStatus::OK != status) {
|
||||
STREAMING_LOG(ERROR) << "Initialize queue failed, id => " << input_channel;
|
||||
}
|
||||
}
|
||||
runtime_context_->SetRuntimeStatus(RuntimeStatus::Running);
|
||||
STREAMING_LOG(INFO) << "[Reader] Reader construction done!";
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus DataReader::InitChannelMerger() {
|
||||
STREAMING_LOG(INFO) << "[Reader] Initializing queue merger.";
|
||||
// Init reader merger by given comparator when it's first created.
|
||||
StreamingReaderMsgPtrComparator comparator;
|
||||
if (!reader_merger_) {
|
||||
reader_merger_.reset(
|
||||
new PriorityQueue<std::shared_ptr<DataBundle>, StreamingReaderMsgPtrComparator>(
|
||||
comparator));
|
||||
}
|
||||
|
||||
// An old item in merger vector must be evicted before new queue item has been
|
||||
// pushed.
|
||||
if (!unready_queue_ids_.empty() && last_fetched_queue_item_) {
|
||||
STREAMING_LOG(INFO) << "pop old item from => " << last_fetched_queue_item_->from;
|
||||
RETURN_IF_NOT_OK(StashNextMessage(last_fetched_queue_item_))
|
||||
last_fetched_queue_item_.reset();
|
||||
}
|
||||
// Create initial heap for priority queue.
|
||||
for (auto &input_queue : unready_queue_ids_) {
|
||||
std::shared_ptr<DataBundle> msg = std::make_shared<DataBundle>();
|
||||
RETURN_IF_NOT_OK(GetMessageFromChannel(channel_info_map_[input_queue], msg))
|
||||
channel_info_map_[msg->from].current_seq_id = msg->seq_id;
|
||||
channel_info_map_[msg->from].current_message_id = msg->meta->GetLastMessageId();
|
||||
reader_merger_->push(msg);
|
||||
}
|
||||
STREAMING_LOG(INFO) << "[Reader] Initializing merger done.";
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus DataReader::GetMessageFromChannel(ConsumerChannelInfo &channel_info,
|
||||
std::shared_ptr<DataBundle> &message) {
|
||||
auto &qid = channel_info.channel_id;
|
||||
last_read_q_id_ = qid;
|
||||
STREAMING_LOG(DEBUG) << "[Reader] send get request queue seq id => " << qid;
|
||||
while (RuntimeStatus::Running == runtime_context_->GetRuntimeStatus() &&
|
||||
!message->data) {
|
||||
auto status = channel_map_[channel_info.channel_id]->ConsumeItemFromChannel(
|
||||
message->seq_id, message->data, message->data_size, kReadItemTimeout);
|
||||
channel_info.get_queue_item_times++;
|
||||
if (!message->data) {
|
||||
STREAMING_LOG(DEBUG) << "[Reader] Queue " << qid << " status " << status
|
||||
<< " get item timeout, resend notify "
|
||||
<< channel_info.current_seq_id;
|
||||
// TODO(lingxuan.zlx): notify consumed when it's timeout.
|
||||
}
|
||||
}
|
||||
if (RuntimeStatus::Interrupted == runtime_context_->GetRuntimeStatus()) {
|
||||
return StreamingStatus::Interrupted;
|
||||
}
|
||||
STREAMING_LOG(DEBUG) << "[Reader] recevied queue seq id => " << message->seq_id
|
||||
<< ", queue id => " << qid;
|
||||
|
||||
message->from = qid;
|
||||
message->meta = StreamingMessageBundleMeta::FromBytes(message->data);
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus DataReader::StashNextMessage(std::shared_ptr<DataBundle> &message) {
|
||||
// Push new message into priority queue and record the channel metrics in
|
||||
// channel info.
|
||||
std::shared_ptr<DataBundle> new_msg = std::make_shared<DataBundle>();
|
||||
auto &channel_info = channel_info_map_[message->from];
|
||||
reader_merger_->pop();
|
||||
int64_t cur_time = current_time_ms();
|
||||
RETURN_IF_NOT_OK(GetMessageFromChannel(channel_info, new_msg))
|
||||
reader_merger_->push(new_msg);
|
||||
channel_info.last_queue_item_delay =
|
||||
new_msg->meta->GetMessageBundleTs() - message->meta->GetMessageBundleTs();
|
||||
channel_info.last_queue_item_latency = current_time_ms() - cur_time;
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus DataReader::GetMergedMessageBundle(std::shared_ptr<DataBundle> &message,
|
||||
bool &is_valid_break) {
|
||||
int64_t cur_time = current_time_ms();
|
||||
if (last_fetched_queue_item_) {
|
||||
RETURN_IF_NOT_OK(StashNextMessage(last_fetched_queue_item_))
|
||||
}
|
||||
message = reader_merger_->top();
|
||||
last_fetched_queue_item_ = message;
|
||||
auto &offset_info = channel_info_map_[message->from];
|
||||
|
||||
uint64_t cur_queue_previous_msg_id = offset_info.current_message_id;
|
||||
STREAMING_LOG(DEBUG) << "[Reader] [Bundle] from q_id =>" << message->from << "cur => "
|
||||
<< cur_queue_previous_msg_id << ", message list size"
|
||||
<< message->meta->GetMessageListSize() << ", lst message id =>"
|
||||
<< message->meta->GetLastMessageId() << ", q seq id => "
|
||||
<< message->seq_id << ", last barrier id => " << message->data_size
|
||||
<< ", " << message->meta->GetMessageBundleTs();
|
||||
|
||||
if (message->meta->IsBundle()) {
|
||||
last_message_ts_ = cur_time;
|
||||
is_valid_break = true;
|
||||
} else if (timer_interval_ != -1 && cur_time - last_message_ts_ > timer_interval_) {
|
||||
// Throw empty message when reaching timer_interval.
|
||||
last_message_ts_ = cur_time;
|
||||
is_valid_break = true;
|
||||
}
|
||||
|
||||
offset_info.current_message_id = message->meta->GetLastMessageId();
|
||||
offset_info.current_seq_id = message->seq_id;
|
||||
last_bundle_ts_ = message->meta->GetMessageBundleTs();
|
||||
|
||||
STREAMING_LOG(DEBUG) << "[Reader] [Bundle] message type =>"
|
||||
<< static_cast<int>(message->meta->GetBundleType())
|
||||
<< " from id => " << message->from << ", queue seq id =>"
|
||||
<< message->seq_id << ", message id => "
|
||||
<< message->meta->GetLastMessageId();
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus DataReader::GetBundle(const uint32_t timeout_ms,
|
||||
std::shared_ptr<DataBundle> &message) {
|
||||
// Notify consumed every item in this mode.
|
||||
if (last_fetched_queue_item_) {
|
||||
NotifyConsumedItem(channel_info_map_[last_fetched_queue_item_->from],
|
||||
last_fetched_queue_item_->seq_id);
|
||||
}
|
||||
|
||||
/// DataBundle will be returned to the upper layer in the following cases:
|
||||
/// a batch of data is returned when the real data is read, or an empty message
|
||||
/// is returned to the upper layer when the given timeout period is reached to
|
||||
/// avoid blocking for too long.
|
||||
auto start_time = current_time_ms();
|
||||
bool is_valid_break = false;
|
||||
uint32_t empty_bundle_cnt = 0;
|
||||
while (!is_valid_break) {
|
||||
if (RuntimeStatus::Interrupted == runtime_context_->GetRuntimeStatus()) {
|
||||
return StreamingStatus::Interrupted;
|
||||
}
|
||||
auto cur_time = current_time_ms();
|
||||
auto dur = cur_time - start_time;
|
||||
if (dur > timeout_ms) {
|
||||
return StreamingStatus::GetBundleTimeOut;
|
||||
}
|
||||
if (!unready_queue_ids_.empty()) {
|
||||
StreamingStatus status = InitChannel();
|
||||
switch (status) {
|
||||
case StreamingStatus::InitQueueFailed:
|
||||
break;
|
||||
case StreamingStatus::WaitQueueTimeOut:
|
||||
STREAMING_LOG(ERROR)
|
||||
<< "Wait upstream queue timeout, maybe some actors in deadlock";
|
||||
break;
|
||||
default:
|
||||
STREAMING_LOG(INFO) << "Init reader queue in GetBundle";
|
||||
}
|
||||
if (StreamingStatus::OK != status) {
|
||||
return status;
|
||||
}
|
||||
RETURN_IF_NOT_OK(InitChannelMerger())
|
||||
unready_queue_ids_.clear();
|
||||
auto &merge_vec = reader_merger_->getRawVector();
|
||||
for (auto &bundle : merge_vec) {
|
||||
STREAMING_LOG(INFO) << "merger vector item => " << bundle->from;
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(GetMergedMessageBundle(message, is_valid_break));
|
||||
if (!is_valid_break) {
|
||||
empty_bundle_cnt++;
|
||||
NotifyConsumedItem(channel_info_map_[message->from], message->seq_id);
|
||||
}
|
||||
}
|
||||
last_message_latency_ += current_time_ms() - start_time;
|
||||
if (message->meta->GetMessageListSize() > 0) {
|
||||
last_bundle_unit_ = message->data_size * 1.0 / message->meta->GetMessageListSize();
|
||||
}
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
void DataReader::GetOffsetInfo(
|
||||
std::unordered_map<ObjectID, ConsumerChannelInfo> *&offset_map) {
|
||||
offset_map = &channel_info_map_;
|
||||
for (auto &offset_info : channel_info_map_) {
|
||||
STREAMING_LOG(INFO) << "[Reader] [GetOffsetInfo], q id " << offset_info.first
|
||||
<< ", seq id => " << offset_info.second.current_seq_id
|
||||
<< ", message id => " << offset_info.second.current_message_id;
|
||||
}
|
||||
}
|
||||
|
||||
void DataReader::NotifyConsumedItem(ConsumerChannelInfo &channel_info, uint64_t offset) {
|
||||
channel_map_[channel_info.channel_id]->NotifyChannelConsumed(offset);
|
||||
if (offset == channel_info.queue_info.last_seq_id) {
|
||||
STREAMING_LOG(DEBUG) << "notify seq id equal to last seq id => " << offset;
|
||||
}
|
||||
}
|
||||
|
||||
DataReader::DataReader(std::shared_ptr<RuntimeContext> &runtime_context)
|
||||
: transfer_config_(new Config()), runtime_context_(runtime_context) {}
|
||||
|
||||
DataReader::~DataReader() { STREAMING_LOG(INFO) << "Streaming reader deconstruct."; }
|
||||
|
||||
void DataReader::Stop() {
|
||||
runtime_context_->SetRuntimeStatus(RuntimeStatus::Interrupted);
|
||||
}
|
||||
|
||||
bool StreamingReaderMsgPtrComparator::operator()(const std::shared_ptr<DataBundle> &a,
|
||||
const std::shared_ptr<DataBundle> &b) {
|
||||
STREAMING_CHECK(a->meta);
|
||||
// We use hash value of id for stability of message in sorting.
|
||||
if (a->meta->GetMessageBundleTs() == b->meta->GetMessageBundleTs()) {
|
||||
return a->from.Hash() > b->from.Hash();
|
||||
}
|
||||
return a->meta->GetMessageBundleTs() > b->meta->GetMessageBundleTs();
|
||||
}
|
||||
|
||||
} // namespace streaming
|
||||
|
||||
} // namespace ray
|
127
streaming/src/data_reader.h
Normal file
127
streaming/src/data_reader.h
Normal file
|
@ -0,0 +1,127 @@
|
|||
#ifndef RAY_DATA_READER_H
|
||||
#define RAY_DATA_READER_H
|
||||
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "channel.h"
|
||||
#include "message/message_bundle.h"
|
||||
#include "message/priority_queue.h"
|
||||
#include "runtime_context.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
/// Databundle is super-bundle that contains channel information (upstream
|
||||
/// channel id & bundle meta data) and raw buffer pointer.
|
||||
struct DataBundle {
|
||||
uint8_t *data = nullptr;
|
||||
uint32_t data_size;
|
||||
ObjectID from;
|
||||
uint64_t seq_id;
|
||||
StreamingMessageBundleMetaPtr meta;
|
||||
};
|
||||
|
||||
/// This is implementation of merger policy in StreamingReaderMsgPtrComparator.
|
||||
struct StreamingReaderMsgPtrComparator {
|
||||
StreamingReaderMsgPtrComparator() = default;
|
||||
bool operator()(const std::shared_ptr<DataBundle> &a,
|
||||
const std::shared_ptr<DataBundle> &b);
|
||||
};
|
||||
|
||||
/// DataReader will fetch data bundles from channels of upstream workers, once
|
||||
/// invoked by user thread. Firstly put them into a priority queue ordered by bundle
|
||||
/// comparator that's related meta-data, then pop out the top bunlde to user
|
||||
/// thread every time, so that the order of the message can be guranteed, which
|
||||
/// will also facilitate our future implementation of fault tolerance. Finally
|
||||
/// user thread can extract messages from the bundle and process one by one.
|
||||
class DataReader {
|
||||
private:
|
||||
std::vector<ObjectID> input_queue_ids_;
|
||||
|
||||
std::vector<ObjectID> unready_queue_ids_;
|
||||
|
||||
std::unique_ptr<
|
||||
PriorityQueue<std::shared_ptr<DataBundle>, StreamingReaderMsgPtrComparator>>
|
||||
reader_merger_;
|
||||
|
||||
std::shared_ptr<DataBundle> last_fetched_queue_item_;
|
||||
|
||||
int64_t timer_interval_;
|
||||
int64_t last_bundle_ts_;
|
||||
int64_t last_message_ts_;
|
||||
int64_t last_message_latency_;
|
||||
int64_t last_bundle_unit_;
|
||||
|
||||
ObjectID last_read_q_id_;
|
||||
|
||||
static const uint32_t kReadItemTimeout;
|
||||
|
||||
protected:
|
||||
std::unordered_map<ObjectID, ConsumerChannelInfo> channel_info_map_;
|
||||
std::unordered_map<ObjectID, std::shared_ptr<ConsumerChannel>> channel_map_;
|
||||
std::shared_ptr<Config> transfer_config_;
|
||||
std::shared_ptr<RuntimeContext> runtime_context_;
|
||||
|
||||
public:
|
||||
explicit DataReader(std::shared_ptr<RuntimeContext> &runtime_context);
|
||||
virtual ~DataReader();
|
||||
|
||||
/// During initialization, only the channel parameters and necessary member properties
|
||||
/// are assigned. All channels will be connected in the first reading operation.
|
||||
/// \param input_ids
|
||||
/// \param actor_ids
|
||||
/// \param channel_seq_ids
|
||||
/// \param msg_ids
|
||||
/// \param timer_interval
|
||||
void Init(const std::vector<ObjectID> &input_ids, const std::vector<ActorID> &actor_ids,
|
||||
const std::vector<uint64_t> &channel_seq_ids,
|
||||
const std::vector<uint64_t> &msg_ids, int64_t timer_interval);
|
||||
|
||||
void Init(const std::vector<ObjectID> &input_ids, const std::vector<ActorID> &actor_ids,
|
||||
int64_t timer_interval);
|
||||
|
||||
/// Get latest message from input queues.
|
||||
/// \param timeout_ms
|
||||
/// \param message, return the latest message
|
||||
StreamingStatus GetBundle(uint32_t timeout_ms, std::shared_ptr<DataBundle> &message);
|
||||
|
||||
/// Get offset information about channels for checkpoint.
|
||||
/// \param offset_map (return value)
|
||||
void GetOffsetInfo(std::unordered_map<ObjectID, ConsumerChannelInfo> *&offset_map);
|
||||
|
||||
void Stop();
|
||||
|
||||
/// Notify input queues to clear data whose seq id is equal or less than offset.
|
||||
/// It's used when checkpoint is done.
|
||||
/// \param channel_info
|
||||
/// \param offset
|
||||
///
|
||||
void NotifyConsumedItem(ConsumerChannelInfo &channel_info, uint64_t offset);
|
||||
|
||||
private:
|
||||
/// Create channels and connect to all upstream.
|
||||
StreamingStatus InitChannel();
|
||||
|
||||
/// One item from every channel will be popped out, then collecting
|
||||
/// them to a merged queue. High prioprity items will be fetched one by one.
|
||||
/// When item pop from one channel where must produce new item for placeholder
|
||||
/// in merged queue.
|
||||
StreamingStatus InitChannelMerger();
|
||||
|
||||
StreamingStatus StashNextMessage(std::shared_ptr<DataBundle> &message);
|
||||
|
||||
StreamingStatus GetMessageFromChannel(ConsumerChannelInfo &channel_info,
|
||||
std::shared_ptr<DataBundle> &message);
|
||||
|
||||
/// Get top item from prioprity queue.
|
||||
StreamingStatus GetMergedMessageBundle(std::shared_ptr<DataBundle> &message,
|
||||
bool &is_valid_break);
|
||||
};
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
#endif // RAY_DATA_READER_H
|
310
streaming/src/data_writer.cc
Normal file
310
streaming/src/data_writer.cc
Normal file
|
@ -0,0 +1,310 @@
|
|||
#include <memory>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <signal.h>
|
||||
#include <unistd.h>
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <list>
|
||||
#include <numeric>
|
||||
|
||||
#include "data_writer.h"
|
||||
#include "util/streaming_util.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
void DataWriter::WriterLoopForward() {
|
||||
STREAMING_CHECK(RuntimeStatus::Running == runtime_context_->GetRuntimeStatus());
|
||||
while (true) {
|
||||
int64_t min_passby_message_ts = std::numeric_limits<int64_t>::max();
|
||||
uint32_t empty_messge_send_count = 0;
|
||||
|
||||
for (auto &output_queue : output_queue_ids_) {
|
||||
if (RuntimeStatus::Running != runtime_context_->GetRuntimeStatus()) {
|
||||
return;
|
||||
}
|
||||
ProducerChannelInfo &channel_info = channel_info_map_[output_queue];
|
||||
bool is_push_empty_message = false;
|
||||
StreamingStatus write_status =
|
||||
WriteChannelProcess(channel_info, &is_push_empty_message);
|
||||
int64_t current_ts = current_time_ms();
|
||||
if (StreamingStatus::OK == write_status) {
|
||||
channel_info.message_pass_by_ts = current_ts;
|
||||
if (is_push_empty_message) {
|
||||
min_passby_message_ts =
|
||||
std::min(channel_info.message_pass_by_ts, min_passby_message_ts);
|
||||
empty_messge_send_count++;
|
||||
}
|
||||
} else if (StreamingStatus::FullChannel == write_status) {
|
||||
} else {
|
||||
if (StreamingStatus::EmptyRingBuffer != write_status) {
|
||||
STREAMING_LOG(DEBUG) << "write buffer status => "
|
||||
<< static_cast<uint32_t>(write_status)
|
||||
<< ", is push empty message => " << is_push_empty_message;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (empty_messge_send_count == output_queue_ids_.size()) {
|
||||
// Sleep if empty message was sent in all channel.
|
||||
uint64_t sleep_time_ = current_time_ms() - min_passby_message_ts;
|
||||
// Sleep_time can be bigger than time interval because of network jitter.
|
||||
if (sleep_time_ <= runtime_context_->GetConfig().GetEmptyMessageTimeInterval()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(
|
||||
runtime_context_->GetConfig().GetEmptyMessageTimeInterval() - sleep_time_));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StreamingStatus DataWriter::WriteChannelProcess(ProducerChannelInfo &channel_info,
|
||||
bool *is_empty_message) {
|
||||
// No message in buffer, empty message will be sent to downstream queue.
|
||||
uint64_t buffer_remain = 0;
|
||||
StreamingStatus write_queue_flag = WriteBufferToChannel(channel_info, buffer_remain);
|
||||
int64_t current_ts = current_time_ms();
|
||||
if (write_queue_flag == StreamingStatus::EmptyRingBuffer &&
|
||||
current_ts - channel_info.message_pass_by_ts >=
|
||||
runtime_context_->GetConfig().GetEmptyMessageTimeInterval()) {
|
||||
write_queue_flag = WriteEmptyMessage(channel_info);
|
||||
*is_empty_message = true;
|
||||
STREAMING_LOG(DEBUG) << "send empty message bundle in q_id =>"
|
||||
<< channel_info.channel_id;
|
||||
}
|
||||
return write_queue_flag;
|
||||
}
|
||||
|
||||
StreamingStatus DataWriter::WriteBufferToChannel(ProducerChannelInfo &channel_info,
|
||||
uint64_t &buffer_remain) {
|
||||
StreamingRingBufferPtr &buffer_ptr = channel_info.writer_ring_buffer;
|
||||
if (!IsMessageAvailableInBuffer(channel_info)) {
|
||||
return StreamingStatus::EmptyRingBuffer;
|
||||
}
|
||||
|
||||
// Flush transient buffer to queue first.
|
||||
if (buffer_ptr->IsTransientAvaliable()) {
|
||||
return WriteTransientBufferToChannel(channel_info);
|
||||
}
|
||||
|
||||
STREAMING_CHECK(CollectFromRingBuffer(channel_info, buffer_remain))
|
||||
<< "empty data in ringbuffer, q id => " << channel_info.channel_id;
|
||||
|
||||
return WriteTransientBufferToChannel(channel_info);
|
||||
}
|
||||
|
||||
void DataWriter::Run() {
|
||||
STREAMING_LOG(INFO) << "WriterLoopForward start";
|
||||
loop_thread_ = std::make_shared<std::thread>(&DataWriter::WriterLoopForward, this);
|
||||
}
|
||||
|
||||
/// Since every memory ring buffer's size is limited, when the writing buffer is
|
||||
/// full, the user thread will be blocked, which will cause backpressure
|
||||
/// naturally.
|
||||
uint64_t DataWriter::WriteMessageToBufferRing(const ObjectID &q_id, uint8_t *data,
|
||||
uint32_t data_size,
|
||||
StreamingMessageType message_type) {
|
||||
STREAMING_LOG(DEBUG) << "WriteMessageToBufferRing q_id: " << q_id
|
||||
<< " data_size: " << data_size;
|
||||
// TODO(lingxuan.zlx): currently, unsafe in multithreads
|
||||
ProducerChannelInfo &channel_info = channel_info_map_[q_id];
|
||||
// Write message id stands for current lastest message id and differs from
|
||||
// channel.current_message_id if it's barrier message.
|
||||
uint64_t &write_message_id = channel_info.current_message_id;
|
||||
write_message_id++;
|
||||
auto &ring_buffer_ptr = channel_info.writer_ring_buffer;
|
||||
while (ring_buffer_ptr->IsFull() &&
|
||||
runtime_context_->GetRuntimeStatus() == RuntimeStatus::Running) {
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::milliseconds(StreamingConfig::TIME_WAIT_UINT));
|
||||
}
|
||||
if (runtime_context_->GetRuntimeStatus() != RuntimeStatus::Running) {
|
||||
STREAMING_LOG(WARNING) << "stop in write message to ringbuffer";
|
||||
return 0;
|
||||
}
|
||||
ring_buffer_ptr->Push(std::make_shared<StreamingMessage>(
|
||||
data, data_size, write_message_id, message_type));
|
||||
|
||||
return write_message_id;
|
||||
}
|
||||
|
||||
StreamingStatus DataWriter::InitChannel(const ObjectID &q_id, const ActorID &actor_id,
|
||||
uint64_t channel_message_id,
|
||||
uint64_t queue_size) {
|
||||
ProducerChannelInfo &channel_info = channel_info_map_[q_id];
|
||||
channel_info.current_message_id = channel_message_id;
|
||||
channel_info.channel_id = q_id;
|
||||
channel_info.actor_id = actor_id;
|
||||
channel_info.queue_size = queue_size;
|
||||
STREAMING_LOG(WARNING) << " Init queue [" << q_id << "]";
|
||||
channel_info.writer_ring_buffer = std::make_shared<StreamingRingBuffer>(
|
||||
runtime_context_->GetConfig().GetRingBufferCapacity(),
|
||||
StreamingRingBufferType::SPSC);
|
||||
channel_info.message_pass_by_ts = current_time_ms();
|
||||
std::shared_ptr<ProducerChannel> channel;
|
||||
|
||||
if (runtime_context_->IsMockTest()) {
|
||||
channel = std::make_shared<MockProducer>(transfer_config_, channel_info);
|
||||
} else {
|
||||
channel = std::make_shared<StreamingQueueProducer>(transfer_config_, channel_info);
|
||||
}
|
||||
|
||||
channel_map_.emplace(q_id, channel);
|
||||
RETURN_IF_NOT_OK(channel->CreateTransferChannel())
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus DataWriter::Init(const std::vector<ObjectID> &queue_id_vec,
|
||||
const std::vector<ActorID> &actor_ids,
|
||||
const std::vector<uint64_t> &channel_message_id_vec,
|
||||
const std::vector<uint64_t> &queue_size_vec) {
|
||||
STREAMING_CHECK(!queue_id_vec.empty() && !channel_message_id_vec.empty());
|
||||
|
||||
ray::JobID job_id =
|
||||
JobID::FromBinary(Util::Hexqid2str(runtime_context_->GetConfig().GetTaskJobId()));
|
||||
|
||||
STREAMING_LOG(INFO) << "Job name => " << runtime_context_->GetConfig().GetJobName()
|
||||
<< ", job id => " << job_id;
|
||||
|
||||
output_queue_ids_ = queue_id_vec;
|
||||
transfer_config_->Set(ConfigEnum::QUEUE_ID_VECTOR, queue_id_vec);
|
||||
|
||||
for (size_t i = 0; i < queue_id_vec.size(); ++i) {
|
||||
StreamingStatus status = InitChannel(queue_id_vec[i], actor_ids[i],
|
||||
channel_message_id_vec[i], queue_size_vec[i]);
|
||||
if (status != StreamingStatus::OK) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
runtime_context_->SetRuntimeStatus(RuntimeStatus::Running);
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
DataWriter::DataWriter(std::shared_ptr<RuntimeContext> &runtime_context)
|
||||
: transfer_config_(new Config()), runtime_context_(runtime_context) {}
|
||||
|
||||
DataWriter::~DataWriter() {
|
||||
// Return if fail to init streaming writer
|
||||
if (runtime_context_->GetRuntimeStatus() == RuntimeStatus::Init) {
|
||||
return;
|
||||
}
|
||||
runtime_context_->SetRuntimeStatus(RuntimeStatus::Interrupted);
|
||||
if (loop_thread_->joinable()) {
|
||||
STREAMING_LOG(INFO) << "Writer loop thread waiting for join";
|
||||
loop_thread_->join();
|
||||
}
|
||||
STREAMING_LOG(INFO) << "Writer client queue disconnect.";
|
||||
}
|
||||
|
||||
bool DataWriter::IsMessageAvailableInBuffer(ProducerChannelInfo &channel_info) {
|
||||
return channel_info.writer_ring_buffer->IsTransientAvaliable() ||
|
||||
!channel_info.writer_ring_buffer->IsEmpty();
|
||||
}
|
||||
|
||||
StreamingStatus DataWriter::WriteEmptyMessage(ProducerChannelInfo &channel_info) {
|
||||
auto &q_id = channel_info.channel_id;
|
||||
if (channel_info.message_last_commit_id < channel_info.current_message_id) {
|
||||
// Abort to send empty message if ring buffer is not empty now.
|
||||
STREAMING_LOG(DEBUG) << "q_id =>" << q_id << " abort to send empty, last commit id =>"
|
||||
<< channel_info.message_last_commit_id << ", channel max id => "
|
||||
<< channel_info.current_message_id;
|
||||
return StreamingStatus::SkipSendEmptyMessage;
|
||||
}
|
||||
|
||||
// Make an empty bundle, use old ts from reloaded meta if it's not nullptr.
|
||||
StreamingMessageBundlePtr bundle_ptr = std::make_shared<StreamingMessageBundle>(
|
||||
channel_info.current_message_id, current_time_ms());
|
||||
auto &q_ringbuffer = channel_info.writer_ring_buffer;
|
||||
q_ringbuffer->ReallocTransientBuffer(bundle_ptr->ClassBytesSize());
|
||||
bundle_ptr->ToBytes(q_ringbuffer->GetTransientBufferMutable());
|
||||
|
||||
StreamingStatus status = channel_map_[q_id]->ProduceItemToChannel(
|
||||
const_cast<uint8_t *>(q_ringbuffer->GetTransientBuffer()),
|
||||
q_ringbuffer->GetTransientBufferSize());
|
||||
STREAMING_LOG(DEBUG) << "q_id =>" << q_id << " send empty message, meta info =>"
|
||||
<< bundle_ptr->ToString();
|
||||
|
||||
q_ringbuffer->FreeTransientBuffer();
|
||||
RETURN_IF_NOT_OK(status)
|
||||
channel_info.current_seq_id++;
|
||||
channel_info.message_pass_by_ts = current_time_ms();
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
StreamingStatus DataWriter::WriteTransientBufferToChannel(
|
||||
ProducerChannelInfo &channel_info) {
|
||||
StreamingRingBufferPtr &buffer_ptr = channel_info.writer_ring_buffer;
|
||||
StreamingStatus status = channel_map_[channel_info.channel_id]->ProduceItemToChannel(
|
||||
buffer_ptr->GetTransientBufferMutable(), buffer_ptr->GetTransientBufferSize());
|
||||
RETURN_IF_NOT_OK(status)
|
||||
channel_info.current_seq_id++;
|
||||
auto transient_bundle_meta =
|
||||
StreamingMessageBundleMeta::FromBytes(buffer_ptr->GetTransientBuffer());
|
||||
bool is_barrier_bundle = transient_bundle_meta->IsBarrier();
|
||||
// Force delete to avoid super block memory isn't released so long
|
||||
// if it's barrier bundle.
|
||||
buffer_ptr->FreeTransientBuffer(is_barrier_bundle);
|
||||
channel_info.message_last_commit_id = transient_bundle_meta->GetLastMessageId();
|
||||
return StreamingStatus::OK;
|
||||
}
|
||||
|
||||
bool DataWriter::CollectFromRingBuffer(ProducerChannelInfo &channel_info,
|
||||
uint64_t &buffer_remain) {
|
||||
StreamingRingBufferPtr &buffer_ptr = channel_info.writer_ring_buffer;
|
||||
auto &q_id = channel_info.channel_id;
|
||||
|
||||
std::list<StreamingMessagePtr> message_list;
|
||||
uint64_t bundle_buffer_size = 0;
|
||||
const uint32_t max_queue_item_size = channel_info.queue_size;
|
||||
while (message_list.size() < runtime_context_->GetConfig().GetRingBufferCapacity() &&
|
||||
!buffer_ptr->IsEmpty()) {
|
||||
StreamingMessagePtr &message_ptr = buffer_ptr->Front();
|
||||
uint32_t message_total_size = message_ptr->ClassBytesSize();
|
||||
if (!message_list.empty() &&
|
||||
bundle_buffer_size + message_total_size >= max_queue_item_size) {
|
||||
STREAMING_LOG(DEBUG) << "message total size " << message_total_size
|
||||
<< " max queue item size => " << max_queue_item_size;
|
||||
break;
|
||||
}
|
||||
if (!message_list.empty() &&
|
||||
message_list.back()->GetMessageType() != message_ptr->GetMessageType()) {
|
||||
break;
|
||||
}
|
||||
// ClassBytesSize = DataSize + MetaDataSize
|
||||
// bundle_buffer_size += message_ptr->GetDataSize();
|
||||
bundle_buffer_size += message_total_size;
|
||||
message_list.push_back(message_ptr);
|
||||
buffer_ptr->Pop();
|
||||
buffer_remain = buffer_ptr->Size();
|
||||
}
|
||||
|
||||
if (bundle_buffer_size >= channel_info.queue_size) {
|
||||
STREAMING_LOG(ERROR) << "bundle buffer is too large to store q id => " << q_id
|
||||
<< ", bundle size => " << bundle_buffer_size
|
||||
<< ", queue size => " << channel_info.queue_size;
|
||||
}
|
||||
|
||||
StreamingMessageBundlePtr bundle_ptr;
|
||||
bundle_ptr = std::make_shared<StreamingMessageBundle>(
|
||||
std::move(message_list), current_time_ms(), message_list.back()->GetMessageSeqId(),
|
||||
StreamingMessageBundleType::Bundle, bundle_buffer_size);
|
||||
buffer_ptr->ReallocTransientBuffer(bundle_ptr->ClassBytesSize());
|
||||
bundle_ptr->ToBytes(buffer_ptr->GetTransientBufferMutable());
|
||||
|
||||
STREAMING_CHECK(bundle_ptr->ClassBytesSize() == buffer_ptr->GetTransientBufferSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
void DataWriter::Stop() {
|
||||
for (auto &output_queue : output_queue_ids_) {
|
||||
ProducerChannelInfo &channel_info = channel_info_map_[output_queue];
|
||||
while (!channel_info.writer_ring_buffer->IsEmpty()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||
}
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(200));
|
||||
runtime_context_->SetRuntimeStatus(RuntimeStatus::Interrupted);
|
||||
}
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
115
streaming/src/data_writer.h
Normal file
115
streaming/src/data_writer.h
Normal file
|
@ -0,0 +1,115 @@
|
|||
#ifndef RAY_DATA_WRITER_H
|
||||
#define RAY_DATA_WRITER_H
|
||||
|
||||
#include <cstring>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "channel.h"
|
||||
#include "config/streaming_config.h"
|
||||
#include "message/message_bundle.h"
|
||||
#include "runtime_context.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
/// DataWriter is designed for data transporting between upstream and downstream.
|
||||
/// After the user sends the data, it does not immediately send the data to
|
||||
/// downstream, but caches it in the corresponding memory ring buffer. There is
|
||||
/// a spearate transfer thread (setup in WriterLoopForward function) to collect
|
||||
/// the messages from all the ringbuffers, and write them to the corresponding
|
||||
/// transmission channels, which is backed by StreamingQueue. Actually, the
|
||||
/// advantage is that the user thread will not be affected by the transmission
|
||||
/// speed during the data transfer. And also the transfer thread can automatically
|
||||
/// batch the catched data from memory buffer into a data bundle to reduce
|
||||
/// transmission overhead. In addtion, when there is no data in the ringbuffer,
|
||||
/// it will also send an empty bundle, so downstream can know that and process
|
||||
/// accordingly. It will sleep for a short interval to save cpu if all ring
|
||||
/// buffers have no data in that moment.
|
||||
class DataWriter {
|
||||
private:
|
||||
std::shared_ptr<std::thread> loop_thread_;
|
||||
// One channel have unique identity.
|
||||
std::vector<ObjectID> output_queue_ids_;
|
||||
|
||||
protected:
|
||||
// ProducerTransfer is middle broker for data transporting.
|
||||
std::unordered_map<ObjectID, ProducerChannelInfo> channel_info_map_;
|
||||
std::unordered_map<ObjectID, std::shared_ptr<ProducerChannel>> channel_map_;
|
||||
std::shared_ptr<Config> transfer_config_;
|
||||
std::shared_ptr<RuntimeContext> runtime_context_;
|
||||
|
||||
private:
|
||||
bool IsMessageAvailableInBuffer(ProducerChannelInfo &channel_info);
|
||||
|
||||
/// This function handles two scenarios. When there is data in the transient
|
||||
/// buffer, the existing data is written into the channel first, otherwise a
|
||||
/// certain amount of message is first collected from the buffer and serialized
|
||||
/// into the transient buffer, and finally written to the channel.
|
||||
/// \\param channel_info
|
||||
/// \\param buffer_remain
|
||||
StreamingStatus WriteBufferToChannel(ProducerChannelInfo &channel_info,
|
||||
uint64_t &buffer_remain);
|
||||
|
||||
/// Start the loop forward thread for collecting messages from all channels.
|
||||
/// Invoking stack:
|
||||
/// WriterLoopForward
|
||||
/// -- WriteChannelProcess
|
||||
/// -- WriteBufferToChannel
|
||||
/// -- CollectFromRingBuffer
|
||||
/// -- WriteTransientBufferToChannel
|
||||
/// -- WriteEmptyMessage(if WriteChannelProcess return empty state)
|
||||
void WriterLoopForward();
|
||||
|
||||
/// Push empty message when no valid message or bundle was produced each time
|
||||
/// interval.
|
||||
/// \param channel_info
|
||||
StreamingStatus WriteEmptyMessage(ProducerChannelInfo &channel_info);
|
||||
|
||||
/// Flush all data from transient buffer to channel for transporting.
|
||||
/// \param channel_info
|
||||
StreamingStatus WriteTransientBufferToChannel(ProducerChannelInfo &channel_info);
|
||||
|
||||
bool CollectFromRingBuffer(ProducerChannelInfo &channel_info, uint64_t &buffer_remain);
|
||||
|
||||
StreamingStatus WriteChannelProcess(ProducerChannelInfo &channel_info,
|
||||
bool *is_empty_message);
|
||||
|
||||
StreamingStatus InitChannel(const ObjectID &q_id, const ActorID &actor_id,
|
||||
uint64_t channel_message_id, uint64_t queue_size);
|
||||
|
||||
public:
|
||||
explicit DataWriter(std::shared_ptr<RuntimeContext> &runtime_context);
|
||||
virtual ~DataWriter();
|
||||
|
||||
/// Streaming writer client initialization.
|
||||
/// \param queue_id_vec queue id vector
|
||||
/// \param channel_message_id_vec channel seq id is related with message checkpoint
|
||||
/// \param queue_size queue size (memory size not length)
|
||||
StreamingStatus Init(const std::vector<ObjectID> &channel_ids,
|
||||
const std::vector<ActorID> &actor_ids,
|
||||
const std::vector<uint64_t> &channel_message_id_vec,
|
||||
const std::vector<uint64_t> &queue_size_vec);
|
||||
|
||||
/// To increase throughout, we employed an output buffer for message transformation,
|
||||
/// which means we merge a lot of message to a message bundle and no message will be
|
||||
/// pushed into queue directly util daemon thread does this action.
|
||||
/// Additionally, writing will block when buffer ring is full intentionly.
|
||||
/// \param q_id
|
||||
/// \param data
|
||||
/// \param data_size
|
||||
/// \param message_type
|
||||
/// \return message seq iq
|
||||
uint64_t WriteMessageToBufferRing(
|
||||
const ObjectID &q_id, uint8_t *data, uint32_t data_size,
|
||||
StreamingMessageType message_type = StreamingMessageType::Message);
|
||||
|
||||
void Run();
|
||||
|
||||
void Stop();
|
||||
};
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
#endif // RAY_DATA_WRITER_H
|
90
streaming/src/message/message.cc
Normal file
90
streaming/src/message/message.cc
Normal file
|
@ -0,0 +1,90 @@
|
|||
#include <utility>
|
||||
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
|
||||
#include "message.h"
|
||||
#include "ray/common/status.h"
|
||||
#include "util/streaming_logging.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
StreamingMessage::StreamingMessage(std::shared_ptr<uint8_t> &data, uint32_t data_size,
|
||||
uint64_t seq_id, StreamingMessageType message_type)
|
||||
: message_data_(data),
|
||||
data_size_(data_size),
|
||||
message_type_(message_type),
|
||||
message_id_(seq_id) {}
|
||||
|
||||
StreamingMessage::StreamingMessage(std::shared_ptr<uint8_t> &&data, uint32_t data_size,
|
||||
uint64_t seq_id, StreamingMessageType message_type)
|
||||
: message_data_(data),
|
||||
data_size_(data_size),
|
||||
message_type_(message_type),
|
||||
message_id_(seq_id) {}
|
||||
|
||||
StreamingMessage::StreamingMessage(const uint8_t *data, uint32_t data_size,
|
||||
uint64_t seq_id, StreamingMessageType message_type)
|
||||
: data_size_(data_size), message_type_(message_type), message_id_(seq_id) {
|
||||
message_data_.reset(new uint8_t[data_size], std::default_delete<uint8_t[]>());
|
||||
std::memcpy(message_data_.get(), data, data_size_);
|
||||
}
|
||||
|
||||
StreamingMessage::StreamingMessage(const StreamingMessage &msg) {
|
||||
data_size_ = msg.data_size_;
|
||||
message_data_ = msg.message_data_;
|
||||
message_id_ = msg.message_id_;
|
||||
message_type_ = msg.message_type_;
|
||||
}
|
||||
|
||||
StreamingMessagePtr StreamingMessage::FromBytes(const uint8_t *bytes,
|
||||
bool verifer_check) {
|
||||
uint32_t byte_offset = 0;
|
||||
uint32_t data_size = *reinterpret_cast<const uint32_t *>(bytes + byte_offset);
|
||||
byte_offset += sizeof(data_size);
|
||||
|
||||
uint64_t seq_id = *reinterpret_cast<const uint64_t *>(bytes + byte_offset);
|
||||
byte_offset += sizeof(seq_id);
|
||||
|
||||
StreamingMessageType msg_type =
|
||||
*reinterpret_cast<const StreamingMessageType *>(bytes + byte_offset);
|
||||
byte_offset += sizeof(msg_type);
|
||||
|
||||
auto buf = new uint8_t[data_size];
|
||||
std::memcpy(buf, bytes + byte_offset, data_size);
|
||||
auto data_ptr = std::shared_ptr<uint8_t>(buf, std::default_delete<uint8_t[]>());
|
||||
return std::make_shared<StreamingMessage>(data_ptr, data_size, seq_id, msg_type);
|
||||
}
|
||||
|
||||
void StreamingMessage::ToBytes(uint8_t *serlizable_data) {
|
||||
uint32_t byte_offset = 0;
|
||||
std::memcpy(serlizable_data + byte_offset, reinterpret_cast<char *>(&data_size_),
|
||||
sizeof(data_size_));
|
||||
byte_offset += sizeof(data_size_);
|
||||
|
||||
std::memcpy(serlizable_data + byte_offset, reinterpret_cast<char *>(&message_id_),
|
||||
sizeof(message_id_));
|
||||
byte_offset += sizeof(message_id_);
|
||||
|
||||
std::memcpy(serlizable_data + byte_offset, reinterpret_cast<char *>(&message_type_),
|
||||
sizeof(message_type_));
|
||||
byte_offset += sizeof(message_type_);
|
||||
|
||||
std::memcpy(serlizable_data + byte_offset,
|
||||
reinterpret_cast<char *>(message_data_.get()), data_size_);
|
||||
|
||||
byte_offset += data_size_;
|
||||
|
||||
STREAMING_CHECK(byte_offset == this->ClassBytesSize());
|
||||
}
|
||||
|
||||
bool StreamingMessage::operator==(const StreamingMessage &message) const {
|
||||
return GetDataSize() == message.GetDataSize() &&
|
||||
GetMessageSeqId() == message.GetMessageSeqId() &&
|
||||
GetMessageType() == message.GetMessageType() &&
|
||||
!std::memcmp(RawData(), message.RawData(), data_size_);
|
||||
}
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
93
streaming/src/message/message.h
Normal file
93
streaming/src/message/message.h
Normal file
|
@ -0,0 +1,93 @@
|
|||
#ifndef RAY_MESSAGE_H
|
||||
#define RAY_MESSAGE_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
class StreamingMessage;
|
||||
|
||||
typedef std::shared_ptr<StreamingMessage> StreamingMessagePtr;
|
||||
|
||||
enum class StreamingMessageType : uint32_t {
|
||||
Barrier = 1,
|
||||
Message = 2,
|
||||
MIN = Barrier,
|
||||
MAX = Message
|
||||
};
|
||||
|
||||
constexpr uint32_t kMessageHeaderSize =
|
||||
sizeof(uint32_t) + sizeof(uint64_t) + sizeof(StreamingMessageType);
|
||||
|
||||
/// All messages should be wrapped by this protocol.
|
||||
// DataSize means length of raw data, message id is increasing from [1, +INF].
|
||||
// MessageType will be used for barrier transporting and checkpoint.
|
||||
/// +----------------+
|
||||
/// | DataSize=U32 |
|
||||
/// +----------------+
|
||||
/// | MessageId=U64 |
|
||||
/// +----------------+
|
||||
/// | MessageType=U32|
|
||||
/// +----------------+
|
||||
/// | Data=var |
|
||||
/// +----------------+
|
||||
|
||||
class StreamingMessage {
|
||||
private:
|
||||
std::shared_ptr<uint8_t> message_data_;
|
||||
uint32_t data_size_;
|
||||
StreamingMessageType message_type_;
|
||||
uint64_t message_id_;
|
||||
|
||||
public:
|
||||
/// Copy raw data from outside shared buffer.
|
||||
/// \param data raw data from user buffer
|
||||
/// \param data_size raw data size
|
||||
/// \param seq_id message id
|
||||
/// \param message_type
|
||||
StreamingMessage(std::shared_ptr<uint8_t> &data, uint32_t data_size, uint64_t seq_id,
|
||||
StreamingMessageType message_type);
|
||||
|
||||
/// Move outsite raw data to message data.
|
||||
/// \param data raw data from user buffer
|
||||
/// \param data_size raw data size
|
||||
/// \param seq_id message id
|
||||
/// \param message_type
|
||||
StreamingMessage(std::shared_ptr<uint8_t> &&data, uint32_t data_size, uint64_t seq_id,
|
||||
StreamingMessageType message_type);
|
||||
|
||||
/// Copy raw data from outside buffer.
|
||||
/// \param data raw data from user buffer
|
||||
/// \param data_size raw data size
|
||||
/// \param seq_id message id
|
||||
/// \param message_type
|
||||
StreamingMessage(const uint8_t *data, uint32_t data_size, uint64_t seq_id,
|
||||
StreamingMessageType message_type);
|
||||
|
||||
StreamingMessage(const StreamingMessage &);
|
||||
|
||||
StreamingMessage operator=(const StreamingMessage &) = delete;
|
||||
|
||||
virtual ~StreamingMessage() = default;
|
||||
|
||||
inline uint8_t *RawData() const { return message_data_.get(); }
|
||||
|
||||
inline uint32_t GetDataSize() const { return data_size_; }
|
||||
inline StreamingMessageType GetMessageType() const { return message_type_; }
|
||||
inline uint64_t GetMessageSeqId() const { return message_id_; }
|
||||
inline bool IsMessage() { return StreamingMessageType::Message == message_type_; }
|
||||
inline bool IsBarrier() { return StreamingMessageType::Barrier == message_type_; }
|
||||
|
||||
bool operator==(const StreamingMessage &) const;
|
||||
|
||||
virtual void ToBytes(uint8_t *data);
|
||||
static StreamingMessagePtr FromBytes(const uint8_t *data, bool verifer_check = true);
|
||||
|
||||
inline virtual uint32_t ClassBytesSize() { return kMessageHeaderSize + data_size_; }
|
||||
};
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_MESSAGE_H
|
236
streaming/src/message/message_bundle.cc
Normal file
236
streaming/src/message/message_bundle.cc
Normal file
|
@ -0,0 +1,236 @@
|
|||
#include <cstring>
|
||||
#include <string>
|
||||
|
||||
#include "ray/common/status.h"
|
||||
|
||||
#include "config/streaming_config.h"
|
||||
#include "message_bundle.h"
|
||||
#include "util/streaming_logging.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
StreamingMessageBundle::StreamingMessageBundle(uint64_t last_offset_seq_id,
|
||||
uint64_t message_bundle_ts)
|
||||
: StreamingMessageBundleMeta(message_bundle_ts, last_offset_seq_id, 0,
|
||||
StreamingMessageBundleType::Empty) {
|
||||
this->raw_bundle_size_ = 0;
|
||||
}
|
||||
|
||||
StreamingMessageBundleMeta::StreamingMessageBundleMeta(
|
||||
uint64_t message_bundle_ts, uint64_t last_offset_seq_id, uint32_t message_list_size,
|
||||
StreamingMessageBundleType bundle_type)
|
||||
: message_bundle_ts_(message_bundle_ts),
|
||||
last_message_id_(last_offset_seq_id),
|
||||
message_list_size_(message_list_size),
|
||||
bundle_type_(bundle_type) {
|
||||
STREAMING_CHECK(message_list_size <= StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE);
|
||||
}
|
||||
|
||||
void StreamingMessageBundleMeta::ToBytes(uint8_t *bytes) {
|
||||
uint32_t byte_offset = 0;
|
||||
|
||||
uint32_t magicNum = StreamingMessageBundleMeta::StreamingMessageBundleMagicNum;
|
||||
std::memcpy(bytes + byte_offset, reinterpret_cast<const uint8_t *>(&magicNum),
|
||||
sizeof(uint32_t));
|
||||
byte_offset += sizeof(uint32_t);
|
||||
|
||||
std::memcpy(bytes + byte_offset, reinterpret_cast<const uint8_t *>(&message_bundle_ts_),
|
||||
sizeof(uint64_t));
|
||||
byte_offset += sizeof(uint64_t);
|
||||
|
||||
std::memcpy(bytes + byte_offset, reinterpret_cast<const uint8_t *>(&last_message_id_),
|
||||
sizeof(uint64_t));
|
||||
byte_offset += sizeof(uint64_t);
|
||||
|
||||
std::memcpy(bytes + byte_offset, reinterpret_cast<const uint8_t *>(&message_list_size_),
|
||||
sizeof(uint32_t));
|
||||
byte_offset += sizeof(uint32_t);
|
||||
|
||||
std::memcpy(bytes + byte_offset, reinterpret_cast<const uint8_t *>(&bundle_type_),
|
||||
sizeof(StreamingMessageBundleType));
|
||||
byte_offset += sizeof(StreamingMessageBundleType);
|
||||
}
|
||||
|
||||
StreamingMessageBundleMetaPtr StreamingMessageBundleMeta::FromBytes(const uint8_t *bytes,
|
||||
bool check) {
|
||||
STREAMING_CHECK(bytes);
|
||||
|
||||
uint32_t byte_offset = 0;
|
||||
const uint32_t magic_num = *reinterpret_cast<const uint32_t *>(bytes + byte_offset);
|
||||
|
||||
if (magic_num != StreamingMessageBundleMagicNum) {
|
||||
STREAMING_LOG(INFO) << "Magic Number => " << magic_num;
|
||||
}
|
||||
|
||||
STREAMING_CHECK(magic_num == StreamingMessageBundleMagicNum);
|
||||
byte_offset += sizeof(uint32_t);
|
||||
|
||||
uint64_t message_bundle_ts = *reinterpret_cast<const uint64_t *>(bytes + byte_offset);
|
||||
byte_offset += sizeof(uint64_t);
|
||||
|
||||
uint64_t last_message_id = *reinterpret_cast<const uint64_t *>(bytes + byte_offset);
|
||||
byte_offset += sizeof(uint64_t);
|
||||
|
||||
uint32_t messageListSize = *reinterpret_cast<const uint32_t *>(bytes + byte_offset);
|
||||
byte_offset += sizeof(uint32_t);
|
||||
STREAMING_LOG(DEBUG) << "ts => " << message_bundle_ts << " last message id => "
|
||||
<< last_message_id << " message size => " << messageListSize;
|
||||
|
||||
STREAMING_CHECK(messageListSize <= StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE);
|
||||
|
||||
StreamingMessageBundleType messageBundleType =
|
||||
*reinterpret_cast<const StreamingMessageBundleType *>(bytes + byte_offset);
|
||||
byte_offset += sizeof(StreamingMessageBundleType);
|
||||
|
||||
auto result = std::make_shared<StreamingMessageBundleMeta>(
|
||||
message_bundle_ts, last_message_id, messageListSize, messageBundleType);
|
||||
STREAMING_CHECK(byte_offset == result->ClassBytesSize());
|
||||
return result;
|
||||
}
|
||||
|
||||
bool StreamingMessageBundleMeta::operator==(StreamingMessageBundleMeta &meta) const {
|
||||
return this->message_list_size_ == meta.GetMessageListSize() &&
|
||||
this->message_bundle_ts_ == meta.GetMessageBundleTs() &&
|
||||
this->bundle_type_ == meta.GetBundleType() &&
|
||||
this->last_message_id_ == meta.GetLastMessageId();
|
||||
}
|
||||
|
||||
bool StreamingMessageBundleMeta::operator==(StreamingMessageBundleMeta *meta) const {
|
||||
return operator==(*meta);
|
||||
}
|
||||
|
||||
StreamingMessageBundleMeta::StreamingMessageBundleMeta(
|
||||
StreamingMessageBundleMeta *meta_ptr) {
|
||||
bundle_type_ = meta_ptr->bundle_type_;
|
||||
last_message_id_ = meta_ptr->last_message_id_;
|
||||
message_bundle_ts_ = meta_ptr->message_bundle_ts_;
|
||||
message_list_size_ = meta_ptr->message_list_size_;
|
||||
}
|
||||
|
||||
StreamingMessageBundleMeta::StreamingMessageBundleMeta()
|
||||
: bundle_type_(StreamingMessageBundleType::Empty) {}
|
||||
|
||||
StreamingMessageBundle::StreamingMessageBundle(
|
||||
std::list<StreamingMessagePtr> &&message_list, uint64_t message_ts,
|
||||
uint64_t last_offset_seq_id, StreamingMessageBundleType bundle_type,
|
||||
uint32_t raw_data_size)
|
||||
: StreamingMessageBundleMeta(message_ts, last_offset_seq_id, message_list.size(),
|
||||
bundle_type),
|
||||
raw_bundle_size_(raw_data_size),
|
||||
message_list_(message_list) {
|
||||
if (bundle_type_ != StreamingMessageBundleType::Empty) {
|
||||
if (!raw_bundle_size_) {
|
||||
raw_bundle_size_ = std::accumulate(
|
||||
message_list_.begin(), message_list_.end(), 0,
|
||||
[](uint32_t x, StreamingMessagePtr &y) { return x + y->ClassBytesSize(); });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StreamingMessageBundle::StreamingMessageBundle(
|
||||
std::list<StreamingMessagePtr> &message_list, uint64_t message_ts,
|
||||
uint64_t last_offset_seq_id, StreamingMessageBundleType bundle_type,
|
||||
uint32_t raw_data_size)
|
||||
: StreamingMessageBundle(std::list<StreamingMessagePtr>(message_list), message_ts,
|
||||
last_offset_seq_id, bundle_type, raw_data_size) {}
|
||||
|
||||
StreamingMessageBundle::StreamingMessageBundle(StreamingMessageBundle &bundle) {
|
||||
message_bundle_ts_ = bundle.message_bundle_ts_;
|
||||
message_list_size_ = bundle.message_list_size_;
|
||||
raw_bundle_size_ = bundle.raw_bundle_size_;
|
||||
bundle_type_ = bundle.bundle_type_;
|
||||
last_message_id_ = bundle.last_message_id_;
|
||||
message_list_ = bundle.message_list_;
|
||||
}
|
||||
|
||||
void StreamingMessageBundle::ToBytes(uint8_t *bytes) {
|
||||
uint32_t byte_offset = 0;
|
||||
StreamingMessageBundleMeta::ToBytes(bytes + byte_offset);
|
||||
|
||||
byte_offset += StreamingMessageBundleMeta::ClassBytesSize();
|
||||
|
||||
std::memcpy(bytes + byte_offset, reinterpret_cast<char *>(&raw_bundle_size_),
|
||||
sizeof(uint32_t));
|
||||
byte_offset += sizeof(uint32_t);
|
||||
|
||||
if (raw_bundle_size_ > 0) {
|
||||
ConvertMessageListToRawData(message_list_, raw_bundle_size_, bytes + byte_offset);
|
||||
}
|
||||
}
|
||||
|
||||
StreamingMessageBundlePtr StreamingMessageBundle::FromBytes(const uint8_t *bytes,
|
||||
bool verifer_check) {
|
||||
uint32_t byte_offset = 0;
|
||||
StreamingMessageBundleMetaPtr meta_ptr =
|
||||
StreamingMessageBundleMeta::FromBytes(bytes + byte_offset);
|
||||
byte_offset += meta_ptr->ClassBytesSize();
|
||||
|
||||
uint32_t raw_data_size = *reinterpret_cast<const uint32_t *>(bytes + byte_offset);
|
||||
byte_offset += sizeof(uint32_t);
|
||||
|
||||
std::list<StreamingMessagePtr> message_list;
|
||||
// only message bundle own raw data
|
||||
if (meta_ptr->GetBundleType() != StreamingMessageBundleType::Empty) {
|
||||
GetMessageListFromRawData(bytes + byte_offset, raw_data_size,
|
||||
meta_ptr->GetMessageListSize(), message_list);
|
||||
byte_offset += raw_data_size;
|
||||
}
|
||||
auto result = std::make_shared<StreamingMessageBundle>(
|
||||
message_list, meta_ptr->GetMessageBundleTs(), meta_ptr->GetLastMessageId(),
|
||||
meta_ptr->GetBundleType());
|
||||
STREAMING_CHECK(byte_offset == result->ClassBytesSize());
|
||||
return result;
|
||||
}
|
||||
|
||||
void StreamingMessageBundle::GetMessageListFromRawData(
|
||||
const uint8_t *bytes, uint32_t byte_size, uint32_t message_list_size,
|
||||
std::list<StreamingMessagePtr> &message_list) {
|
||||
uint32_t byte_offset = 0;
|
||||
// only message bundle own raw data
|
||||
for (size_t i = 0; i < message_list_size; ++i) {
|
||||
StreamingMessagePtr item = StreamingMessage::FromBytes(bytes + byte_offset);
|
||||
message_list.push_back(item);
|
||||
byte_offset += item->ClassBytesSize();
|
||||
}
|
||||
STREAMING_CHECK(byte_offset == byte_size);
|
||||
}
|
||||
|
||||
void StreamingMessageBundle::GetMessageList(
|
||||
std::list<StreamingMessagePtr> &message_list) {
|
||||
message_list = message_list_;
|
||||
}
|
||||
|
||||
void StreamingMessageBundle::ConvertMessageListToRawData(
|
||||
const std::list<StreamingMessagePtr> &message_list, uint32_t raw_data_size,
|
||||
uint8_t *raw_data) {
|
||||
uint32_t byte_offset = 0;
|
||||
for (auto &message : message_list) {
|
||||
message->ToBytes(raw_data + byte_offset);
|
||||
byte_offset += message->ClassBytesSize();
|
||||
}
|
||||
STREAMING_CHECK(byte_offset == raw_data_size);
|
||||
}
|
||||
|
||||
bool StreamingMessageBundle::operator==(StreamingMessageBundle &bundle) const {
|
||||
if (!(StreamingMessageBundleMeta::operator==(&bundle) &&
|
||||
this->GetRawBundleSize() == bundle.GetRawBundleSize() &&
|
||||
this->GetMessageListSize() == bundle.GetMessageListSize())) {
|
||||
return false;
|
||||
}
|
||||
auto it1 = message_list_.begin();
|
||||
auto it2 = bundle.message_list_.begin();
|
||||
while (it1 != message_list_.end() && it2 != bundle.message_list_.end()) {
|
||||
if (!((*it1).get()->operator==(*(*it2).get()))) {
|
||||
return false;
|
||||
}
|
||||
it1++;
|
||||
it2++;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StreamingMessageBundle::operator==(StreamingMessageBundle *bundle) const {
|
||||
return this->operator==(*bundle);
|
||||
}
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
164
streaming/src/message/message_bundle.h
Normal file
164
streaming/src/message/message_bundle.h
Normal file
|
@ -0,0 +1,164 @@
|
|||
#ifndef RAY_MESSAGE_BUNDLE_H
|
||||
#define RAY_MESSAGE_BUNDLE_H
|
||||
|
||||
#include <ctime>
|
||||
#include <list>
|
||||
#include <numeric>
|
||||
|
||||
#include "message.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
enum class StreamingMessageBundleType : uint32_t {
|
||||
Empty = 1,
|
||||
Barrier = 2,
|
||||
Bundle = 3,
|
||||
MIN = Empty,
|
||||
MAX = Bundle
|
||||
};
|
||||
|
||||
class StreamingMessageBundleMeta;
|
||||
class StreamingMessageBundle;
|
||||
|
||||
typedef std::shared_ptr<StreamingMessageBundle> StreamingMessageBundlePtr;
|
||||
typedef std::shared_ptr<StreamingMessageBundleMeta> StreamingMessageBundleMetaPtr;
|
||||
|
||||
constexpr uint32_t kMessageBundleMetaHeaderSize = sizeof(uint32_t) + sizeof(uint32_t) +
|
||||
sizeof(uint64_t) + sizeof(uint64_t) +
|
||||
sizeof(StreamingMessageBundleType);
|
||||
|
||||
constexpr uint32_t kMessageBundleHeaderSize =
|
||||
kMessageBundleMetaHeaderSize + sizeof(uint32_t);
|
||||
|
||||
class StreamingMessageBundleMeta {
|
||||
public:
|
||||
static const uint32_t StreamingMessageBundleMagicNum = 0xCAFEBABA;
|
||||
|
||||
protected:
|
||||
uint64_t message_bundle_ts_;
|
||||
|
||||
uint64_t last_message_id_;
|
||||
|
||||
uint32_t message_list_size_;
|
||||
|
||||
StreamingMessageBundleType bundle_type_;
|
||||
|
||||
public:
|
||||
explicit StreamingMessageBundleMeta(uint64_t, uint64_t, uint32_t,
|
||||
StreamingMessageBundleType);
|
||||
|
||||
explicit StreamingMessageBundleMeta(StreamingMessageBundleMeta *);
|
||||
|
||||
explicit StreamingMessageBundleMeta();
|
||||
|
||||
virtual ~StreamingMessageBundleMeta(){};
|
||||
|
||||
bool operator==(StreamingMessageBundleMeta &) const;
|
||||
|
||||
bool operator==(StreamingMessageBundleMeta *) const;
|
||||
|
||||
inline uint64_t GetMessageBundleTs() const { return message_bundle_ts_; }
|
||||
|
||||
inline uint64_t GetLastMessageId() const { return last_message_id_; }
|
||||
|
||||
inline uint32_t GetMessageListSize() const { return message_list_size_; }
|
||||
|
||||
inline StreamingMessageBundleType GetBundleType() const { return bundle_type_; }
|
||||
|
||||
inline bool IsBarrier() { return StreamingMessageBundleType::Barrier == bundle_type_; }
|
||||
inline bool IsBundle() { return StreamingMessageBundleType::Bundle == bundle_type_; }
|
||||
|
||||
virtual void ToBytes(uint8_t *data);
|
||||
static StreamingMessageBundleMetaPtr FromBytes(const uint8_t *data,
|
||||
bool verifer_check = true);
|
||||
inline virtual uint32_t ClassBytesSize() { return kMessageBundleMetaHeaderSize; }
|
||||
|
||||
std::string ToString() {
|
||||
return std::to_string(last_message_id_) + "," + std::to_string(message_list_size_) +
|
||||
"," + std::to_string(message_bundle_ts_) + "," +
|
||||
std::to_string(static_cast<uint32_t>(bundle_type_));
|
||||
}
|
||||
};
|
||||
|
||||
/// StreamingMessageBundle inherits from metadata class (StreamingMessageBundleMeta) with
|
||||
/// the following protocol:
|
||||
/// MagicNum = 0xcafebaba
|
||||
/// Timestamp 64bits timestamp (milliseconds from 1970)
|
||||
/// LastMessageId( the last id of bundle) (0,INF]
|
||||
/// MessageListSize(bundle len of message)
|
||||
/// BundleType(a. bundle = 3 , b. barrier =2, c. empty = 1)
|
||||
/// RawBundleSize(binary length of data)
|
||||
/// RawData ( binary data)
|
||||
///
|
||||
/// +--------------------+
|
||||
/// | MagicNum=U32 |
|
||||
/// +--------------------+
|
||||
/// | BundleTs=U64 |
|
||||
/// +--------------------+
|
||||
/// | LastMessageId=U64 |
|
||||
/// +--------------------+
|
||||
/// | MessageListSize=U32|
|
||||
/// +--------------------+
|
||||
/// | BundleType=U32 |
|
||||
/// +--------------------+
|
||||
/// | RawBundleSize=U32 |
|
||||
/// +--------------------+
|
||||
/// | RawData=var(N*Msg) |
|
||||
/// +--------------------+
|
||||
/// It should be noted that StreamingMessageBundle and StreamingMessageBundleMeta share
|
||||
/// almost same protocol but the last two fields (RawBundleSize and RawData).
|
||||
class StreamingMessageBundle : public StreamingMessageBundleMeta {
|
||||
private:
|
||||
uint32_t raw_bundle_size_;
|
||||
|
||||
// Lazy serlization/deserlization.
|
||||
std::list<StreamingMessagePtr> message_list_;
|
||||
|
||||
public:
|
||||
explicit StreamingMessageBundle(std::list<StreamingMessagePtr> &&message_list,
|
||||
uint64_t bundle_ts, uint64_t offset,
|
||||
StreamingMessageBundleType bundle_type,
|
||||
uint32_t raw_data_size = 0);
|
||||
|
||||
// Duplicated copy if left reference in constructor.
|
||||
explicit StreamingMessageBundle(std::list<StreamingMessagePtr> &message_list,
|
||||
uint64_t bundle_ts, uint64_t offset,
|
||||
StreamingMessageBundleType bundle_type,
|
||||
uint32_t raw_data_size = 0);
|
||||
|
||||
// New a empty bundle by passing last message id and timestamp.
|
||||
explicit StreamingMessageBundle(uint64_t, uint64_t);
|
||||
|
||||
explicit StreamingMessageBundle(StreamingMessageBundle &bundle);
|
||||
|
||||
virtual ~StreamingMessageBundle() = default;
|
||||
|
||||
inline uint32_t GetRawBundleSize() const { return raw_bundle_size_; }
|
||||
|
||||
bool operator==(StreamingMessageBundle &bundle) const;
|
||||
|
||||
bool operator==(StreamingMessageBundle *bundle_ptr) const;
|
||||
|
||||
void GetMessageList(std::list<StreamingMessagePtr> &message_list);
|
||||
|
||||
const std::list<StreamingMessagePtr> &GetMessageList() const { return message_list_; }
|
||||
|
||||
virtual void ToBytes(uint8_t *data);
|
||||
static StreamingMessageBundlePtr FromBytes(const uint8_t *data,
|
||||
bool verifer_check = true);
|
||||
inline virtual uint32_t ClassBytesSize() {
|
||||
return kMessageBundleHeaderSize + raw_bundle_size_;
|
||||
};
|
||||
|
||||
static void GetMessageListFromRawData(const uint8_t *bytes, uint32_t bytes_size,
|
||||
uint32_t message_list_size,
|
||||
std::list<StreamingMessagePtr> &message_list);
|
||||
static void ConvertMessageListToRawData(
|
||||
const std::list<StreamingMessagePtr> &message_list, uint32_t raw_data_size,
|
||||
uint8_t *raw_data);
|
||||
};
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_MESSAGE_BUNDLE_H
|
53
streaming/src/message/priority_queue.h
Normal file
53
streaming/src/message/priority_queue.h
Normal file
|
@ -0,0 +1,53 @@
|
|||
#ifndef RAY_PRIORITY_QUEUE_H
|
||||
#define RAY_PRIORITY_QUEUE_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "util/streaming_logging.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
template <class T, class C>
|
||||
|
||||
class PriorityQueue {
|
||||
private:
|
||||
std::vector<T> merge_vec_;
|
||||
C comparator_;
|
||||
|
||||
public:
|
||||
PriorityQueue(C &comparator) : comparator_(comparator){};
|
||||
|
||||
inline void push(T &&item) {
|
||||
merge_vec_.push_back(std::forward<T>(item));
|
||||
std::push_heap(merge_vec_.begin(), merge_vec_.end(), comparator_);
|
||||
}
|
||||
|
||||
inline void push(const T &item) {
|
||||
merge_vec_.push_back(item);
|
||||
std::push_heap(merge_vec_.begin(), merge_vec_.end(), comparator_);
|
||||
}
|
||||
|
||||
inline void pop() {
|
||||
STREAMING_CHECK(!isEmpty());
|
||||
std::pop_heap(merge_vec_.begin(), merge_vec_.end(), comparator_);
|
||||
merge_vec_.pop_back();
|
||||
}
|
||||
|
||||
inline void makeHeap() {
|
||||
std::make_heap(merge_vec_.begin(), merge_vec_.end(), comparator_);
|
||||
}
|
||||
|
||||
inline T &top() { return merge_vec_.front(); }
|
||||
|
||||
inline uint32_t size() { return merge_vec_.size(); }
|
||||
|
||||
inline bool isEmpty() { return merge_vec_.empty(); }
|
||||
|
||||
std::vector<T> &getRawVector() { return merge_vec_; }
|
||||
};
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_PRIORITY_QUEUE_H
|
23
streaming/src/protobuf/streaming.proto
Normal file
23
streaming/src/protobuf/streaming.proto
Normal file
|
@ -0,0 +1,23 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package ray.streaming.proto;
|
||||
|
||||
option java_package = "org.ray.streaming.runtime.generated";
|
||||
|
||||
enum OperatorType {
|
||||
UNKNOWN = 0;
|
||||
TRANSFORM = 1;
|
||||
SOURCE = 2;
|
||||
SINK = 3;
|
||||
}
|
||||
|
||||
// all string in this message is ASCII string
|
||||
message StreamingConfig {
|
||||
string job_name = 1;
|
||||
string task_job_id = 2;
|
||||
string worker_name = 3;
|
||||
string op_name = 4;
|
||||
OperatorType role = 5;
|
||||
uint32 ring_buffer_capacity = 6;
|
||||
uint32 empty_message_interval = 7;
|
||||
}
|
70
streaming/src/protobuf/streaming_queue.proto
Normal file
70
streaming/src/protobuf/streaming_queue.proto
Normal file
|
@ -0,0 +1,70 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package ray.streaming.queue.protobuf;
|
||||
|
||||
enum StreamingQueueMessageType {
|
||||
StreamingQueueDataMsgType = 0;
|
||||
StreamingQueueCheckMsgType = 1;
|
||||
StreamingQueueCheckRspMsgType = 2;
|
||||
StreamingQueueNotificationMsgType = 3;
|
||||
StreamingQueueTestInitMsgType = 4;
|
||||
StreamingQueueTestCheckStatusRspMsgType = 5;
|
||||
}
|
||||
|
||||
enum StreamingQueueError {
|
||||
OK = 0;
|
||||
QUEUE_NOT_EXIST = 1;
|
||||
NO_VALID_DATA_TO_PULL = 2;
|
||||
}
|
||||
|
||||
message StreamingQueueDataMsg {
|
||||
bytes src_actor_id = 1;
|
||||
bytes dst_actor_id = 2;
|
||||
bytes queue_id = 3;
|
||||
uint64 seq_id = 4;
|
||||
uint64 length = 5;
|
||||
bool raw = 6;
|
||||
}
|
||||
|
||||
message StreamingQueueCheckMsg {
|
||||
bytes src_actor_id = 1;
|
||||
bytes dst_actor_id = 2;
|
||||
bytes queue_id = 3;
|
||||
}
|
||||
|
||||
message StreamingQueueCheckRspMsg {
|
||||
bytes src_actor_id = 1;
|
||||
bytes dst_actor_id = 2;
|
||||
bytes queue_id = 3;
|
||||
StreamingQueueError err_code = 4;
|
||||
}
|
||||
|
||||
message StreamingQueueNotificationMsg {
|
||||
bytes src_actor_id = 1;
|
||||
bytes dst_actor_id = 2;
|
||||
bytes queue_id = 3;
|
||||
uint64 seq_id = 4;
|
||||
}
|
||||
|
||||
// for test
|
||||
enum StreamingQueueTestRole {
|
||||
WRITER = 0;
|
||||
READER = 1;
|
||||
}
|
||||
|
||||
message StreamingQueueTestInitMsg {
|
||||
StreamingQueueTestRole role = 1;
|
||||
bytes src_actor_id = 2;
|
||||
bytes dst_actor_id = 3;
|
||||
bytes actor_handle = 4;
|
||||
repeated bytes queue_ids = 5;
|
||||
repeated bytes rescale_queue_ids = 6;
|
||||
string test_suite_name = 7;
|
||||
string test_name = 8;
|
||||
uint64 param = 9;
|
||||
}
|
||||
|
||||
message StreamingQueueTestCheckStatusRspMsg {
|
||||
string test_name = 1;
|
||||
bool status = 2;
|
||||
}
|
240
streaming/src/queue/message.cc
Normal file
240
streaming/src/queue/message.cc
Normal file
|
@ -0,0 +1,240 @@
|
|||
#include "message.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
const uint32_t Message::MagicNum = 0xBABA0510;
|
||||
|
||||
std::unique_ptr<LocalMemoryBuffer> Message::ToBytes() {
|
||||
uint8_t *bytes = nullptr;
|
||||
|
||||
std::string pboutput;
|
||||
ToProtobuf(&pboutput);
|
||||
int64_t fbs_length = pboutput.length();
|
||||
|
||||
queue::protobuf::StreamingQueueMessageType type = Type();
|
||||
size_t total_len =
|
||||
sizeof(Message::MagicNum) + sizeof(type) + sizeof(fbs_length) + fbs_length;
|
||||
if (buffer_ != nullptr) {
|
||||
total_len += buffer_->Size();
|
||||
}
|
||||
bytes = new uint8_t[total_len];
|
||||
STREAMING_CHECK(bytes != nullptr) << "allocate bytes fail.";
|
||||
|
||||
uint8_t *p_cur = bytes;
|
||||
memcpy(p_cur, &Message::MagicNum, sizeof(Message::MagicNum));
|
||||
|
||||
p_cur += sizeof(Message::MagicNum);
|
||||
memcpy(p_cur, &type, sizeof(type));
|
||||
|
||||
p_cur += sizeof(type);
|
||||
memcpy(p_cur, &fbs_length, sizeof(fbs_length));
|
||||
|
||||
p_cur += sizeof(fbs_length);
|
||||
uint8_t *fbs_bytes = (uint8_t *)pboutput.data();
|
||||
memcpy(p_cur, fbs_bytes, fbs_length);
|
||||
p_cur += fbs_length;
|
||||
|
||||
if (buffer_ != nullptr) {
|
||||
memcpy(p_cur, buffer_->Data(), buffer_->Size());
|
||||
}
|
||||
|
||||
// COPY
|
||||
std::unique_ptr<LocalMemoryBuffer> buffer =
|
||||
std::unique_ptr<LocalMemoryBuffer>(new LocalMemoryBuffer(bytes, total_len, true));
|
||||
delete bytes;
|
||||
return buffer;
|
||||
}
|
||||
|
||||
void DataMessage::ToProtobuf(std::string *output) {
|
||||
queue::protobuf::StreamingQueueDataMsg msg;
|
||||
msg.set_src_actor_id(actor_id_.Binary());
|
||||
msg.set_dst_actor_id(peer_actor_id_.Binary());
|
||||
msg.set_queue_id(queue_id_.Binary());
|
||||
msg.set_seq_id(seq_id_);
|
||||
msg.set_length(buffer_->Size());
|
||||
msg.set_raw(raw_);
|
||||
msg.SerializeToString(output);
|
||||
}
|
||||
|
||||
std::shared_ptr<DataMessage> DataMessage::FromBytes(uint8_t *bytes) {
|
||||
bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType);
|
||||
uint64_t *fbs_length = (uint64_t *)bytes;
|
||||
bytes += sizeof(uint64_t);
|
||||
|
||||
std::string inputpb(reinterpret_cast<char const *>(bytes), *fbs_length);
|
||||
queue::protobuf::StreamingQueueDataMsg message;
|
||||
message.ParseFromString(inputpb);
|
||||
ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id());
|
||||
ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id());
|
||||
ObjectID queue_id = ObjectID::FromBinary(message.queue_id());
|
||||
uint64_t seq_id = message.seq_id();
|
||||
uint64_t length = message.length();
|
||||
bool raw = message.raw();
|
||||
bytes += *fbs_length;
|
||||
|
||||
/// Copy data and create a new buffer for streaming queue.
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer =
|
||||
std::make_shared<LocalMemoryBuffer>(bytes, (size_t)length, true);
|
||||
std::shared_ptr<DataMessage> data_msg = std::make_shared<DataMessage>(
|
||||
src_actor_id, dst_actor_id, queue_id, seq_id, buffer, raw);
|
||||
|
||||
return data_msg;
|
||||
}
|
||||
|
||||
void NotificationMessage::ToProtobuf(std::string *output) {
|
||||
queue::protobuf::StreamingQueueNotificationMsg msg;
|
||||
msg.set_src_actor_id(actor_id_.Binary());
|
||||
msg.set_dst_actor_id(peer_actor_id_.Binary());
|
||||
msg.set_queue_id(queue_id_.Binary());
|
||||
msg.set_seq_id(seq_id_);
|
||||
msg.SerializeToString(output);
|
||||
}
|
||||
|
||||
std::shared_ptr<NotificationMessage> NotificationMessage::FromBytes(uint8_t *bytes) {
|
||||
bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType);
|
||||
uint64_t *length = (uint64_t *)bytes;
|
||||
bytes += sizeof(uint64_t);
|
||||
|
||||
std::string inputpb(reinterpret_cast<char const *>(bytes), *length);
|
||||
queue::protobuf::StreamingQueueNotificationMsg message;
|
||||
message.ParseFromString(inputpb);
|
||||
STREAMING_LOG(INFO) << "message.src_actor_id: " << message.src_actor_id();
|
||||
ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id());
|
||||
ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id());
|
||||
ObjectID queue_id = ObjectID::FromBinary(message.queue_id());
|
||||
uint64_t seq_id = message.seq_id();
|
||||
|
||||
std::shared_ptr<NotificationMessage> notify_msg =
|
||||
std::make_shared<NotificationMessage>(src_actor_id, dst_actor_id, queue_id, seq_id);
|
||||
|
||||
return notify_msg;
|
||||
}
|
||||
|
||||
void CheckMessage::ToProtobuf(std::string *output) {
|
||||
queue::protobuf::StreamingQueueCheckMsg msg;
|
||||
msg.set_src_actor_id(actor_id_.Binary());
|
||||
msg.set_dst_actor_id(peer_actor_id_.Binary());
|
||||
msg.set_queue_id(queue_id_.Binary());
|
||||
msg.SerializeToString(output);
|
||||
}
|
||||
|
||||
std::shared_ptr<CheckMessage> CheckMessage::FromBytes(uint8_t *bytes) {
|
||||
bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType);
|
||||
uint64_t *length = (uint64_t *)bytes;
|
||||
bytes += sizeof(uint64_t);
|
||||
|
||||
std::string inputpb(reinterpret_cast<char const *>(bytes), *length);
|
||||
queue::protobuf::StreamingQueueCheckMsg message;
|
||||
message.ParseFromString(inputpb);
|
||||
ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id());
|
||||
ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id());
|
||||
ObjectID queue_id = ObjectID::FromBinary(message.queue_id());
|
||||
|
||||
std::shared_ptr<CheckMessage> check_msg =
|
||||
std::make_shared<CheckMessage>(src_actor_id, dst_actor_id, queue_id);
|
||||
|
||||
return check_msg;
|
||||
}
|
||||
|
||||
void CheckRspMessage::ToProtobuf(std::string *output) {
|
||||
queue::protobuf::StreamingQueueCheckRspMsg msg;
|
||||
msg.set_src_actor_id(actor_id_.Binary());
|
||||
msg.set_dst_actor_id(peer_actor_id_.Binary());
|
||||
msg.set_queue_id(queue_id_.Binary());
|
||||
msg.set_err_code(err_code_);
|
||||
msg.SerializeToString(output);
|
||||
}
|
||||
|
||||
std::shared_ptr<CheckRspMessage> CheckRspMessage::FromBytes(uint8_t *bytes) {
|
||||
bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType);
|
||||
uint64_t *length = (uint64_t *)bytes;
|
||||
bytes += sizeof(uint64_t);
|
||||
|
||||
std::string inputpb(reinterpret_cast<char const *>(bytes), *length);
|
||||
queue::protobuf::StreamingQueueCheckRspMsg message;
|
||||
message.ParseFromString(inputpb);
|
||||
ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id());
|
||||
ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id());
|
||||
ObjectID queue_id = ObjectID::FromBinary(message.queue_id());
|
||||
queue::protobuf::StreamingQueueError err_code = message.err_code();
|
||||
|
||||
std::shared_ptr<CheckRspMessage> check_rsp_msg =
|
||||
std::make_shared<CheckRspMessage>(src_actor_id, dst_actor_id, queue_id, err_code);
|
||||
|
||||
return check_rsp_msg;
|
||||
}
|
||||
|
||||
void TestInitMessage::ToProtobuf(std::string *output) {
|
||||
queue::protobuf::StreamingQueueTestInitMsg msg;
|
||||
msg.set_role(role_);
|
||||
msg.set_src_actor_id(actor_id_.Binary());
|
||||
msg.set_dst_actor_id(peer_actor_id_.Binary());
|
||||
msg.set_actor_handle(actor_handle_serialized_);
|
||||
for (auto &queue_id : queue_ids_) {
|
||||
msg.add_queue_ids(queue_id.Binary());
|
||||
}
|
||||
for (auto &queue_id : rescale_queue_ids_) {
|
||||
msg.add_rescale_queue_ids(queue_id.Binary());
|
||||
}
|
||||
msg.set_test_suite_name(test_suite_name_);
|
||||
msg.set_test_name(test_name_);
|
||||
msg.set_param(param_);
|
||||
msg.SerializeToString(output);
|
||||
}
|
||||
|
||||
std::shared_ptr<TestInitMessage> TestInitMessage::FromBytes(uint8_t *bytes) {
|
||||
bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType);
|
||||
uint64_t *length = (uint64_t *)bytes;
|
||||
bytes += sizeof(uint64_t);
|
||||
|
||||
std::string inputpb(reinterpret_cast<char const *>(bytes), *length);
|
||||
queue::protobuf::StreamingQueueTestInitMsg message;
|
||||
message.ParseFromString(inputpb);
|
||||
queue::protobuf::StreamingQueueTestRole role = message.role();
|
||||
ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id());
|
||||
ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id());
|
||||
std::string actor_handle_serialized = message.actor_handle();
|
||||
std::vector<ObjectID> queue_ids;
|
||||
for (int i = 0; i < message.queue_ids_size(); i++) {
|
||||
queue_ids.push_back(ObjectID::FromBinary(message.queue_ids(i)));
|
||||
}
|
||||
std::vector<ObjectID> rescale_queue_ids;
|
||||
for (int i = 0; i < message.rescale_queue_ids_size(); i++) {
|
||||
rescale_queue_ids.push_back(ObjectID::FromBinary(message.rescale_queue_ids(i)));
|
||||
}
|
||||
std::string test_suite_name = message.test_suite_name();
|
||||
std::string test_name = message.test_name();
|
||||
uint64_t param = message.param();
|
||||
|
||||
std::shared_ptr<TestInitMessage> test_init_msg = std::make_shared<TestInitMessage>(
|
||||
role, src_actor_id, dst_actor_id, actor_handle_serialized, queue_ids,
|
||||
rescale_queue_ids, test_suite_name, test_name, param);
|
||||
|
||||
return test_init_msg;
|
||||
}
|
||||
|
||||
void TestCheckStatusRspMsg::ToProtobuf(std::string *output) {
|
||||
queue::protobuf::StreamingQueueTestCheckStatusRspMsg msg;
|
||||
msg.set_test_name(test_name_);
|
||||
msg.set_status(status_);
|
||||
msg.SerializeToString(output);
|
||||
}
|
||||
|
||||
std::shared_ptr<TestCheckStatusRspMsg> TestCheckStatusRspMsg::FromBytes(uint8_t *bytes) {
|
||||
bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType);
|
||||
uint64_t *length = (uint64_t *)bytes;
|
||||
bytes += sizeof(uint64_t);
|
||||
|
||||
std::string inputpb(reinterpret_cast<char const *>(bytes), *length);
|
||||
queue::protobuf::StreamingQueueTestCheckStatusRspMsg message;
|
||||
message.ParseFromString(inputpb);
|
||||
std::string test_name = message.test_name();
|
||||
bool status = message.status();
|
||||
|
||||
std::shared_ptr<TestCheckStatusRspMsg> test_check_msg =
|
||||
std::make_shared<TestCheckStatusRspMsg>(test_name, status);
|
||||
|
||||
return test_check_msg;
|
||||
}
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
235
streaming/src/queue/message.h
Normal file
235
streaming/src/queue/message.h
Normal file
|
@ -0,0 +1,235 @@
|
|||
#ifndef _STREAMING_QUEUE_MESSAGE_H_
|
||||
#define _STREAMING_QUEUE_MESSAGE_H_
|
||||
|
||||
#include "protobuf/streaming_queue.pb.h"
|
||||
#include "ray/common/buffer.h"
|
||||
#include "ray/common/id.h"
|
||||
#include "util/streaming_logging.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
/// Base class of all message classes.
|
||||
/// All payloads transferred through direct actor call are packed into a unified package,
|
||||
/// consisting of protobuf-formatted metadata and data, including data and control
|
||||
/// messages. These message classes wrap the package defined in
|
||||
/// protobuf/streaming_queue.proto respectively.
|
||||
class Message {
|
||||
public:
|
||||
/// Construct a Message instance.
|
||||
/// \param[in] actor_id ActorID of message sender.
|
||||
/// \param[in] peer_actor_id ActorID of message receiver.
|
||||
/// \param[in] queue_id queue id to identify which queue the message is sent to.
|
||||
/// \param[in] buffer an optional param, a chunk of data to send.
|
||||
Message(const ActorID &actor_id, const ActorID &peer_actor_id, const ObjectID &queue_id,
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer = nullptr)
|
||||
: actor_id_(actor_id),
|
||||
peer_actor_id_(peer_actor_id),
|
||||
queue_id_(queue_id),
|
||||
buffer_(buffer) {}
|
||||
Message() {}
|
||||
virtual ~Message() {}
|
||||
ActorID ActorId() { return actor_id_; }
|
||||
ActorID PeerActorId() { return peer_actor_id_; }
|
||||
ObjectID QueueId() { return queue_id_; }
|
||||
std::shared_ptr<LocalMemoryBuffer> Buffer() { return buffer_; }
|
||||
|
||||
/// Serialize all meta data and data to a LocalMemoryBuffer, which can be sent through
|
||||
/// direct actor call. \return serialized buffer .
|
||||
std::unique_ptr<LocalMemoryBuffer> ToBytes();
|
||||
|
||||
/// Get message type.
|
||||
/// \return message type.
|
||||
virtual queue::protobuf::StreamingQueueMessageType Type() = 0;
|
||||
|
||||
/// All subclasses should implement `ToProtobuf` to serialize its own protobuf data.
|
||||
virtual void ToProtobuf(std::string *output) = 0;
|
||||
|
||||
protected:
|
||||
ActorID actor_id_;
|
||||
ActorID peer_actor_id_;
|
||||
ObjectID queue_id_;
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer_;
|
||||
|
||||
public:
|
||||
/// A magic number to identify a valid message.
|
||||
static const uint32_t MagicNum;
|
||||
};
|
||||
|
||||
/// Wrap StreamingQueueDataMsg in streaming_queue.proto.
|
||||
/// DataMessage encapsulates the memory buffer of QueueItem, a one-to-one relationship
|
||||
/// exists between DataMessage and QueueItem.
|
||||
class DataMessage : public Message {
|
||||
public:
|
||||
DataMessage(const ActorID &actor_id, const ActorID &peer_actor_id, ObjectID queue_id,
|
||||
uint64_t seq_id, std::shared_ptr<LocalMemoryBuffer> buffer, bool raw)
|
||||
: Message(actor_id, peer_actor_id, queue_id, buffer), seq_id_(seq_id), raw_(raw) {}
|
||||
virtual ~DataMessage() {}
|
||||
|
||||
static std::shared_ptr<DataMessage> FromBytes(uint8_t *bytes);
|
||||
virtual void ToProtobuf(std::string *output);
|
||||
uint64_t SeqId() { return seq_id_; }
|
||||
bool IsRaw() { return raw_; }
|
||||
queue::protobuf::StreamingQueueMessageType Type() { return type_; }
|
||||
|
||||
private:
|
||||
uint64_t seq_id_;
|
||||
bool raw_;
|
||||
|
||||
const queue::protobuf::StreamingQueueMessageType type_ =
|
||||
queue::protobuf::StreamingQueueMessageType::StreamingQueueDataMsgType;
|
||||
};
|
||||
|
||||
/// Wrap StreamingQueueNotificationMsg in streaming_queue.proto.
|
||||
/// NotificationMessage, downstream queues sends to upstream queues, for the data reader
|
||||
/// to inform the data writer of the consumed offset.
|
||||
class NotificationMessage : public Message {
|
||||
public:
|
||||
NotificationMessage(const ActorID &actor_id, const ActorID &peer_actor_id,
|
||||
const ObjectID &queue_id, uint64_t seq_id)
|
||||
: Message(actor_id, peer_actor_id, queue_id), seq_id_(seq_id) {}
|
||||
|
||||
virtual ~NotificationMessage() {}
|
||||
|
||||
static std::shared_ptr<NotificationMessage> FromBytes(uint8_t *bytes);
|
||||
virtual void ToProtobuf(std::string *output);
|
||||
|
||||
uint64_t SeqId() { return seq_id_; }
|
||||
queue::protobuf::StreamingQueueMessageType Type() { return type_; }
|
||||
|
||||
private:
|
||||
uint64_t seq_id_;
|
||||
const queue::protobuf::StreamingQueueMessageType type_ =
|
||||
queue::protobuf::StreamingQueueMessageType::StreamingQueueNotificationMsgType;
|
||||
};
|
||||
|
||||
/// Wrap StreamingQueueCheckMsg in streaming_queue.proto.
|
||||
/// CheckMessage, upstream queues sends to downstream queues, fot the data writer to check
|
||||
/// whether the corresponded downstream queue is read or not.
|
||||
class CheckMessage : public Message {
|
||||
public:
|
||||
CheckMessage(const ActorID &actor_id, const ActorID &peer_actor_id,
|
||||
const ObjectID &queue_id)
|
||||
: Message(actor_id, peer_actor_id, queue_id) {}
|
||||
virtual ~CheckMessage() {}
|
||||
|
||||
static std::shared_ptr<CheckMessage> FromBytes(uint8_t *bytes);
|
||||
virtual void ToProtobuf(std::string *output);
|
||||
|
||||
queue::protobuf::StreamingQueueMessageType Type() { return type_; }
|
||||
|
||||
private:
|
||||
const queue::protobuf::StreamingQueueMessageType type_ =
|
||||
queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckMsgType;
|
||||
};
|
||||
|
||||
/// Wrap StreamingQueueCheckRspMsg in streaming_queue.proto.
|
||||
/// CheckRspMessage, downstream queues sends to upstream queues, the response message to
|
||||
/// CheckMessage to indicate whether downstream queue is ready or not.
|
||||
class CheckRspMessage : public Message {
|
||||
public:
|
||||
CheckRspMessage(const ActorID &actor_id, const ActorID &peer_actor_id,
|
||||
const ObjectID &queue_id, queue::protobuf::StreamingQueueError err_code)
|
||||
: Message(actor_id, peer_actor_id, queue_id), err_code_(err_code) {}
|
||||
virtual ~CheckRspMessage() {}
|
||||
|
||||
static std::shared_ptr<CheckRspMessage> FromBytes(uint8_t *bytes);
|
||||
virtual void ToProtobuf(std::string *output);
|
||||
queue::protobuf::StreamingQueueMessageType Type() { return type_; }
|
||||
queue::protobuf::StreamingQueueError Error() { return err_code_; }
|
||||
|
||||
private:
|
||||
queue::protobuf::StreamingQueueError err_code_;
|
||||
const queue::protobuf::StreamingQueueMessageType type_ =
|
||||
queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType;
|
||||
};
|
||||
|
||||
/// Wrap StreamingQueueTestInitMsg in streaming_queue.proto.
|
||||
/// TestInitMessage, used for test, driver sends to test workers to init test suite.
|
||||
class TestInitMessage : public Message {
|
||||
public:
|
||||
TestInitMessage(const queue::protobuf::StreamingQueueTestRole role,
|
||||
const ActorID &actor_id, const ActorID &peer_actor_id,
|
||||
const std::string actor_handle_serialized,
|
||||
const std::vector<ObjectID> &queue_ids,
|
||||
const std::vector<ObjectID> &rescale_queue_ids,
|
||||
std::string test_suite_name, std::string test_name, uint64_t param)
|
||||
: Message(actor_id, peer_actor_id, queue_ids[0]),
|
||||
actor_handle_serialized_(actor_handle_serialized),
|
||||
queue_ids_(queue_ids),
|
||||
rescale_queue_ids_(rescale_queue_ids),
|
||||
role_(role),
|
||||
test_suite_name_(test_suite_name),
|
||||
test_name_(test_name),
|
||||
param_(param) {}
|
||||
virtual ~TestInitMessage() {}
|
||||
|
||||
static std::shared_ptr<TestInitMessage> FromBytes(uint8_t *bytes);
|
||||
virtual void ToProtobuf(std::string *output);
|
||||
queue::protobuf::StreamingQueueMessageType Type() { return type_; }
|
||||
std::string ActorHandleSerialized() { return actor_handle_serialized_; }
|
||||
queue::protobuf::StreamingQueueTestRole Role() { return role_; }
|
||||
std::vector<ObjectID> QueueIds() { return queue_ids_; }
|
||||
std::vector<ObjectID> RescaleQueueIds() { return rescale_queue_ids_; }
|
||||
std::string TestSuiteName() { return test_suite_name_; }
|
||||
std::string TestName() { return test_name_; }
|
||||
uint64_t Param() { return param_; }
|
||||
|
||||
std::string ToString() {
|
||||
std::ostringstream os;
|
||||
os << "actor_handle_serialized: " << actor_handle_serialized_;
|
||||
os << " actor_id: " << ActorId();
|
||||
os << " peer_actor_id: " << PeerActorId();
|
||||
os << " queue_ids:[";
|
||||
for (auto &qid : queue_ids_) {
|
||||
os << qid << ",";
|
||||
}
|
||||
os << "], rescale_queue_ids:[";
|
||||
for (auto &qid : rescale_queue_ids_) {
|
||||
os << qid << ",";
|
||||
}
|
||||
os << "],";
|
||||
os << " role:" << queue::protobuf::StreamingQueueTestRole_Name(role_);
|
||||
os << " suite_name: " << test_suite_name_;
|
||||
os << " test_name: " << test_name_;
|
||||
os << " param: " << param_;
|
||||
return os.str();
|
||||
}
|
||||
|
||||
private:
|
||||
const queue::protobuf::StreamingQueueMessageType type_ =
|
||||
queue::protobuf::StreamingQueueMessageType::StreamingQueueTestInitMsgType;
|
||||
std::string actor_handle_serialized_;
|
||||
std::vector<ObjectID> queue_ids_;
|
||||
std::vector<ObjectID> rescale_queue_ids_;
|
||||
queue::protobuf::StreamingQueueTestRole role_;
|
||||
std::string test_suite_name_;
|
||||
std::string test_name_;
|
||||
uint64_t param_;
|
||||
};
|
||||
|
||||
/// Wrap StreamingQueueTestCheckStatusRspMsg in streaming_queue.proto.
|
||||
/// TestCheckStatusRspMsg, used for test, driver sends to test workers to check
|
||||
/// whether test has completed or failed.
|
||||
class TestCheckStatusRspMsg : public Message {
|
||||
public:
|
||||
TestCheckStatusRspMsg(const std::string test_name, bool status)
|
||||
: test_name_(test_name), status_(status) {}
|
||||
virtual ~TestCheckStatusRspMsg() {}
|
||||
|
||||
static std::shared_ptr<TestCheckStatusRspMsg> FromBytes(uint8_t *bytes);
|
||||
virtual void ToProtobuf(std::string *output);
|
||||
queue::protobuf::StreamingQueueMessageType Type() { return type_; }
|
||||
std::string TestName() { return test_name_; }
|
||||
bool Status() { return status_; }
|
||||
|
||||
private:
|
||||
const queue::protobuf::StreamingQueueMessageType type_ =
|
||||
queue::protobuf::StreamingQueueMessageType::StreamingQueueTestCheckStatusRspMsgType;
|
||||
std::string test_name_;
|
||||
bool status_;
|
||||
};
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
#endif
|
211
streaming/src/queue/queue.cc
Normal file
211
streaming/src/queue/queue.cc
Normal file
|
@ -0,0 +1,211 @@
|
|||
#include "queue.h"
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include "queue_handler.h"
|
||||
#include "util/streaming_util.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
bool Queue::Push(QueueItem item) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (max_data_size_ < item.DataSize() + data_size_) return false;
|
||||
|
||||
buffer_queue_.push_back(item);
|
||||
data_size_ += item.DataSize();
|
||||
readable_cv_.notify_one();
|
||||
return true;
|
||||
}
|
||||
|
||||
QueueItem Queue::FrontProcessed() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
STREAMING_CHECK(buffer_queue_.size() != 0) << "WriterQueue Pop fail";
|
||||
|
||||
if (watershed_iter_ == buffer_queue_.begin()) {
|
||||
return InvalidQueueItem();
|
||||
}
|
||||
|
||||
QueueItem item = buffer_queue_.front();
|
||||
return item;
|
||||
}
|
||||
|
||||
QueueItem Queue::PopProcessed() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
STREAMING_CHECK(buffer_queue_.size() != 0) << "WriterQueue Pop fail";
|
||||
|
||||
if (watershed_iter_ == buffer_queue_.begin()) {
|
||||
return InvalidQueueItem();
|
||||
}
|
||||
|
||||
QueueItem item = buffer_queue_.front();
|
||||
buffer_queue_.pop_front();
|
||||
data_size_ -= item.DataSize();
|
||||
data_size_sent_ -= item.DataSize();
|
||||
return item;
|
||||
}
|
||||
|
||||
QueueItem Queue::PopPending() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
auto it = std::next(watershed_iter_);
|
||||
QueueItem item = *it;
|
||||
data_size_sent_ += it->DataSize();
|
||||
buffer_queue_.splice(watershed_iter_, buffer_queue_, it, std::next(it));
|
||||
return item;
|
||||
}
|
||||
|
||||
QueueItem Queue::PopPendingBlockTimeout(uint64_t timeout_us) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
std::chrono::system_clock::time_point point =
|
||||
std::chrono::system_clock::now() + std::chrono::microseconds(timeout_us);
|
||||
if (readable_cv_.wait_until(lock, point, [this] {
|
||||
return std::next(watershed_iter_) != buffer_queue_.end();
|
||||
})) {
|
||||
auto it = std::next(watershed_iter_);
|
||||
QueueItem item = *it;
|
||||
data_size_sent_ += it->DataSize();
|
||||
buffer_queue_.splice(watershed_iter_, buffer_queue_, it, std::next(it));
|
||||
return item;
|
||||
|
||||
} else {
|
||||
uint8_t data[1];
|
||||
return QueueItem(QUEUE_INVALID_SEQ_ID, data, 1, 0, true);
|
||||
}
|
||||
}
|
||||
|
||||
QueueItem Queue::BackPending() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (std::next(watershed_iter_) == buffer_queue_.end()) {
|
||||
uint8_t data[1];
|
||||
return QueueItem(QUEUE_INVALID_SEQ_ID, data, 1, 0, true);
|
||||
}
|
||||
return buffer_queue_.back();
|
||||
}
|
||||
|
||||
bool Queue::IsPendingEmpty() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
return std::next(watershed_iter_) == buffer_queue_.end();
|
||||
}
|
||||
|
||||
bool Queue::IsPendingFull(uint64_t data_size) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
return max_data_size_ < data_size + data_size_;
|
||||
}
|
||||
|
||||
size_t Queue::ProcessedCount() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (watershed_iter_ == buffer_queue_.begin()) return 0;
|
||||
|
||||
auto begin = buffer_queue_.begin();
|
||||
auto end = std::prev(watershed_iter_);
|
||||
|
||||
return end->SeqId() + 1 - begin->SeqId();
|
||||
}
|
||||
|
||||
size_t Queue::PendingCount() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (std::next(watershed_iter_) == buffer_queue_.end()) return 0;
|
||||
|
||||
auto begin = std::next(watershed_iter_);
|
||||
auto end = std::prev(buffer_queue_.end());
|
||||
|
||||
return begin->SeqId() - end->SeqId() + 1;
|
||||
}
|
||||
|
||||
Status WriterQueue::Push(uint64_t seq_id, uint8_t *data, uint32_t data_size,
|
||||
uint64_t timestamp, bool raw) {
|
||||
if (IsPendingFull(data_size)) {
|
||||
return Status::OutOfMemory("Queue Push OutOfMemory");
|
||||
}
|
||||
|
||||
while (is_pulling_) {
|
||||
STREAMING_LOG(INFO) << "This queue is sending pull data, wait.";
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
}
|
||||
|
||||
QueueItem item(seq_id, data, data_size, timestamp, raw);
|
||||
Queue::Push(item);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void WriterQueue::Send() {
|
||||
while (!IsPendingEmpty()) {
|
||||
// FIXME: front -> send -> pop
|
||||
QueueItem item = PopPending();
|
||||
DataMessage msg(actor_id_, peer_actor_id_, queue_id_, item.SeqId(), item.Buffer(),
|
||||
item.IsRaw());
|
||||
std::unique_ptr<LocalMemoryBuffer> buffer = msg.ToBytes();
|
||||
STREAMING_CHECK(transport_ != nullptr);
|
||||
transport_->Send(std::move(buffer),
|
||||
DownstreamQueueMessageHandler::peer_async_function_);
|
||||
}
|
||||
}
|
||||
|
||||
Status WriterQueue::TryEvictItems() {
|
||||
STREAMING_LOG(INFO) << "TryEvictItems";
|
||||
QueueItem item = FrontProcessed();
|
||||
uint64_t first_seq_id = item.SeqId();
|
||||
STREAMING_LOG(INFO) << "TryEvictItems first_seq_id: " << first_seq_id
|
||||
<< " min_consumed_id_: " << min_consumed_id_
|
||||
<< " eviction_limit_: " << eviction_limit_;
|
||||
if (min_consumed_id_ == QUEUE_INVALID_SEQ_ID || first_seq_id > min_consumed_id_) {
|
||||
return Status::OutOfMemory("The queue is full and some reader doesn't consume");
|
||||
}
|
||||
|
||||
if (eviction_limit_ == QUEUE_INVALID_SEQ_ID || first_seq_id > eviction_limit_) {
|
||||
return Status::OutOfMemory("The queue is full and eviction limit block evict");
|
||||
}
|
||||
|
||||
uint64_t evict_target_seq_id = std::min(min_consumed_id_, eviction_limit_);
|
||||
|
||||
while (item.SeqId() <= evict_target_seq_id) {
|
||||
PopProcessed();
|
||||
STREAMING_LOG(INFO) << "TryEvictItems directly " << item.SeqId();
|
||||
item = FrontProcessed();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void WriterQueue::OnNotify(std::shared_ptr<NotificationMessage> notify_msg) {
|
||||
STREAMING_LOG(INFO) << "OnNotify target seq_id: " << notify_msg->SeqId();
|
||||
min_consumed_id_ = notify_msg->SeqId();
|
||||
}
|
||||
|
||||
void ReaderQueue::OnConsumed(uint64_t seq_id) {
|
||||
STREAMING_LOG(INFO) << "OnConsumed: " << seq_id;
|
||||
QueueItem item = FrontProcessed();
|
||||
while (item.SeqId() <= seq_id) {
|
||||
PopProcessed();
|
||||
item = FrontProcessed();
|
||||
}
|
||||
Notify(seq_id);
|
||||
}
|
||||
|
||||
void ReaderQueue::Notify(uint64_t seq_id) {
|
||||
std::vector<TaskArg> task_args;
|
||||
CreateNotifyTask(seq_id, task_args);
|
||||
// SubmitActorTask
|
||||
|
||||
NotificationMessage msg(actor_id_, peer_actor_id_, queue_id_, seq_id);
|
||||
std::unique_ptr<LocalMemoryBuffer> buffer = msg.ToBytes();
|
||||
|
||||
transport_->Send(std::move(buffer), UpstreamQueueMessageHandler::peer_async_function_);
|
||||
}
|
||||
|
||||
void ReaderQueue::CreateNotifyTask(uint64_t seq_id, std::vector<TaskArg> &task_args) {}
|
||||
|
||||
void ReaderQueue::OnData(QueueItem &item) {
|
||||
if (item.SeqId() != expect_seq_id_) {
|
||||
STREAMING_LOG(WARNING) << "OnData ignore seq_id: " << item.SeqId()
|
||||
<< " expect_seq_id_: " << expect_seq_id_;
|
||||
return;
|
||||
}
|
||||
|
||||
last_recv_seq_id_ = item.SeqId();
|
||||
STREAMING_LOG(DEBUG) << "ReaderQueue::OnData seq_id: " << last_recv_seq_id_;
|
||||
|
||||
Push(item);
|
||||
expect_seq_id_++;
|
||||
}
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
213
streaming/src/queue/queue.h
Normal file
213
streaming/src/queue/queue.h
Normal file
|
@ -0,0 +1,213 @@
|
|||
#ifndef _STREAMING_QUEUE_H_
|
||||
#define _STREAMING_QUEUE_H_
|
||||
#include <iterator>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/util/util.h"
|
||||
|
||||
#include "queue_item.h"
|
||||
#include "transport.h"
|
||||
#include "util/streaming_logging.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
using ray::ObjectID;
|
||||
|
||||
enum QueueType { UPSTREAM = 0, DOWNSTREAM };
|
||||
|
||||
/// A queue-like data structure, which does not delete its items after poped.
|
||||
/// The lifecycle of each item is:
|
||||
/// - Pending, an item is pushed into a queue, but has not been processed (sent out or
|
||||
/// consumed),
|
||||
/// - Processed, has been handled by the user, but should not be deleted.
|
||||
/// - Evicted, useless to the user, should be poped and destroyed.
|
||||
/// At present, this data structure is implemented with one std::list,
|
||||
/// using a watershed iterator to divided.
|
||||
class Queue {
|
||||
public:
|
||||
/// \param[in] queue_id the unique identification of a pair of queues (upstream and
|
||||
/// downstream). \param[in] size max size of the queue in bytes. \param[in] transport
|
||||
/// transport to send items to peer.
|
||||
Queue(ObjectID queue_id, uint64_t size, std::shared_ptr<Transport> transport)
|
||||
: queue_id_(queue_id), max_data_size_(size), data_size_(0), data_size_sent_(0) {
|
||||
buffer_queue_.push_back(InvalidQueueItem());
|
||||
watershed_iter_ = buffer_queue_.begin();
|
||||
}
|
||||
|
||||
virtual ~Queue() {}
|
||||
|
||||
/// Push an item into the queue.
|
||||
/// \param[in] item the QueueItem object to be send to peer.
|
||||
/// \return false if the queue is full.
|
||||
bool Push(QueueItem item);
|
||||
|
||||
/// Get the front of item which in processed state.
|
||||
QueueItem FrontProcessed();
|
||||
|
||||
/// Pop the front of item which in processed state.
|
||||
QueueItem PopProcessed();
|
||||
|
||||
/// Pop the front of item which in pending state, the item
|
||||
/// will not be evicted at this moment, its state turn to
|
||||
/// processed.
|
||||
QueueItem PopPending();
|
||||
|
||||
/// PopPending with timeout in microseconds.
|
||||
QueueItem PopPendingBlockTimeout(uint64_t timeout_us);
|
||||
|
||||
/// Return the last item in pending state.
|
||||
QueueItem BackPending();
|
||||
|
||||
bool IsPendingEmpty();
|
||||
bool IsPendingFull(uint64_t data_size = 0);
|
||||
|
||||
/// Return the size in bytes of all items in queue.
|
||||
uint64_t QueueSize() { return data_size_; }
|
||||
|
||||
/// Return the size in bytes of all items in pending state.
|
||||
uint64_t PendingDataSize() { return data_size_ - data_size_sent_; }
|
||||
|
||||
/// Return the size in bytes of all items in processed state.
|
||||
uint64_t ProcessedDataSize() { return data_size_sent_; }
|
||||
|
||||
/// Return item count of the queue.
|
||||
size_t Count() { return buffer_queue_.size(); }
|
||||
|
||||
/// Return item count in pending state.
|
||||
size_t PendingCount();
|
||||
|
||||
/// Return item count in processed state.
|
||||
size_t ProcessedCount();
|
||||
|
||||
protected:
|
||||
ObjectID queue_id_;
|
||||
std::list<QueueItem> buffer_queue_;
|
||||
std::list<QueueItem>::iterator watershed_iter_;
|
||||
|
||||
/// max data size in bytes
|
||||
uint64_t max_data_size_;
|
||||
uint64_t data_size_;
|
||||
uint64_t data_size_sent_;
|
||||
|
||||
std::mutex mutex_;
|
||||
std::condition_variable readable_cv_;
|
||||
};
|
||||
|
||||
/// Queue in upstream.
|
||||
class WriterQueue : public Queue {
|
||||
public:
|
||||
/// \param queue_id, the unique ObjectID to identify a queue
|
||||
/// \param actor_id, the actor id of upstream worker
|
||||
/// \param peer_actor_id, the actor id of downstream worker
|
||||
/// \param size, max data size in bytes
|
||||
/// \param transport, transport
|
||||
WriterQueue(const ObjectID &queue_id, const ActorID &actor_id,
|
||||
const ActorID &peer_actor_id, uint64_t size,
|
||||
std::shared_ptr<Transport> transport)
|
||||
: Queue(queue_id, size, transport),
|
||||
actor_id_(actor_id),
|
||||
peer_actor_id_(peer_actor_id),
|
||||
eviction_limit_(QUEUE_INVALID_SEQ_ID),
|
||||
min_consumed_id_(QUEUE_INVALID_SEQ_ID),
|
||||
peer_last_msg_id_(0),
|
||||
peer_last_seq_id_(QUEUE_INVALID_SEQ_ID),
|
||||
transport_(transport),
|
||||
is_pulling_(false) {}
|
||||
|
||||
/// Push a continuous buffer into queue.
|
||||
/// NOTE: the buffer should be copied.
|
||||
Status Push(uint64_t seq_id, uint8_t *data, uint32_t data_size, uint64_t timestamp,
|
||||
bool raw = false);
|
||||
|
||||
/// Callback function, will be called when downstream queue notifies
|
||||
/// it has consumed some items.
|
||||
/// NOTE: this callback function is called in queue thread.
|
||||
void OnNotify(std::shared_ptr<NotificationMessage> notify_msg);
|
||||
|
||||
/// Send items through direct call.
|
||||
void Send();
|
||||
|
||||
/// Called when user pushs item into queue. The count of items
|
||||
/// can be evicted, determined by eviction_limit_ and min_consumed_id_.
|
||||
Status TryEvictItems();
|
||||
|
||||
void SetQueueEvictionLimit(uint64_t eviction_limit) {
|
||||
eviction_limit_ = eviction_limit;
|
||||
}
|
||||
|
||||
uint64_t EvictionLimit() { return eviction_limit_; }
|
||||
|
||||
uint64_t GetMinConsumedSeqID() { return min_consumed_id_; }
|
||||
|
||||
void SetPeerLastIds(uint64_t msg_id, uint64_t seq_id) {
|
||||
peer_last_msg_id_ = msg_id;
|
||||
peer_last_seq_id_ = seq_id;
|
||||
}
|
||||
|
||||
uint64_t GetPeerLastMsgId() { return peer_last_msg_id_; }
|
||||
|
||||
uint64_t GetPeerLastSeqId() { return peer_last_seq_id_; }
|
||||
|
||||
private:
|
||||
ActorID actor_id_;
|
||||
ActorID peer_actor_id_;
|
||||
uint64_t eviction_limit_;
|
||||
uint64_t min_consumed_id_;
|
||||
uint64_t peer_last_msg_id_;
|
||||
uint64_t peer_last_seq_id_;
|
||||
std::shared_ptr<Transport> transport_;
|
||||
|
||||
std::atomic<bool> is_pulling_;
|
||||
};
|
||||
|
||||
/// Queue in downstream.
|
||||
class ReaderQueue : public Queue {
|
||||
public:
|
||||
/// \param queue_id, the unique ObjectID to identify a queue
|
||||
/// \param actor_id, the actor id of upstream worker
|
||||
/// \param peer_actor_id, the actor id of downstream worker
|
||||
/// \param transport, transport
|
||||
/// NOTE: we do not restrict queue size of ReaderQueue
|
||||
ReaderQueue(const ObjectID &queue_id, const ActorID &actor_id,
|
||||
const ActorID &peer_actor_id, std::shared_ptr<Transport> transport)
|
||||
: Queue(queue_id, std::numeric_limits<uint64_t>::max(), transport),
|
||||
actor_id_(actor_id),
|
||||
peer_actor_id_(peer_actor_id),
|
||||
min_consumed_id_(QUEUE_INVALID_SEQ_ID),
|
||||
last_recv_seq_id_(QUEUE_INVALID_SEQ_ID),
|
||||
expect_seq_id_(1),
|
||||
transport_(transport) {}
|
||||
|
||||
/// Delete processed items whose seq id <= seq_id,
|
||||
/// then notify upstream queue.
|
||||
void OnConsumed(uint64_t seq_id);
|
||||
|
||||
void OnData(QueueItem &item);
|
||||
|
||||
uint64_t GetMinConsumedSeqID() { return min_consumed_id_; }
|
||||
|
||||
uint64_t GetLastRecvSeqId() { return last_recv_seq_id_; }
|
||||
|
||||
void SetExpectSeqId(uint64_t expect) { expect_seq_id_ = expect; }
|
||||
|
||||
private:
|
||||
void Notify(uint64_t seq_id);
|
||||
void CreateNotifyTask(uint64_t seq_id, std::vector<TaskArg> &task_args);
|
||||
|
||||
private:
|
||||
ActorID actor_id_;
|
||||
ActorID peer_actor_id_;
|
||||
uint64_t min_consumed_id_;
|
||||
uint64_t last_recv_seq_id_;
|
||||
uint64_t expect_seq_id_;
|
||||
std::shared_ptr<PromiseWrapper> promise_for_pull_;
|
||||
std::shared_ptr<Transport> transport_;
|
||||
};
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
#endif
|
25
streaming/src/queue/queue_client.cc
Normal file
25
streaming/src/queue/queue_client.cc
Normal file
|
@ -0,0 +1,25 @@
|
|||
#include "queue_client.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
void WriterClient::OnWriterMessage(std::shared_ptr<LocalMemoryBuffer> buffer) {
|
||||
upstream_handler_->DispatchMessageAsync(buffer);
|
||||
}
|
||||
|
||||
std::shared_ptr<LocalMemoryBuffer> WriterClient::OnWriterMessageSync(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer) {
|
||||
return upstream_handler_->DispatchMessageSync(buffer);
|
||||
}
|
||||
|
||||
void ReaderClient::OnReaderMessage(std::shared_ptr<LocalMemoryBuffer> buffer) {
|
||||
downstream_handler_->DispatchMessageAsync(buffer);
|
||||
}
|
||||
|
||||
std::shared_ptr<LocalMemoryBuffer> ReaderClient::OnReaderMessageSync(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer) {
|
||||
return downstream_handler_->DispatchMessageSync(buffer);
|
||||
}
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
62
streaming/src/queue/queue_client.h
Normal file
62
streaming/src/queue/queue_client.h
Normal file
|
@ -0,0 +1,62 @@
|
|||
#ifndef _STREAMING_QUEUE_CLIENT_H_
|
||||
#define _STREAMING_QUEUE_CLIENT_H_
|
||||
#include "queue_handler.h"
|
||||
#include "transport.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
/// The interface of the streaming queue for DataReader.
|
||||
/// A ReaderClient should be created before DataReader created in Cython/Jni, and hold by
|
||||
/// Jobworker. When DataReader receive a buffer from upstream DataWriter (DataReader's
|
||||
/// raycall function is called), it calls `OnReaderMessage` to pass the buffer to its own
|
||||
/// downstream queue, or `OnReaderMessageSync` to wait for handle result.
|
||||
class ReaderClient {
|
||||
public:
|
||||
/// Construct a ReaderClient object.
|
||||
/// \param[in] core_worker CoreWorker C++ pointer of current actor
|
||||
/// \param[in] async_func DataReader's raycall function descriptor to be called by
|
||||
/// DataWriter, asynchronous semantics \param[in] sync_func DataReader's raycall
|
||||
/// function descriptor to be called by DataWriter, synchronous semantics
|
||||
ReaderClient(CoreWorker *core_worker, RayFunction &async_func, RayFunction &sync_func)
|
||||
: core_worker_(core_worker) {
|
||||
DownstreamQueueMessageHandler::peer_async_function_ = async_func;
|
||||
DownstreamQueueMessageHandler::peer_sync_function_ = sync_func;
|
||||
downstream_handler_ = ray::streaming::DownstreamQueueMessageHandler::CreateService(
|
||||
core_worker_, core_worker_->GetWorkerContext().GetCurrentActorID());
|
||||
}
|
||||
|
||||
/// Post buffer to downstream queue service, asynchronously.
|
||||
void OnReaderMessage(std::shared_ptr<LocalMemoryBuffer> buffer);
|
||||
/// Post buffer to downstream queue service, synchronously.
|
||||
/// \return handle result.
|
||||
std::shared_ptr<LocalMemoryBuffer> OnReaderMessageSync(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer);
|
||||
|
||||
private:
|
||||
CoreWorker *core_worker_;
|
||||
std::shared_ptr<DownstreamQueueMessageHandler> downstream_handler_;
|
||||
};
|
||||
|
||||
/// Interface of streaming queue for DataWriter. Similar to ReaderClient.
|
||||
class WriterClient {
|
||||
public:
|
||||
WriterClient(CoreWorker *core_worker, RayFunction &async_func, RayFunction &sync_func)
|
||||
: core_worker_(core_worker) {
|
||||
UpstreamQueueMessageHandler::peer_async_function_ = async_func;
|
||||
UpstreamQueueMessageHandler::peer_sync_function_ = sync_func;
|
||||
upstream_handler_ = ray::streaming::UpstreamQueueMessageHandler::CreateService(
|
||||
core_worker, core_worker_->GetWorkerContext().GetCurrentActorID());
|
||||
}
|
||||
|
||||
void OnWriterMessage(std::shared_ptr<LocalMemoryBuffer> buffer);
|
||||
std::shared_ptr<LocalMemoryBuffer> OnWriterMessageSync(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer);
|
||||
|
||||
private:
|
||||
CoreWorker *core_worker_;
|
||||
std::shared_ptr<UpstreamQueueMessageHandler> upstream_handler_;
|
||||
};
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
#endif
|
358
streaming/src/queue/queue_handler.cc
Normal file
358
streaming/src/queue/queue_handler.cc
Normal file
|
@ -0,0 +1,358 @@
|
|||
#include "queue_handler.h"
|
||||
#include "util/streaming_util.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
constexpr uint64_t COMMON_SYNC_CALL_TIMEOUTT_MS = 5 * 1000;
|
||||
|
||||
std::shared_ptr<UpstreamQueueMessageHandler>
|
||||
UpstreamQueueMessageHandler::upstream_handler_ = nullptr;
|
||||
std::shared_ptr<DownstreamQueueMessageHandler>
|
||||
DownstreamQueueMessageHandler::downstream_handler_ = nullptr;
|
||||
|
||||
RayFunction UpstreamQueueMessageHandler::peer_sync_function_;
|
||||
RayFunction UpstreamQueueMessageHandler::peer_async_function_;
|
||||
RayFunction DownstreamQueueMessageHandler::peer_sync_function_;
|
||||
RayFunction DownstreamQueueMessageHandler::peer_async_function_;
|
||||
|
||||
std::shared_ptr<Message> QueueMessageHandler::ParseMessage(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer) {
|
||||
uint8_t *bytes = buffer->Data();
|
||||
uint8_t *p_cur = bytes;
|
||||
uint32_t *magic_num = (uint32_t *)p_cur;
|
||||
STREAMING_CHECK(*magic_num == Message::MagicNum)
|
||||
<< *magic_num << " " << Message::MagicNum;
|
||||
|
||||
p_cur += sizeof(Message::MagicNum);
|
||||
queue::protobuf::StreamingQueueMessageType *type =
|
||||
(queue::protobuf::StreamingQueueMessageType *)p_cur;
|
||||
|
||||
std::shared_ptr<Message> message = nullptr;
|
||||
switch (*type) {
|
||||
case queue::protobuf::StreamingQueueMessageType::StreamingQueueNotificationMsgType:
|
||||
message = NotificationMessage::FromBytes(bytes);
|
||||
break;
|
||||
case queue::protobuf::StreamingQueueMessageType::StreamingQueueDataMsgType:
|
||||
message = DataMessage::FromBytes(bytes);
|
||||
break;
|
||||
case queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckMsgType:
|
||||
message = CheckMessage::FromBytes(bytes);
|
||||
break;
|
||||
case queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType:
|
||||
message = CheckRspMessage::FromBytes(bytes);
|
||||
break;
|
||||
default:
|
||||
STREAMING_CHECK(false) << "nonsupport message type: "
|
||||
<< queue::protobuf::StreamingQueueMessageType_Name(*type);
|
||||
break;
|
||||
}
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
void QueueMessageHandler::DispatchMessageAsync(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer) {
|
||||
queue_service_.post(
|
||||
boost::bind(&QueueMessageHandler::DispatchMessageInternal, this, buffer, nullptr));
|
||||
}
|
||||
|
||||
std::shared_ptr<LocalMemoryBuffer> QueueMessageHandler::DispatchMessageSync(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer) {
|
||||
std::shared_ptr<LocalMemoryBuffer> result = nullptr;
|
||||
std::shared_ptr<PromiseWrapper> promise = std::make_shared<PromiseWrapper>();
|
||||
queue_service_.post(
|
||||
boost::bind(&QueueMessageHandler::DispatchMessageInternal, this, buffer,
|
||||
[&promise, &result](std::shared_ptr<LocalMemoryBuffer> rst) {
|
||||
result = rst;
|
||||
promise->Notify(ray::Status::OK());
|
||||
}));
|
||||
Status st = promise->Wait();
|
||||
STREAMING_CHECK(st.ok());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::shared_ptr<Transport> QueueMessageHandler::GetOutTransport(
|
||||
const ObjectID &queue_id) {
|
||||
auto it = out_transports_.find(queue_id);
|
||||
if (it == out_transports_.end()) return nullptr;
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void QueueMessageHandler::SetPeerActorID(const ObjectID &queue_id,
|
||||
const ActorID &actor_id) {
|
||||
actors_.emplace(queue_id, actor_id);
|
||||
out_transports_.emplace(
|
||||
queue_id, std::make_shared<ray::streaming::Transport>(core_worker_, actor_id));
|
||||
}
|
||||
|
||||
ActorID QueueMessageHandler::GetPeerActorID(const ObjectID &queue_id) {
|
||||
auto it = actors_.find(queue_id);
|
||||
STREAMING_CHECK(it != actors_.end());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void QueueMessageHandler::Release() {
|
||||
actors_.clear();
|
||||
out_transports_.clear();
|
||||
}
|
||||
|
||||
void QueueMessageHandler::Start() {
|
||||
queue_thread_ = std::thread(&QueueMessageHandler::QueueThreadCallback, this);
|
||||
}
|
||||
|
||||
void QueueMessageHandler::Stop() {
|
||||
STREAMING_LOG(INFO) << "QueueMessageHandler Stop.";
|
||||
queue_service_.stop();
|
||||
if (queue_thread_.joinable()) {
|
||||
queue_thread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<UpstreamQueueMessageHandler> UpstreamQueueMessageHandler::CreateService(
|
||||
CoreWorker *core_worker, const ActorID &actor_id) {
|
||||
if (nullptr == upstream_handler_) {
|
||||
upstream_handler_ =
|
||||
std::make_shared<UpstreamQueueMessageHandler>(core_worker, actor_id);
|
||||
}
|
||||
return upstream_handler_;
|
||||
}
|
||||
|
||||
std::shared_ptr<UpstreamQueueMessageHandler> UpstreamQueueMessageHandler::GetService() {
|
||||
return upstream_handler_;
|
||||
}
|
||||
|
||||
std::shared_ptr<WriterQueue> UpstreamQueueMessageHandler::CreateUpstreamQueue(
|
||||
const ObjectID &queue_id, const ActorID &peer_actor_id, uint64_t size) {
|
||||
STREAMING_LOG(INFO) << "CreateUpstreamQueue: " << queue_id << " " << actor_id_ << "->"
|
||||
<< peer_actor_id;
|
||||
std::shared_ptr<WriterQueue> queue = GetUpQueue(queue_id);
|
||||
if (queue != nullptr) {
|
||||
STREAMING_LOG(WARNING) << "Duplicate to create up queue." << queue_id;
|
||||
return queue;
|
||||
}
|
||||
|
||||
queue = std::unique_ptr<streaming::WriterQueue>(new streaming::WriterQueue(
|
||||
queue_id, actor_id_, peer_actor_id, size, GetOutTransport(queue_id)));
|
||||
upstream_queues_[queue_id] = queue;
|
||||
|
||||
return queue;
|
||||
}
|
||||
|
||||
bool UpstreamQueueMessageHandler::UpstreamQueueExists(const ObjectID &queue_id) {
|
||||
return nullptr != GetUpQueue(queue_id);
|
||||
}
|
||||
|
||||
std::shared_ptr<streaming::WriterQueue> UpstreamQueueMessageHandler::GetUpQueue(
|
||||
const ObjectID &queue_id) {
|
||||
auto it = upstream_queues_.find(queue_id);
|
||||
if (it == upstream_queues_.end()) return nullptr;
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
bool UpstreamQueueMessageHandler::CheckQueueSync(const ObjectID &queue_id) {
|
||||
ActorID peer_actor_id = GetPeerActorID(queue_id);
|
||||
STREAMING_LOG(INFO) << "CheckQueueSync queue_id: " << queue_id
|
||||
<< " peer_actor_id: " << peer_actor_id;
|
||||
|
||||
CheckMessage msg(actor_id_, peer_actor_id, queue_id);
|
||||
std::unique_ptr<LocalMemoryBuffer> buffer = msg.ToBytes();
|
||||
|
||||
auto transport_it = GetOutTransport(queue_id);
|
||||
STREAMING_CHECK(transport_it != nullptr);
|
||||
std::shared_ptr<LocalMemoryBuffer> result_buffer = transport_it->SendForResultWithRetry(
|
||||
std::move(buffer), DownstreamQueueMessageHandler::peer_sync_function_, 10,
|
||||
COMMON_SYNC_CALL_TIMEOUTT_MS);
|
||||
if (result_buffer == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<Message> result_msg = ParseMessage(result_buffer);
|
||||
STREAMING_CHECK(
|
||||
result_msg->Type() ==
|
||||
queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType);
|
||||
std::shared_ptr<CheckRspMessage> check_rsp_msg =
|
||||
std::dynamic_pointer_cast<CheckRspMessage>(result_msg);
|
||||
STREAMING_LOG(INFO) << "CheckQueueSync return queue_id: " << check_rsp_msg->QueueId();
|
||||
STREAMING_CHECK(check_rsp_msg->PeerActorId() == actor_id_);
|
||||
|
||||
return queue::protobuf::StreamingQueueError::OK == check_rsp_msg->Error();
|
||||
}
|
||||
|
||||
void UpstreamQueueMessageHandler::WaitQueues(const std::vector<ObjectID> &queue_ids,
|
||||
int64_t timeout_ms,
|
||||
std::vector<ObjectID> &failed_queues) {
|
||||
failed_queues.insert(failed_queues.begin(), queue_ids.begin(), queue_ids.end());
|
||||
uint64_t start_time_us = current_time_ms();
|
||||
uint64_t current_time_us = start_time_us;
|
||||
while (!failed_queues.empty() && current_time_us < start_time_us + timeout_ms * 1000) {
|
||||
for (auto it = failed_queues.begin(); it != failed_queues.end();) {
|
||||
if (CheckQueueSync(*it)) {
|
||||
STREAMING_LOG(INFO) << "Check queue: " << *it << " return, ready.";
|
||||
it = failed_queues.erase(it);
|
||||
} else {
|
||||
STREAMING_LOG(INFO) << "Check queue: " << *it << " return, not ready.";
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||
it++;
|
||||
}
|
||||
}
|
||||
current_time_us = current_time_ms();
|
||||
}
|
||||
}
|
||||
|
||||
void UpstreamQueueMessageHandler::DispatchMessageInternal(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer,
|
||||
std::function<void(std::shared_ptr<LocalMemoryBuffer>)> callback) {
|
||||
std::shared_ptr<Message> msg = ParseMessage(buffer);
|
||||
STREAMING_LOG(DEBUG) << "QueueMessageHandler::DispatchMessageInternal: "
|
||||
<< " qid: " << msg->QueueId() << " actorid " << msg->ActorId()
|
||||
<< " peer actorid: " << msg->PeerActorId() << " type: "
|
||||
<< queue::protobuf::StreamingQueueMessageType_Name(msg->Type());
|
||||
|
||||
if (msg->Type() ==
|
||||
queue::protobuf::StreamingQueueMessageType::StreamingQueueNotificationMsgType) {
|
||||
OnNotify(std::dynamic_pointer_cast<NotificationMessage>(msg));
|
||||
} else if (msg->Type() ==
|
||||
queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType) {
|
||||
STREAMING_CHECK(false) << "Should not receive StreamingQueueCheckRspMsg";
|
||||
} else {
|
||||
STREAMING_CHECK(false) << "message type should be added: "
|
||||
<< queue::protobuf::StreamingQueueMessageType_Name(
|
||||
msg->Type());
|
||||
}
|
||||
}
|
||||
|
||||
void UpstreamQueueMessageHandler::OnNotify(
|
||||
std::shared_ptr<NotificationMessage> notify_msg) {
|
||||
auto queue = GetUpQueue(notify_msg->QueueId());
|
||||
if (queue == nullptr) {
|
||||
STREAMING_LOG(WARNING) << "Can not find queue for "
|
||||
<< queue::protobuf::StreamingQueueMessageType_Name(
|
||||
notify_msg->Type())
|
||||
<< ", maybe queue has been destroyed, ignore it."
|
||||
<< " seq id: " << notify_msg->SeqId();
|
||||
return;
|
||||
}
|
||||
queue->OnNotify(notify_msg);
|
||||
}
|
||||
|
||||
void UpstreamQueueMessageHandler::ReleaseAllUpQueues() {
|
||||
STREAMING_LOG(INFO) << "ReleaseAllUpQueues";
|
||||
upstream_queues_.clear();
|
||||
Release();
|
||||
}
|
||||
|
||||
std::shared_ptr<DownstreamQueueMessageHandler>
|
||||
DownstreamQueueMessageHandler::CreateService(CoreWorker *core_worker,
|
||||
const ActorID &actor_id) {
|
||||
if (nullptr == downstream_handler_) {
|
||||
downstream_handler_ =
|
||||
std::make_shared<DownstreamQueueMessageHandler>(core_worker, actor_id);
|
||||
}
|
||||
return downstream_handler_;
|
||||
}
|
||||
|
||||
std::shared_ptr<DownstreamQueueMessageHandler>
|
||||
DownstreamQueueMessageHandler::GetService() {
|
||||
return downstream_handler_;
|
||||
}
|
||||
|
||||
bool DownstreamQueueMessageHandler::DownstreamQueueExists(const ObjectID &queue_id) {
|
||||
return nullptr != GetDownQueue(queue_id);
|
||||
}
|
||||
|
||||
std::shared_ptr<ReaderQueue> DownstreamQueueMessageHandler::CreateDownstreamQueue(
|
||||
const ObjectID &queue_id, const ActorID &peer_actor_id) {
|
||||
STREAMING_LOG(INFO) << "CreateDownstreamQueue: " << queue_id << " " << peer_actor_id
|
||||
<< "->" << actor_id_;
|
||||
auto it = downstream_queues_.find(queue_id);
|
||||
if (it != downstream_queues_.end()) {
|
||||
STREAMING_LOG(WARNING) << "Duplicate to create down queue!!!! " << queue_id;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::shared_ptr<streaming::ReaderQueue> queue =
|
||||
std::unique_ptr<streaming::ReaderQueue>(new streaming::ReaderQueue(
|
||||
queue_id, actor_id_, peer_actor_id, GetOutTransport(queue_id)));
|
||||
downstream_queues_[queue_id] = queue;
|
||||
return queue;
|
||||
}
|
||||
|
||||
std::shared_ptr<streaming::ReaderQueue> DownstreamQueueMessageHandler::GetDownQueue(
|
||||
const ObjectID &queue_id) {
|
||||
auto it = downstream_queues_.find(queue_id);
|
||||
if (it == downstream_queues_.end()) return nullptr;
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::shared_ptr<LocalMemoryBuffer> DownstreamQueueMessageHandler::OnCheckQueue(
|
||||
std::shared_ptr<CheckMessage> check_msg) {
|
||||
queue::protobuf::StreamingQueueError err_code =
|
||||
queue::protobuf::StreamingQueueError::OK;
|
||||
|
||||
auto down_queue = downstream_queues_.find(check_msg->QueueId());
|
||||
if (down_queue == downstream_queues_.end()) {
|
||||
STREAMING_LOG(WARNING) << "OnCheckQueue " << check_msg->QueueId() << " not found.";
|
||||
err_code = queue::protobuf::StreamingQueueError::QUEUE_NOT_EXIST;
|
||||
}
|
||||
|
||||
CheckRspMessage msg(check_msg->PeerActorId(), check_msg->ActorId(),
|
||||
check_msg->QueueId(), err_code);
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer = msg.ToBytes();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
void DownstreamQueueMessageHandler::ReleaseAllDownQueues() {
|
||||
STREAMING_LOG(INFO) << "ReleaseAllDownQueues size: " << downstream_queues_.size();
|
||||
downstream_queues_.clear();
|
||||
Release();
|
||||
}
|
||||
|
||||
void DownstreamQueueMessageHandler::DispatchMessageInternal(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer,
|
||||
std::function<void(std::shared_ptr<LocalMemoryBuffer>)> callback) {
|
||||
std::shared_ptr<Message> msg = ParseMessage(buffer);
|
||||
STREAMING_LOG(DEBUG) << "QueueMessageHandler::DispatchMessageInternal: "
|
||||
<< " qid: " << msg->QueueId() << " actorid " << msg->ActorId()
|
||||
<< " peer actorid: " << msg->PeerActorId() << " type: "
|
||||
<< queue::protobuf::StreamingQueueMessageType_Name(msg->Type());
|
||||
|
||||
if (msg->Type() ==
|
||||
queue::protobuf::StreamingQueueMessageType::StreamingQueueDataMsgType) {
|
||||
OnData(std::dynamic_pointer_cast<DataMessage>(msg));
|
||||
} else if (msg->Type() ==
|
||||
queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckMsgType) {
|
||||
std::shared_ptr<LocalMemoryBuffer> check_result =
|
||||
this->OnCheckQueue(std::dynamic_pointer_cast<CheckMessage>(msg));
|
||||
if (callback != nullptr) {
|
||||
callback(check_result);
|
||||
}
|
||||
} else {
|
||||
STREAMING_CHECK(false) << "message type should be added: "
|
||||
<< queue::protobuf::StreamingQueueMessageType_Name(
|
||||
msg->Type());
|
||||
}
|
||||
}
|
||||
|
||||
void DownstreamQueueMessageHandler::OnData(std::shared_ptr<DataMessage> msg) {
|
||||
auto queue = GetDownQueue(msg->QueueId());
|
||||
if (queue == nullptr) {
|
||||
STREAMING_LOG(WARNING) << "Can not find queue for "
|
||||
<< queue::protobuf::StreamingQueueMessageType_Name(msg->Type())
|
||||
<< ", maybe queue has been destroyed, ignore it."
|
||||
<< " seq id: " << msg->SeqId();
|
||||
return;
|
||||
}
|
||||
|
||||
QueueItem item(msg);
|
||||
queue->OnData(item);
|
||||
}
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
194
streaming/src/queue/queue_handler.h
Normal file
194
streaming/src/queue/queue_handler.h
Normal file
|
@ -0,0 +1,194 @@
|
|||
#ifndef _QUEUE_SERVICE_H_
|
||||
#define _QUEUE_SERVICE_H_
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/bind.hpp>
|
||||
#include <boost/thread.hpp>
|
||||
#include <thread>
|
||||
|
||||
#include "queue.h"
|
||||
#include "util/streaming_logging.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
/// Base class of UpstreamQueueMessageHandler and DownstreamQueueMessageHandler.
|
||||
/// A queue service manages a group of queues, upstream queues or downstream queues of
|
||||
/// the current actor. Each queue service holds a boost.asio io_service, to handle
|
||||
/// messages asynchronously. When a message received by Writer/Reader in ray call thread,
|
||||
/// the message was delivered to
|
||||
/// UpstreamQueueMessageHandler/DownstreamQueueMessageHandler, then the ray call thread
|
||||
/// returns immediately. The queue service parses meta infomation from the message,
|
||||
/// including queue_id actor_id, etc, and dispatchs message to queue according to
|
||||
/// queue_id.
|
||||
class QueueMessageHandler {
|
||||
public:
|
||||
/// Construct a QueueMessageHandler instance.
|
||||
/// \param[in] core_worker CoreWorker C++ pointer of current actor, used to call Core
|
||||
/// Worker's api.
|
||||
/// For Python worker, the pointer can be obtained from
|
||||
/// ray.worker.global_worker.core_worker; For Java worker, obtained from
|
||||
/// RayNativeRuntime object through java reflection.
|
||||
/// \param[in] actor_id actor id of current actor.
|
||||
QueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id)
|
||||
: core_worker_(core_worker),
|
||||
actor_id_(actor_id),
|
||||
queue_dummy_work_(queue_service_) {
|
||||
Start();
|
||||
}
|
||||
|
||||
virtual ~QueueMessageHandler() { Stop(); }
|
||||
|
||||
/// Dispatch message buffer to asio service.
|
||||
/// \param[in] buffer serialized message received from peer actor.
|
||||
void DispatchMessageAsync(std::shared_ptr<LocalMemoryBuffer> buffer);
|
||||
|
||||
/// Dispatch message buffer to asio service synchronously, and wait for handle result.
|
||||
/// \param[in] buffer serialized message received from peer actor.
|
||||
/// \return handle result.
|
||||
std::shared_ptr<LocalMemoryBuffer> DispatchMessageSync(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer);
|
||||
|
||||
/// Get transport to a peer actor specified by actor_id.
|
||||
/// \param[in] actor_id actor id of peer actor
|
||||
/// \return transport
|
||||
std::shared_ptr<Transport> GetOutTransport(const ObjectID &actor_id);
|
||||
|
||||
/// The actual function where message being dispatched, called by DispatchMessageAsync
|
||||
/// and DispatchMessageSync.
|
||||
/// \param[in] buffer serialized message received from peer actor.
|
||||
/// \param[in] callback the callback function used by DispatchMessageSync, called
|
||||
/// after message processed complete. The std::shared_ptr<LocalMemoryBuffer>
|
||||
/// parameter is the return value.
|
||||
virtual void DispatchMessageInternal(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer,
|
||||
std::function<void(std::shared_ptr<LocalMemoryBuffer>)> callback) = 0;
|
||||
|
||||
/// Save actor_id of the peer actor specified by queue_id. For a upstream queue, the
|
||||
/// peer actor refer specifically to the actor in current ray cluster who has a
|
||||
/// downstream queue with same queue_id, and vice versa.
|
||||
/// \param[in] queue_id queue id of current queue.
|
||||
/// \param[in] actor_id actor_id actor id of corresponded peer actor.
|
||||
void SetPeerActorID(const ObjectID &queue_id, const ActorID &actor_id);
|
||||
|
||||
/// Obtain the actor id of the peer actor specified by queue_id.
|
||||
/// \return actor id
|
||||
ActorID GetPeerActorID(const ObjectID &queue_id);
|
||||
|
||||
/// Release all queues in current queue service.
|
||||
void Release();
|
||||
|
||||
private:
|
||||
/// Start asio service
|
||||
void Start();
|
||||
/// Stop asio service
|
||||
void Stop();
|
||||
/// The callback function of internal thread.
|
||||
void QueueThreadCallback() { queue_service_.run(); }
|
||||
|
||||
protected:
|
||||
/// CoreWorker C++ pointer of current actor
|
||||
CoreWorker *core_worker_;
|
||||
/// actor_id actor id of current actor
|
||||
ActorID actor_id_;
|
||||
/// Helper function, parse message buffer to Message object.
|
||||
std::shared_ptr<Message> ParseMessage(std::shared_ptr<LocalMemoryBuffer> buffer);
|
||||
|
||||
private:
|
||||
/// Map from queue id to a actor id of the queue's peer actor.
|
||||
std::unordered_map<ObjectID, ActorID> actors_;
|
||||
/// Map from queue id to a transport of the queue's peer actor.
|
||||
std::unordered_map<ObjectID, std::shared_ptr<Transport>> out_transports_;
|
||||
/// The internal thread which asio service run with.
|
||||
std::thread queue_thread_;
|
||||
/// The internal asio service.
|
||||
boost::asio::io_service queue_service_;
|
||||
/// The asio work which keeps queue_service_ alive.
|
||||
boost::asio::io_service::work queue_dummy_work_;
|
||||
};
|
||||
|
||||
/// UpstreamQueueMessageHandler holds and manages all upstream queues of current actor.
|
||||
class UpstreamQueueMessageHandler : public QueueMessageHandler {
|
||||
public:
|
||||
/// Construct a UpstreamQueueMessageHandler instance.
|
||||
UpstreamQueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id)
|
||||
: QueueMessageHandler(core_worker, actor_id) {}
|
||||
/// Create a upstream queue.
|
||||
/// \param[in] queue_id queue id of the queue to be created.
|
||||
/// \param[in] peer_actor_id actor id of peer actor.
|
||||
/// \param[in] size the max memory size of the queue.
|
||||
std::shared_ptr<WriterQueue> CreateUpstreamQueue(const ObjectID &queue_id,
|
||||
const ActorID &peer_actor_id,
|
||||
uint64_t size);
|
||||
/// Check whether the upstream queue specified by queue_id exists or not.
|
||||
bool UpstreamQueueExists(const ObjectID &queue_id);
|
||||
/// Wait all queues in queue_ids vector ready, until timeout.
|
||||
/// \param[in] queue_ids a group of queues.
|
||||
/// \param[in] timeout_ms max timeout time interval for wait all queues.
|
||||
/// \param[out] failed_queues a group of queues which are not ready when timeout.
|
||||
void WaitQueues(const std::vector<ObjectID> &queue_ids, int64_t timeout_ms,
|
||||
std::vector<ObjectID> &failed_queues);
|
||||
/// Handle notify message from corresponded downstream queue.
|
||||
void OnNotify(std::shared_ptr<NotificationMessage> notify_msg);
|
||||
/// Obtain upstream queue specified by queue_id.
|
||||
std::shared_ptr<streaming::WriterQueue> GetUpQueue(const ObjectID &queue_id);
|
||||
/// Release all upstream queues
|
||||
void ReleaseAllUpQueues();
|
||||
|
||||
virtual void DispatchMessageInternal(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer,
|
||||
std::function<void(std::shared_ptr<LocalMemoryBuffer>)> callback) override;
|
||||
|
||||
static std::shared_ptr<UpstreamQueueMessageHandler> CreateService(
|
||||
CoreWorker *core_worker, const ActorID &actor_id);
|
||||
static std::shared_ptr<UpstreamQueueMessageHandler> GetService();
|
||||
|
||||
static RayFunction peer_sync_function_;
|
||||
static RayFunction peer_async_function_;
|
||||
|
||||
private:
|
||||
bool CheckQueueSync(const ObjectID &queue_ids);
|
||||
|
||||
private:
|
||||
std::unordered_map<ObjectID, std::shared_ptr<streaming::WriterQueue>> upstream_queues_;
|
||||
static std::shared_ptr<UpstreamQueueMessageHandler> upstream_handler_;
|
||||
};
|
||||
|
||||
/// UpstreamQueueMessageHandler holds and manages all downstream queues of current actor.
|
||||
class DownstreamQueueMessageHandler : public QueueMessageHandler {
|
||||
public:
|
||||
DownstreamQueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id)
|
||||
: QueueMessageHandler(core_worker, actor_id) {}
|
||||
std::shared_ptr<ReaderQueue> CreateDownstreamQueue(const ObjectID &queue_id,
|
||||
const ActorID &peer_actor_id);
|
||||
bool DownstreamQueueExists(const ObjectID &queue_id);
|
||||
|
||||
void UpdateDownActor(const ObjectID &queue_id, const ActorID &actor_id);
|
||||
|
||||
std::shared_ptr<LocalMemoryBuffer> OnCheckQueue(
|
||||
std::shared_ptr<CheckMessage> check_msg);
|
||||
|
||||
std::shared_ptr<streaming::ReaderQueue> GetDownQueue(const ObjectID &queue_id);
|
||||
|
||||
void ReleaseAllDownQueues();
|
||||
|
||||
void OnData(std::shared_ptr<DataMessage> msg);
|
||||
virtual void DispatchMessageInternal(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer,
|
||||
std::function<void(std::shared_ptr<LocalMemoryBuffer>)> callback);
|
||||
|
||||
static std::shared_ptr<DownstreamQueueMessageHandler> CreateService(
|
||||
CoreWorker *core_worker, const ActorID &actor_id);
|
||||
static std::shared_ptr<DownstreamQueueMessageHandler> GetService();
|
||||
static RayFunction peer_sync_function_;
|
||||
static RayFunction peer_async_function_;
|
||||
|
||||
private:
|
||||
std::unordered_map<ObjectID, std::shared_ptr<streaming::ReaderQueue>>
|
||||
downstream_queues_;
|
||||
static std::shared_ptr<DownstreamQueueMessageHandler> downstream_handler_;
|
||||
};
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
#endif
|
109
streaming/src/queue/queue_item.h
Normal file
109
streaming/src/queue/queue_item.h
Normal file
|
@ -0,0 +1,109 @@
|
|||
#ifndef _STREAMING_QUEUE_ITEM_H_
|
||||
#define _STREAMING_QUEUE_ITEM_H_
|
||||
#include <iterator>
|
||||
#include <list>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "ray/common/id.h"
|
||||
|
||||
#include "message.h"
|
||||
#include "message/message_bundle.h"
|
||||
#include "util/streaming_logging.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
using ray::ObjectID;
|
||||
const uint64_t QUEUE_INVALID_SEQ_ID = std::numeric_limits<uint64_t>::max();
|
||||
|
||||
/// QueueItem is the element stored in `Queue`. Actually, when DataWriter pushes a message
|
||||
/// bundle into a queue, the bundle is packed into one QueueItem, so a one-to-one
|
||||
/// relationship exists between message bundle and QueueItem. Meanwhile, the QueueItem is
|
||||
/// also the minimum unit to send through direct actor call. Each QueueItem holds a
|
||||
/// LocalMemoryBuffer shared_ptr, which will be sent out by Transport.
|
||||
class QueueItem {
|
||||
public:
|
||||
/// Construct a QueueItem object.
|
||||
/// \param[in] seq_id the sequential id assigned by DataWriter for a message bundle and
|
||||
/// QueueItem.
|
||||
/// \param[in] data the data buffer to be stored in this QueueItem.
|
||||
/// \param[in] data_size the data size in bytes.
|
||||
/// \param[in] timestamp the time when this QueueItem created.
|
||||
/// \param[in] raw whether the data content is raw bytes, only used in some tests.
|
||||
QueueItem(uint64_t seq_id, uint8_t *data, uint32_t data_size, uint64_t timestamp,
|
||||
bool raw = false)
|
||||
: seq_id_(seq_id),
|
||||
timestamp_(timestamp),
|
||||
raw_(raw),
|
||||
/*COPY*/ buffer_(std::make_shared<LocalMemoryBuffer>(data, data_size, true)) {}
|
||||
|
||||
QueueItem(uint64_t seq_id, std::shared_ptr<LocalMemoryBuffer> buffer,
|
||||
uint64_t timestamp, bool raw = false)
|
||||
: seq_id_(seq_id), timestamp_(timestamp), raw_(raw), buffer_(buffer) {}
|
||||
|
||||
QueueItem(std::shared_ptr<DataMessage> data_msg)
|
||||
: seq_id_(data_msg->SeqId()),
|
||||
raw_(data_msg->IsRaw()),
|
||||
buffer_(data_msg->Buffer()) {}
|
||||
|
||||
QueueItem(const QueueItem &&item) {
|
||||
buffer_ = item.buffer_;
|
||||
seq_id_ = item.seq_id_;
|
||||
timestamp_ = item.timestamp_;
|
||||
raw_ = item.raw_;
|
||||
}
|
||||
|
||||
QueueItem(const QueueItem &item) {
|
||||
buffer_ = item.buffer_;
|
||||
seq_id_ = item.seq_id_;
|
||||
timestamp_ = item.timestamp_;
|
||||
raw_ = item.raw_;
|
||||
}
|
||||
|
||||
QueueItem &operator=(const QueueItem &item) {
|
||||
buffer_ = item.buffer_;
|
||||
seq_id_ = item.seq_id_;
|
||||
timestamp_ = item.timestamp_;
|
||||
raw_ = item.raw_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
virtual ~QueueItem() = default;
|
||||
|
||||
uint64_t SeqId() { return seq_id_; }
|
||||
bool IsRaw() { return raw_; }
|
||||
uint64_t TimeStamp() { return timestamp_; }
|
||||
size_t DataSize() { return buffer_->Size(); }
|
||||
std::shared_ptr<LocalMemoryBuffer> Buffer() { return buffer_; }
|
||||
|
||||
/// Get max message id in this item.
|
||||
/// \return max message id.
|
||||
uint64_t MaxMsgId() {
|
||||
if (raw_) {
|
||||
return 0;
|
||||
}
|
||||
auto message_bundle = StreamingMessageBundleMeta::FromBytes(buffer_->Data());
|
||||
return message_bundle->GetLastMessageId();
|
||||
}
|
||||
|
||||
protected:
|
||||
uint64_t seq_id_;
|
||||
uint64_t timestamp_;
|
||||
bool raw_;
|
||||
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer_;
|
||||
};
|
||||
|
||||
class InvalidQueueItem : public QueueItem {
|
||||
public:
|
||||
InvalidQueueItem() : QueueItem(QUEUE_INVALID_SEQ_ID, data_, 1, 0) {}
|
||||
|
||||
private:
|
||||
uint8_t data_[1];
|
||||
};
|
||||
typedef std::shared_ptr<QueueItem> QueueItemPtr;
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
#endif
|
94
streaming/src/queue/transport.cc
Normal file
94
streaming/src/queue/transport.cc
Normal file
|
@ -0,0 +1,94 @@
|
|||
#include "transport.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
static constexpr int TASK_OPTION_RETURN_NUM_0 = 0;
|
||||
static constexpr int TASK_OPTION_RETURN_NUM_1 = 1;
|
||||
|
||||
void Transport::SendInternal(std::shared_ptr<LocalMemoryBuffer> buffer,
|
||||
RayFunction &function, int return_num,
|
||||
std::vector<ObjectID> &return_ids) {
|
||||
std::unordered_map<std::string, double> resources;
|
||||
TaskOptions options{return_num, true, resources};
|
||||
|
||||
char meta_data[3] = {'R', 'A', 'W'};
|
||||
std::shared_ptr<LocalMemoryBuffer> meta =
|
||||
std::make_shared<LocalMemoryBuffer>((uint8_t *)meta_data, 3, true);
|
||||
|
||||
std::vector<TaskArg> args;
|
||||
if (function.GetLanguage() == Language::PYTHON) {
|
||||
auto dummy = "__RAY_DUMMY__";
|
||||
std::shared_ptr<LocalMemoryBuffer> dummyBuffer =
|
||||
std::make_shared<LocalMemoryBuffer>((uint8_t *)dummy, 13, true);
|
||||
args.emplace_back(TaskArg::PassByValue(
|
||||
std::make_shared<RayObject>(std::move(dummyBuffer), meta, true)));
|
||||
}
|
||||
args.emplace_back(
|
||||
TaskArg::PassByValue(std::make_shared<RayObject>(std::move(buffer), meta, true)));
|
||||
|
||||
STREAMING_CHECK(core_worker_ != nullptr);
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
ray::Status st =
|
||||
core_worker_->SubmitActorTask(peer_actor_id_, function, args, options, &return_ids);
|
||||
if (!st.ok()) {
|
||||
STREAMING_LOG(ERROR) << "SubmitActorTask failed. " << st;
|
||||
}
|
||||
}
|
||||
|
||||
void Transport::Send(std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function) {
|
||||
STREAMING_LOG(INFO) << "Transport::Send buffer size: " << buffer->Size();
|
||||
std::vector<ObjectID> return_ids;
|
||||
SendInternal(std::move(buffer), function, TASK_OPTION_RETURN_NUM_0, return_ids);
|
||||
}
|
||||
|
||||
std::shared_ptr<LocalMemoryBuffer> Transport::SendForResult(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function,
|
||||
int64_t timeout_ms) {
|
||||
std::vector<ObjectID> return_ids;
|
||||
SendInternal(buffer, function, TASK_OPTION_RETURN_NUM_1, return_ids);
|
||||
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
Status get_st = core_worker_->Get(return_ids, timeout_ms, &results);
|
||||
if (!get_st.ok()) {
|
||||
STREAMING_LOG(ERROR) << "Get fail.";
|
||||
return nullptr;
|
||||
}
|
||||
STREAMING_CHECK(results.size() >= 1);
|
||||
if (results[0]->IsException()) {
|
||||
STREAMING_LOG(ERROR) << "peer actor may has exceptions, should retry.";
|
||||
return nullptr;
|
||||
}
|
||||
STREAMING_CHECK(results[0]->HasData());
|
||||
if (results[0]->GetData()->Size() == 4) {
|
||||
STREAMING_LOG(WARNING) << "peer actor may not ready yet, should retry.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<Buffer> result_buffer = results[0]->GetData();
|
||||
std::shared_ptr<LocalMemoryBuffer> return_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||
result_buffer->Data(), result_buffer->Size(), true);
|
||||
return return_buffer;
|
||||
}
|
||||
|
||||
std::shared_ptr<LocalMemoryBuffer> Transport::SendForResultWithRetry(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function, int retry_cnt,
|
||||
int64_t timeout_ms) {
|
||||
STREAMING_LOG(INFO) << "SendForResultWithRetry retry_cnt: " << retry_cnt
|
||||
<< " timeout_ms: " << timeout_ms
|
||||
<< " function: " << function.GetFunctionDescriptor()[0];
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer_shared = std::move(buffer);
|
||||
for (int cnt = 0; cnt < retry_cnt; cnt++) {
|
||||
auto result = SendForResult(buffer_shared, function, timeout_ms);
|
||||
if (result != nullptr) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
STREAMING_LOG(WARNING) << "SendForResultWithRetry fail after retry.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
63
streaming/src/queue/transport.h
Normal file
63
streaming/src/queue/transport.h
Normal file
|
@ -0,0 +1,63 @@
|
|||
#ifndef _STREAMING_QUEUE_TRANSPORT_H_
|
||||
#define _STREAMING_QUEUE_TRANSPORT_H_
|
||||
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
#include "util/streaming_logging.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
/// Transport is the transfer endpoint to a specific actor, buffers can be sent to peer
|
||||
/// through direct actor call.
|
||||
class Transport {
|
||||
public:
|
||||
/// Construct a Transport object.
|
||||
/// \param[in] core_worker CoreWorker C++ pointer of current actor, which we call direct
|
||||
/// actor call interface with.
|
||||
/// \param[in] peer_actor_id actor id of peer actor.
|
||||
Transport(CoreWorker *core_worker, const ActorID &peer_actor_id)
|
||||
: core_worker_(core_worker), peer_actor_id_(peer_actor_id) {}
|
||||
virtual ~Transport() = default;
|
||||
|
||||
/// Send buffer asynchronously, peer's `function` will be called.
|
||||
/// \param[in] buffer buffer to be sent.
|
||||
/// \param[in] function the function descriptor of peer's function.
|
||||
virtual void Send(std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function);
|
||||
/// Send buffer synchronously, peer's `function` will be called, and return the peer
|
||||
/// function's return value.
|
||||
/// \param[in] buffer buffer to be sent.
|
||||
/// \param[in] function the function descriptor of peer's function.
|
||||
/// \param[in] timeout_ms max time to wait for result.
|
||||
/// \return peer function's result.
|
||||
virtual std::shared_ptr<LocalMemoryBuffer> SendForResult(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function,
|
||||
int64_t timeout_ms);
|
||||
/// Send buffer and get result with retry.
|
||||
/// return value.
|
||||
/// \param[in] buffer buffer to be sent.
|
||||
/// \param[in] function the function descriptor of peer's function.
|
||||
/// \param[in] max retry count
|
||||
/// \param[in] timeout_ms max time to wait for result.
|
||||
/// \return peer function's result.
|
||||
std::shared_ptr<LocalMemoryBuffer> SendForResultWithRetry(
|
||||
std::shared_ptr<LocalMemoryBuffer> buffer, RayFunction &function, int retry_cnt,
|
||||
int64_t timeout_ms);
|
||||
|
||||
private:
|
||||
/// Send buffer internal
|
||||
/// \param[in] buffer buffer to be sent.
|
||||
/// \param[in] function the function descriptor of peer's function.
|
||||
/// \param[in] return_num return value number of the call.
|
||||
/// \param[out] return_ids return ids from SubmitActorTask.
|
||||
virtual void SendInternal(std::shared_ptr<LocalMemoryBuffer> buffer,
|
||||
RayFunction &function, int return_num,
|
||||
std::vector<ObjectID> &return_ids);
|
||||
|
||||
private:
|
||||
CoreWorker *core_worker_;
|
||||
ActorID peer_actor_id_;
|
||||
};
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
#endif
|
50
streaming/src/queue/utils.h
Normal file
50
streaming/src/queue/utils.h
Normal file
|
@ -0,0 +1,50 @@
|
|||
#ifndef _STREAMING_QUEUE_UTILS_H_
|
||||
#define _STREAMING_QUEUE_UTILS_H_
|
||||
#include <chrono>
|
||||
#include <future>
|
||||
#include <thread>
|
||||
#include "ray/util/util.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
/// Helper class encapulate std::future to help multithread async wait.
|
||||
class PromiseWrapper {
|
||||
public:
|
||||
Status Wait() {
|
||||
std::future<bool> fut = promise_.get_future();
|
||||
fut.get();
|
||||
return status_;
|
||||
}
|
||||
|
||||
Status WaitFor(uint64_t timeout_ms) {
|
||||
std::future<bool> fut = promise_.get_future();
|
||||
std::future_status status;
|
||||
do {
|
||||
status = fut.wait_for(std::chrono::milliseconds(timeout_ms));
|
||||
if (status == std::future_status::deferred) {
|
||||
} else if (status == std::future_status::timeout) {
|
||||
return Status::Invalid("timeout");
|
||||
} else if (status == std::future_status::ready) {
|
||||
return status_;
|
||||
}
|
||||
} while (status == std::future_status::deferred);
|
||||
|
||||
return status_;
|
||||
}
|
||||
|
||||
void Notify(Status status) {
|
||||
status_ = status;
|
||||
promise_.set_value(true);
|
||||
}
|
||||
|
||||
Status GetResultStatus() { return status_; }
|
||||
|
||||
private:
|
||||
std::promise<bool> promise_;
|
||||
Status status_;
|
||||
};
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
#endif
|
82
streaming/src/ring_buffer.cc
Normal file
82
streaming/src/ring_buffer.cc
Normal file
|
@ -0,0 +1,82 @@
|
|||
#include "ring_buffer.h"
|
||||
#include "util/streaming_logging.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
StreamingRingBuffer::StreamingRingBuffer(size_t buf_size,
|
||||
StreamingRingBufferType buffer_type) {
|
||||
switch (buffer_type) {
|
||||
case StreamingRingBufferType::SPSC:
|
||||
message_buffer_ =
|
||||
std::make_shared<RingBufferImplLockFree<StreamingMessagePtr>>(buf_size);
|
||||
break;
|
||||
case StreamingRingBufferType::SPSC_LOCK:
|
||||
default:
|
||||
message_buffer_ =
|
||||
std::make_shared<RingBufferImplThreadSafe<StreamingMessagePtr>>(buf_size);
|
||||
}
|
||||
}
|
||||
|
||||
bool StreamingRingBuffer::Push(const StreamingMessagePtr &msg) {
|
||||
message_buffer_->Push(msg);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StreamingRingBuffer::Push(StreamingMessagePtr &&msg) {
|
||||
message_buffer_->Push(std::forward<StreamingMessagePtr>(msg));
|
||||
return true;
|
||||
}
|
||||
|
||||
StreamingMessagePtr &StreamingRingBuffer::Front() {
|
||||
STREAMING_CHECK(!message_buffer_->Empty());
|
||||
return message_buffer_->Front();
|
||||
}
|
||||
|
||||
void StreamingRingBuffer::Pop() {
|
||||
STREAMING_CHECK(!message_buffer_->Empty());
|
||||
message_buffer_->Pop();
|
||||
}
|
||||
|
||||
bool StreamingRingBuffer::IsFull() { return message_buffer_->Full(); }
|
||||
|
||||
bool StreamingRingBuffer::IsEmpty() { return message_buffer_->Empty(); }
|
||||
|
||||
size_t StreamingRingBuffer::Size() { return message_buffer_->Size(); };
|
||||
|
||||
size_t StreamingRingBuffer::Capacity() const { return message_buffer_->Capacity(); }
|
||||
|
||||
size_t StreamingRingBuffer::GetTransientBufferSize() {
|
||||
return transient_buffer_.GetTransientBufferSize();
|
||||
};
|
||||
|
||||
void StreamingRingBuffer::SetTransientBufferSize(uint32_t new_transient_buffer_size) {
|
||||
return transient_buffer_.SetTransientBufferSize(new_transient_buffer_size);
|
||||
}
|
||||
|
||||
size_t StreamingRingBuffer::GetMaxTransientBufferSize() const {
|
||||
return transient_buffer_.GetMaxTransientBufferSize();
|
||||
}
|
||||
|
||||
const uint8_t *StreamingRingBuffer::GetTransientBuffer() const {
|
||||
return transient_buffer_.GetTransientBuffer();
|
||||
}
|
||||
|
||||
uint8_t *StreamingRingBuffer::GetTransientBufferMutable() const {
|
||||
return transient_buffer_.GetTransientBufferMutable();
|
||||
}
|
||||
|
||||
void StreamingRingBuffer::ReallocTransientBuffer(uint32_t size) {
|
||||
transient_buffer_.ReallocTransientBuffer(size);
|
||||
}
|
||||
|
||||
bool StreamingRingBuffer::IsTransientAvaliable() {
|
||||
return transient_buffer_.IsTransientAvaliable();
|
||||
}
|
||||
|
||||
void StreamingRingBuffer::FreeTransientBuffer(bool is_force) {
|
||||
transient_buffer_.FreeTransientBuffer(is_force);
|
||||
}
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
233
streaming/src/ring_buffer.h
Normal file
233
streaming/src/ring_buffer.h
Normal file
|
@ -0,0 +1,233 @@
|
|||
#ifndef RAY_RING_BUFFER_H
|
||||
#define RAY_RING_BUFFER_H
|
||||
|
||||
#include <atomic>
|
||||
#include <boost/circular_buffer.hpp>
|
||||
#include <boost/thread/locks.hpp>
|
||||
#include <boost/thread/shared_mutex.hpp>
|
||||
#include <condition_variable>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
|
||||
#include "message/message.h"
|
||||
#include "ray/common/status.h"
|
||||
#include "util/streaming_logging.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
/// Because the data cannot be successfully written to the channel every time, in
|
||||
/// order not to serialize the message repeatedly, we designed a temporary buffer
|
||||
/// area so that when the downstream is backpressured or the channel is blocked
|
||||
/// due to memory limitations, it can be cached first and waited for the next use.
|
||||
class StreamingTransientBuffer {
|
||||
private:
|
||||
std::shared_ptr<uint8_t> transient_buffer_;
|
||||
// BufferSize is length of last serialization data.
|
||||
uint32_t transient_buffer_size_ = 0;
|
||||
uint32_t max_transient_buffer_size_ = 0;
|
||||
bool transient_flag_ = false;
|
||||
|
||||
public:
|
||||
inline size_t GetTransientBufferSize() const { return transient_buffer_size_; }
|
||||
|
||||
inline void SetTransientBufferSize(uint32_t new_transient_buffer_size) {
|
||||
transient_buffer_size_ = new_transient_buffer_size;
|
||||
}
|
||||
|
||||
inline size_t GetMaxTransientBufferSize() const { return max_transient_buffer_size_; }
|
||||
|
||||
inline const uint8_t *GetTransientBuffer() const { return transient_buffer_.get(); }
|
||||
|
||||
inline uint8_t *GetTransientBufferMutable() const { return transient_buffer_.get(); }
|
||||
|
||||
/// To reuse transient buffer, we will realloc buffer memory if size of needed
|
||||
/// message bundle raw data is greater-than original buffer size.
|
||||
/// \param size buffer size
|
||||
///
|
||||
inline void ReallocTransientBuffer(uint32_t size) {
|
||||
transient_buffer_size_ = size;
|
||||
transient_flag_ = true;
|
||||
if (max_transient_buffer_size_ > size) {
|
||||
return;
|
||||
}
|
||||
max_transient_buffer_size_ = size;
|
||||
transient_buffer_.reset(new uint8_t[size], std::default_delete<uint8_t[]>());
|
||||
}
|
||||
|
||||
inline bool IsTransientAvaliable() { return transient_flag_; }
|
||||
|
||||
inline void FreeTransientBuffer(bool is_force = false) {
|
||||
transient_buffer_size_ = 0;
|
||||
transient_flag_ = false;
|
||||
|
||||
// Transient buffer always holds max size buffer among all messages, which is
|
||||
// wasteful. So expiration time is considerable idea to release large buffer if this
|
||||
// transient buffer pointer hold it in long time.
|
||||
|
||||
if (is_force) {
|
||||
max_transient_buffer_size_ = 0;
|
||||
transient_buffer_.reset();
|
||||
}
|
||||
}
|
||||
|
||||
virtual ~StreamingTransientBuffer() = default;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class AbstractRingBufferImpl {
|
||||
public:
|
||||
virtual void Push(T &&) = 0;
|
||||
virtual void Push(const T &) = 0;
|
||||
virtual void Pop() = 0;
|
||||
virtual T &Front() = 0;
|
||||
virtual bool Empty() = 0;
|
||||
virtual bool Full() = 0;
|
||||
virtual size_t Size() = 0;
|
||||
virtual size_t Capacity() = 0;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class RingBufferImplThreadSafe : public AbstractRingBufferImpl<T> {
|
||||
private:
|
||||
boost::shared_mutex ring_buffer_mutex_;
|
||||
boost::circular_buffer<T> buffer_;
|
||||
|
||||
public:
|
||||
RingBufferImplThreadSafe(size_t size) : buffer_(size) {}
|
||||
virtual ~RingBufferImplThreadSafe() = default;
|
||||
void Push(T &&t) {
|
||||
boost::unique_lock<boost::shared_mutex> lock(ring_buffer_mutex_);
|
||||
buffer_.push_back(t);
|
||||
}
|
||||
void Push(const T &t) {
|
||||
boost::unique_lock<boost::shared_mutex> lock(ring_buffer_mutex_);
|
||||
buffer_.push_back(t);
|
||||
}
|
||||
void Pop() {
|
||||
boost::unique_lock<boost::shared_mutex> lock(ring_buffer_mutex_);
|
||||
buffer_.pop_front();
|
||||
}
|
||||
T &Front() {
|
||||
boost::shared_lock<boost::shared_mutex> lock(ring_buffer_mutex_);
|
||||
return buffer_.front();
|
||||
}
|
||||
bool Empty() {
|
||||
boost::shared_lock<boost::shared_mutex> lock(ring_buffer_mutex_);
|
||||
return buffer_.empty();
|
||||
}
|
||||
bool Full() {
|
||||
boost::shared_lock<boost::shared_mutex> lock(ring_buffer_mutex_);
|
||||
return buffer_.full();
|
||||
}
|
||||
size_t Size() {
|
||||
boost::shared_lock<boost::shared_mutex> lock(ring_buffer_mutex_);
|
||||
return buffer_.size();
|
||||
}
|
||||
size_t Capacity() { return buffer_.capacity(); }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class RingBufferImplLockFree : public AbstractRingBufferImpl<T> {
|
||||
private:
|
||||
std::vector<T> buffer_;
|
||||
std::atomic<size_t> capacity_;
|
||||
std::atomic<size_t> read_index_;
|
||||
std::atomic<size_t> write_index_;
|
||||
|
||||
public:
|
||||
RingBufferImplLockFree(size_t size)
|
||||
: buffer_(size, nullptr), capacity_(size), read_index_(0), write_index_(0) {}
|
||||
virtual ~RingBufferImplLockFree() = default;
|
||||
|
||||
void Push(T &&t) {
|
||||
STREAMING_CHECK(!Full());
|
||||
buffer_[write_index_] = t;
|
||||
write_index_ = IncreaseIndex(write_index_);
|
||||
}
|
||||
|
||||
void Push(const T &t) {
|
||||
STREAMING_CHECK(!Full());
|
||||
buffer_[write_index_] = t;
|
||||
write_index_ = IncreaseIndex(write_index_);
|
||||
}
|
||||
|
||||
void Pop() {
|
||||
STREAMING_CHECK(!Empty());
|
||||
read_index_ = IncreaseIndex(read_index_);
|
||||
}
|
||||
|
||||
T &Front() {
|
||||
STREAMING_CHECK(!Empty());
|
||||
return buffer_[read_index_];
|
||||
}
|
||||
|
||||
bool Empty() { return write_index_ == read_index_; }
|
||||
|
||||
bool Full() { return IncreaseIndex(write_index_) == read_index_; }
|
||||
|
||||
size_t Size() { return (write_index_ + capacity_ - read_index_) % capacity_; }
|
||||
|
||||
size_t Capacity() { return capacity_; }
|
||||
|
||||
private:
|
||||
size_t IncreaseIndex(size_t index) const { return (index + 1) % capacity_; }
|
||||
};
|
||||
|
||||
enum class StreamingRingBufferType : uint8_t { SPSC_LOCK, SPSC };
|
||||
|
||||
/// StreamingRinggBuffer is factory to generate two different buffers. In data
|
||||
/// writer, we use lock-free single producer single consumer (SPSC) ring buffer
|
||||
/// to hold messages from user thread because SPSC has much better performance
|
||||
/// than lock style. Since the SPSC_LOCK is useful to our event-driver model(
|
||||
/// we will use that buffer to optimize our thread model in the future), so
|
||||
/// it cann't be removed currently.
|
||||
class StreamingRingBuffer {
|
||||
private:
|
||||
std::shared_ptr<AbstractRingBufferImpl<StreamingMessagePtr>> message_buffer_;
|
||||
|
||||
StreamingTransientBuffer transient_buffer_;
|
||||
|
||||
public:
|
||||
explicit StreamingRingBuffer(size_t buf_size, StreamingRingBufferType buffer_type =
|
||||
StreamingRingBufferType::SPSC_LOCK);
|
||||
|
||||
bool Push(StreamingMessagePtr &&msg);
|
||||
|
||||
bool Push(const StreamingMessagePtr &msg);
|
||||
|
||||
StreamingMessagePtr &Front();
|
||||
|
||||
void Pop();
|
||||
|
||||
bool IsFull();
|
||||
|
||||
bool IsEmpty();
|
||||
|
||||
size_t Size();
|
||||
|
||||
size_t Capacity() const;
|
||||
|
||||
size_t GetTransientBufferSize();
|
||||
|
||||
void SetTransientBufferSize(uint32_t new_transient_buffer_size);
|
||||
|
||||
size_t GetMaxTransientBufferSize() const;
|
||||
|
||||
const uint8_t *GetTransientBuffer() const;
|
||||
|
||||
uint8_t *GetTransientBufferMutable() const;
|
||||
|
||||
void ReallocTransientBuffer(uint32_t size);
|
||||
|
||||
bool IsTransientAvaliable();
|
||||
|
||||
void FreeTransientBuffer(bool is_force = false);
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<StreamingRingBuffer> StreamingRingBufferPtr;
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RING_BUFFER_H
|
32
streaming/src/runtime_context.cc
Normal file
32
streaming/src/runtime_context.cc
Normal file
|
@ -0,0 +1,32 @@
|
|||
#include "ray/common/id.h"
|
||||
#include "ray/protobuf/common.pb.h"
|
||||
#include "ray/util/util.h"
|
||||
|
||||
#include "runtime_context.h"
|
||||
#include "util/streaming_logging.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
void RuntimeContext::SetConfig(const StreamingConfig &streaming_config) {
|
||||
STREAMING_CHECK(runtime_status_ == RuntimeStatus::Init)
|
||||
<< "set config must be at beginning";
|
||||
config_ = streaming_config;
|
||||
}
|
||||
|
||||
void RuntimeContext::SetConfig(const uint8_t *data, uint32_t size) {
|
||||
STREAMING_CHECK(runtime_status_ == RuntimeStatus::Init)
|
||||
<< "set config must be at beginning";
|
||||
if (!data) {
|
||||
STREAMING_LOG(WARNING) << "buffer pointer is null, but len is => " << size;
|
||||
return;
|
||||
}
|
||||
config_.FromProto(data, size);
|
||||
}
|
||||
|
||||
RuntimeContext::~RuntimeContext() {}
|
||||
|
||||
RuntimeContext::RuntimeContext() : runtime_status_(RuntimeStatus::Init) {}
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
42
streaming/src/runtime_context.h
Normal file
42
streaming/src/runtime_context.h
Normal file
|
@ -0,0 +1,42 @@
|
|||
#ifndef RAY_STREAMING_H
|
||||
#define RAY_STREAMING_H
|
||||
#include <string>
|
||||
|
||||
#include "config/streaming_config.h"
|
||||
#include "status.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
enum class RuntimeStatus : uint8_t { Init = 0, Running = 1, Interrupted = 2 };
|
||||
|
||||
#define RETURN_IF_NOT_OK(STATUS_EXP) \
|
||||
{ \
|
||||
StreamingStatus state = STATUS_EXP; \
|
||||
if (StreamingStatus::OK != state) { \
|
||||
return state; \
|
||||
} \
|
||||
}
|
||||
|
||||
class RuntimeContext {
|
||||
public:
|
||||
RuntimeContext();
|
||||
virtual ~RuntimeContext();
|
||||
inline const StreamingConfig &GetConfig() const { return config_; };
|
||||
void SetConfig(const StreamingConfig &config);
|
||||
void SetConfig(const uint8_t *data, uint32_t buffer_len);
|
||||
inline RuntimeStatus GetRuntimeStatus() { return runtime_status_; }
|
||||
inline void SetRuntimeStatus(RuntimeStatus status) { runtime_status_ = status; }
|
||||
inline void MarkMockTest() { is_mock_test_ = true; }
|
||||
inline bool IsMockTest() { return is_mock_test_; }
|
||||
|
||||
private:
|
||||
StreamingConfig config_;
|
||||
RuntimeStatus runtime_status_;
|
||||
bool is_mock_test_ = false;
|
||||
};
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_STREAMING_H
|
47
streaming/src/status.h
Normal file
47
streaming/src/status.h
Normal file
|
@ -0,0 +1,47 @@
|
|||
#ifndef RAY_STREAMING_STATUS_H
|
||||
#define RAY_STREAMING_STATUS_H
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
enum class StreamingStatus : uint32_t {
|
||||
OK = 0,
|
||||
ReconstructTimeOut = 1,
|
||||
QueueIdNotFound = 3,
|
||||
ResubscribeFailed = 4,
|
||||
EmptyRingBuffer = 5,
|
||||
FullChannel = 6,
|
||||
NoSuchItem = 7,
|
||||
InitQueueFailed = 8,
|
||||
GetBundleTimeOut = 9,
|
||||
SkipSendEmptyMessage = 10,
|
||||
Interrupted = 11,
|
||||
WaitQueueTimeOut = 12,
|
||||
OutOfMemory = 13,
|
||||
Invalid = 14,
|
||||
UnknownError = 15,
|
||||
TailStatus = 999,
|
||||
MIN = OK,
|
||||
MAX = TailStatus
|
||||
};
|
||||
|
||||
static inline std::ostream &operator<<(std::ostream &os, const StreamingStatus &status) {
|
||||
os << static_cast<std::underlying_type<StreamingStatus>::type>(status);
|
||||
return os;
|
||||
}
|
||||
|
||||
#define RETURN_IF_NOT_OK(STATUS_EXP) \
|
||||
{ \
|
||||
StreamingStatus state = STATUS_EXP; \
|
||||
if (StreamingStatus::OK != state) { \
|
||||
return state; \
|
||||
} \
|
||||
}
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_STREAMING_STATUS_H
|
176
streaming/src/test/message_serialization_tests.cc
Normal file
176
streaming/src/test/message_serialization_tests.cc
Normal file
|
@ -0,0 +1,176 @@
|
|||
#include <cstring>
|
||||
#include <string>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "message/message.h"
|
||||
#include "message/message_bundle.h"
|
||||
|
||||
using namespace ray;
|
||||
using namespace ray::streaming;
|
||||
|
||||
TEST(StreamingSerializationTest, streaming_message_serialization_test) {
|
||||
uint8_t data[] = {9, 1, 3};
|
||||
StreamingMessagePtr message =
|
||||
std::make_shared<StreamingMessage>(data, 3, 7, StreamingMessageType::Message);
|
||||
uint32_t message_length = message->ClassBytesSize();
|
||||
uint8_t *bytes = new uint8_t[message_length];
|
||||
message->ToBytes(bytes);
|
||||
StreamingMessagePtr new_message = StreamingMessage::FromBytes(bytes);
|
||||
EXPECT_EQ(std::memcmp(new_message->RawData(), data, 3), 0);
|
||||
delete[] bytes;
|
||||
}
|
||||
|
||||
TEST(StreamingSerializationTest, streaming_message_empty_bundle_serialization_test) {
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
StreamingMessageBundle bundle(i, i);
|
||||
uint64_t bundle_size = bundle.ClassBytesSize();
|
||||
uint8_t *bundle_bytes = new uint8_t[bundle_size];
|
||||
bundle.ToBytes(bundle_bytes);
|
||||
StreamingMessageBundlePtr bundle_ptr =
|
||||
StreamingMessageBundle::FromBytes(bundle_bytes);
|
||||
|
||||
EXPECT_EQ(bundle.ClassBytesSize(), bundle_ptr->ClassBytesSize());
|
||||
EXPECT_EQ(bundle.GetMessageListSize(), bundle_ptr->GetMessageListSize());
|
||||
EXPECT_EQ(bundle.GetBundleType(), bundle_ptr->GetBundleType());
|
||||
EXPECT_EQ(bundle.GetLastMessageId(), bundle_ptr->GetLastMessageId());
|
||||
std::list<StreamingMessagePtr> s_message_list;
|
||||
bundle_ptr->GetMessageList(s_message_list);
|
||||
std::list<StreamingMessagePtr> b_message_list;
|
||||
bundle.GetMessageList(b_message_list);
|
||||
EXPECT_EQ(b_message_list.size(), 0);
|
||||
EXPECT_EQ(s_message_list.size(), 0);
|
||||
|
||||
delete[] bundle_bytes;
|
||||
}
|
||||
}
|
||||
TEST(StreamingSerializationTest, streaming_message_barrier_bundle_serialization_test) {
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
uint8_t data[] = {1, 2, 3, 4};
|
||||
uint32_t data_size = 4;
|
||||
uint32_t head_size = sizeof(uint64_t);
|
||||
uint64_t checkpoint_id = 777;
|
||||
std::shared_ptr<uint8_t> ptr(new uint8_t[data_size + head_size],
|
||||
std::default_delete<uint8_t[]>());
|
||||
// move checkpint_id in head of barrier data
|
||||
std::memcpy(ptr.get(), &checkpoint_id, head_size);
|
||||
std::memcpy(ptr.get() + head_size, data, data_size);
|
||||
StreamingMessagePtr message = std::make_shared<StreamingMessage>(
|
||||
data, head_size + data_size, i, StreamingMessageType::Barrier);
|
||||
std::list<StreamingMessagePtr> message_list;
|
||||
message_list.push_back(message);
|
||||
// message list will be moved to bundle member
|
||||
std::list<StreamingMessagePtr> message_list_cpy(message_list);
|
||||
|
||||
StreamingMessageBundle bundle(message_list_cpy, i, i,
|
||||
StreamingMessageBundleType::Barrier);
|
||||
uint64_t bundle_size = bundle.ClassBytesSize();
|
||||
uint8_t *bundle_bytes = new uint8_t[bundle_size];
|
||||
bundle.ToBytes(bundle_bytes);
|
||||
StreamingMessageBundlePtr bundle_ptr =
|
||||
StreamingMessageBundle::FromBytes(bundle_bytes);
|
||||
|
||||
EXPECT_TRUE(bundle.ClassBytesSize() == bundle_ptr->ClassBytesSize());
|
||||
EXPECT_TRUE(bundle.GetMessageListSize() == bundle_ptr->GetMessageListSize());
|
||||
EXPECT_TRUE(bundle.GetBundleType() == bundle_ptr->GetBundleType());
|
||||
EXPECT_TRUE(bundle.GetLastMessageId() == bundle_ptr->GetLastMessageId());
|
||||
std::list<StreamingMessagePtr> s_message_list;
|
||||
bundle_ptr->GetMessageList(s_message_list);
|
||||
EXPECT_TRUE(s_message_list.size() == message_list.size());
|
||||
auto m_item = message_list.back();
|
||||
auto s_item = s_message_list.back();
|
||||
EXPECT_TRUE(s_item->ClassBytesSize() == m_item->ClassBytesSize());
|
||||
EXPECT_TRUE(s_item->GetMessageType() == m_item->GetMessageType());
|
||||
EXPECT_TRUE(s_item->GetMessageSeqId() == m_item->GetMessageSeqId());
|
||||
EXPECT_TRUE(s_item->GetDataSize() == m_item->GetDataSize());
|
||||
EXPECT_TRUE(
|
||||
std::memcmp(s_item->RawData(), m_item->RawData(), m_item->GetDataSize()) == 0);
|
||||
EXPECT_TRUE(*(s_item.get()) == (*(m_item.get())));
|
||||
|
||||
delete[] bundle_bytes;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(StreamingSerializationTest, streaming_message_bundle_serialization_test) {
|
||||
for (int k = 0; k <= 1000; k++) {
|
||||
std::list<StreamingMessagePtr> message_list;
|
||||
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
uint8_t *data = new uint8_t[i + 1];
|
||||
data[0] = i;
|
||||
StreamingMessagePtr message = std::make_shared<StreamingMessage>(
|
||||
data, i + 1, i + 1, StreamingMessageType::Message);
|
||||
message_list.push_back(message);
|
||||
delete[] data;
|
||||
}
|
||||
StreamingMessageBundle messageBundle(message_list, 0, 1,
|
||||
StreamingMessageBundleType::Bundle);
|
||||
size_t message_length = messageBundle.ClassBytesSize();
|
||||
uint8_t *bytes = new uint8_t[message_length];
|
||||
messageBundle.ToBytes(bytes);
|
||||
|
||||
StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(bytes);
|
||||
EXPECT_EQ(bundle_ptr->ClassBytesSize(), message_length);
|
||||
std::list<StreamingMessagePtr> s_message_list;
|
||||
bundle_ptr->GetMessageList(s_message_list);
|
||||
EXPECT_TRUE(bundle_ptr->operator==(messageBundle));
|
||||
StreamingMessageBundleMetaPtr bundle_meta_ptr =
|
||||
StreamingMessageBundleMeta::FromBytes(bytes);
|
||||
|
||||
EXPECT_EQ(bundle_meta_ptr->GetBundleType(), bundle_ptr->GetBundleType());
|
||||
EXPECT_EQ(bundle_meta_ptr->GetLastMessageId(), bundle_ptr->GetLastMessageId());
|
||||
EXPECT_EQ(bundle_meta_ptr->GetMessageBundleTs(), bundle_ptr->GetMessageBundleTs());
|
||||
EXPECT_EQ(bundle_meta_ptr->GetMessageListSize(), bundle_ptr->GetMessageListSize());
|
||||
delete[] bytes;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(StreamingSerializationTest, streaming_message_bundle_equal_test) {
|
||||
std::list<StreamingMessagePtr> message_list;
|
||||
std::list<StreamingMessagePtr> message_list_same;
|
||||
std::list<StreamingMessagePtr> message_list_cpy;
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
uint8_t *data = new uint8_t[i + 1];
|
||||
for (int j = 0; j < i + 1; ++j) {
|
||||
data[j] = i;
|
||||
}
|
||||
StreamingMessagePtr message = std::make_shared<StreamingMessage>(
|
||||
data, i + 1, i + 1, StreamingMessageType::Message);
|
||||
message_list.push_back(message);
|
||||
message_list_cpy.push_front(message);
|
||||
delete[] data;
|
||||
}
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
uint8_t *data = new uint8_t[i + 1];
|
||||
for (int j = 0; j < i + 1; ++j) {
|
||||
data[j] = i;
|
||||
}
|
||||
StreamingMessagePtr message = std::make_shared<StreamingMessage>(
|
||||
data, i + 1, i + 1, StreamingMessageType::Message);
|
||||
message_list_same.push_back(message);
|
||||
delete[] data;
|
||||
}
|
||||
StreamingMessageBundle message_bundle(message_list, 0, 1,
|
||||
StreamingMessageBundleType::Bundle);
|
||||
StreamingMessageBundle message_bundle_same(message_list_same, 0, 1,
|
||||
StreamingMessageBundleType::Bundle);
|
||||
StreamingMessageBundle message_bundle_reverse(message_list_cpy, 0, 1,
|
||||
StreamingMessageBundleType::Bundle);
|
||||
EXPECT_TRUE(message_bundle_same == message_bundle);
|
||||
EXPECT_FALSE(message_bundle_reverse == message_bundle);
|
||||
size_t message_length = message_bundle.ClassBytesSize();
|
||||
uint8_t *bytes = new uint8_t[message_length];
|
||||
message_bundle.ToBytes(bytes);
|
||||
|
||||
StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(bytes);
|
||||
EXPECT_EQ(bundle_ptr->ClassBytesSize(), message_length);
|
||||
std::list<StreamingMessagePtr> s_message_list;
|
||||
bundle_ptr->GetMessageList(s_message_list);
|
||||
EXPECT_TRUE(bundle_ptr->operator==(message_bundle));
|
||||
delete[] bytes;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
439
streaming/src/test/mock_actor.cc
Normal file
439
streaming/src/test/mock_actor.cc
Normal file
|
@ -0,0 +1,439 @@
|
|||
#define BOOST_BIND_NO_PLACEHOLDERS
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
#include "src/ray/util/test_util.h"
|
||||
|
||||
#include "data_reader.h"
|
||||
#include "data_writer.h"
|
||||
#include "message/message.h"
|
||||
#include "message/message_bundle.h"
|
||||
#include "queue/queue_client.h"
|
||||
#include "ring_buffer.h"
|
||||
#include "status.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
using namespace std::placeholders;
|
||||
|
||||
const uint32_t MESSAGE_BOUND_SIZE = 10000;
|
||||
const uint32_t DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE = 1000;
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
class StreamingQueueTestSuite {
|
||||
public:
|
||||
StreamingQueueTestSuite(std::shared_ptr<CoreWorker> core_worker, ActorID &peer_actor_id,
|
||||
std::vector<ObjectID> queue_ids,
|
||||
std::vector<ObjectID> rescale_queue_ids)
|
||||
: core_worker_(core_worker),
|
||||
peer_actor_id_(peer_actor_id),
|
||||
queue_ids_(queue_ids),
|
||||
rescale_queue_ids_(rescale_queue_ids) {}
|
||||
|
||||
virtual void ExecuteTest(std::string test_name) {
|
||||
auto it = test_func_map_.find(test_name);
|
||||
STREAMING_CHECK(it != test_func_map_.end());
|
||||
current_test_ = test_name;
|
||||
status_ = false;
|
||||
auto func = it->second;
|
||||
executor_thread_ = std::make_shared<std::thread>(func);
|
||||
executor_thread_->detach();
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<LocalMemoryBuffer> CheckCurTestStatus() {
|
||||
TestCheckStatusRspMsg msg(current_test_, status_);
|
||||
return msg.ToBytes();
|
||||
}
|
||||
|
||||
virtual bool TestDone() { return status_; }
|
||||
|
||||
virtual ~StreamingQueueTestSuite() {}
|
||||
|
||||
protected:
|
||||
std::unordered_map<std::string, std::function<void()>> test_func_map_;
|
||||
std::string current_test_;
|
||||
bool status_;
|
||||
std::shared_ptr<std::thread> executor_thread_;
|
||||
std::shared_ptr<CoreWorker> core_worker_;
|
||||
ActorID peer_actor_id_;
|
||||
std::vector<ObjectID> queue_ids_;
|
||||
std::vector<ObjectID> rescale_queue_ids_;
|
||||
};
|
||||
|
||||
class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite {
|
||||
public:
|
||||
StreamingQueueWriterTestSuite(std::shared_ptr<CoreWorker> core_worker,
|
||||
ActorID &peer_actor_id, std::vector<ObjectID> queue_ids,
|
||||
std::vector<ObjectID> rescale_queue_ids)
|
||||
: StreamingQueueTestSuite(core_worker, peer_actor_id, queue_ids,
|
||||
rescale_queue_ids) {
|
||||
test_func_map_ = {
|
||||
{"streaming_writer_exactly_once_test",
|
||||
std::bind(&StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest,
|
||||
this)}};
|
||||
}
|
||||
|
||||
private:
|
||||
void TestWriteMessageToBufferRing(std::shared_ptr<DataWriter> writer_client,
|
||||
std::vector<ray::ObjectID> &q_list) {
|
||||
// const uint8_t temp_data[] = {1, 2, 4, 5};
|
||||
|
||||
uint32_t i = 1;
|
||||
while (i <= MESSAGE_BOUND_SIZE) {
|
||||
for (auto &q_id : q_list) {
|
||||
uint64_t buffer_len = (i % DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE);
|
||||
uint8_t *data = new uint8_t[buffer_len];
|
||||
for (uint32_t j = 0; j < buffer_len; ++j) {
|
||||
data[j] = j % 128;
|
||||
}
|
||||
|
||||
writer_client->WriteMessageToBufferRing(q_id, data, buffer_len,
|
||||
StreamingMessageType::Message);
|
||||
}
|
||||
++i;
|
||||
}
|
||||
|
||||
// Wait a while
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5000));
|
||||
}
|
||||
|
||||
void StreamingWriterStrategyTest(StreamingConfig &config) {
|
||||
for (auto &queue_id : queue_ids_) {
|
||||
STREAMING_LOG(INFO) << "queue_id: " << queue_id;
|
||||
}
|
||||
std::vector<ActorID> actor_ids(queue_ids_.size(), peer_actor_id_);
|
||||
STREAMING_LOG(INFO) << "writer actor_ids size: " << actor_ids.size()
|
||||
<< " actor_id: " << peer_actor_id_;
|
||||
|
||||
std::shared_ptr<RuntimeContext> runtime_context(new RuntimeContext());
|
||||
runtime_context->SetConfig(config);
|
||||
|
||||
std::shared_ptr<DataWriter> streaming_writer_client(new DataWriter(runtime_context));
|
||||
uint64_t queue_size = 10 * 1000 * 1000;
|
||||
std::vector<uint64_t> channel_seq_id_vec(queue_ids_.size(), 0);
|
||||
streaming_writer_client->Init(queue_ids_, actor_ids, channel_seq_id_vec,
|
||||
std::vector<uint64_t>(queue_ids_.size(), queue_size));
|
||||
STREAMING_LOG(INFO) << "streaming_writer_client Init done";
|
||||
|
||||
streaming_writer_client->Run();
|
||||
std::thread test_loop_thread(
|
||||
&StreamingQueueWriterTestSuite::TestWriteMessageToBufferRing, this,
|
||||
streaming_writer_client, std::ref(queue_ids_));
|
||||
// test_loop_thread.detach();
|
||||
if (test_loop_thread.joinable()) {
|
||||
test_loop_thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
void StreamingWriterExactlyOnceTest() {
|
||||
StreamingConfig config;
|
||||
StreamingWriterStrategyTest(config);
|
||||
|
||||
STREAMING_LOG(INFO)
|
||||
<< "StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest";
|
||||
status_ = true;
|
||||
}
|
||||
};
|
||||
|
||||
class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite {
|
||||
public:
|
||||
StreamingQueueReaderTestSuite(std::shared_ptr<CoreWorker> core_worker,
|
||||
ActorID peer_actor_id, std::vector<ObjectID> queue_ids,
|
||||
std::vector<ObjectID> rescale_queue_ids)
|
||||
: StreamingQueueTestSuite(core_worker, peer_actor_id, queue_ids,
|
||||
rescale_queue_ids) {
|
||||
test_func_map_ = {
|
||||
{"streaming_writer_exactly_once_test",
|
||||
std::bind(&StreamingQueueReaderTestSuite::StreamingWriterExactlyOnceTest,
|
||||
this)}};
|
||||
}
|
||||
|
||||
private:
|
||||
void ReaderLoopForward(std::shared_ptr<DataReader> reader_client,
|
||||
std::shared_ptr<DataWriter> writer_client,
|
||||
std::vector<ray::ObjectID> &queue_id_vec) {
|
||||
uint64_t recevied_message_cnt = 0;
|
||||
std::unordered_map<ray::ObjectID, uint64_t> queue_last_cp_id;
|
||||
|
||||
for (auto &q_id : queue_id_vec) {
|
||||
queue_last_cp_id[q_id] = 0;
|
||||
}
|
||||
STREAMING_LOG(INFO) << "Start read message bundle";
|
||||
while (true) {
|
||||
std::shared_ptr<DataBundle> msg;
|
||||
StreamingStatus st = reader_client->GetBundle(100, msg);
|
||||
|
||||
if (st != StreamingStatus::OK || !msg->data) {
|
||||
STREAMING_LOG(DEBUG) << "read bundle timeout, status = " << (int)st;
|
||||
continue;
|
||||
}
|
||||
|
||||
STREAMING_CHECK(msg.get() && msg->meta.get())
|
||||
<< "read null pointer message, queue id => " << msg->from.Hex();
|
||||
|
||||
if (msg->meta->GetBundleType() == StreamingMessageBundleType::Barrier) {
|
||||
STREAMING_LOG(DEBUG) << "barrier message recevied => "
|
||||
<< msg->meta->GetMessageBundleTs();
|
||||
std::unordered_map<ray::ObjectID, ConsumerChannelInfo> *offset_map;
|
||||
reader_client->GetOffsetInfo(offset_map);
|
||||
|
||||
for (auto &q_id : queue_id_vec) {
|
||||
reader_client->NotifyConsumedItem((*offset_map)[q_id],
|
||||
(*offset_map)[q_id].current_seq_id);
|
||||
}
|
||||
// writer_client->ClearCheckpoint(msg->last_barrier_id);
|
||||
|
||||
continue;
|
||||
} else if (msg->meta->GetBundleType() == StreamingMessageBundleType::Empty) {
|
||||
STREAMING_LOG(DEBUG) << "empty message recevied => "
|
||||
<< msg->meta->GetMessageBundleTs();
|
||||
continue;
|
||||
}
|
||||
|
||||
StreamingMessageBundlePtr bundlePtr;
|
||||
bundlePtr = StreamingMessageBundle::FromBytes(msg->data);
|
||||
std::list<StreamingMessagePtr> message_list;
|
||||
bundlePtr->GetMessageList(message_list);
|
||||
STREAMING_LOG(INFO) << "message size => " << message_list.size()
|
||||
<< " from queue id => " << msg->from.Hex()
|
||||
<< " last message id => " << msg->meta->GetLastMessageId();
|
||||
|
||||
recevied_message_cnt += message_list.size();
|
||||
for (auto &item : message_list) {
|
||||
uint64_t i = item->GetMessageSeqId();
|
||||
|
||||
uint32_t buff_len = i % DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE;
|
||||
if (i > MESSAGE_BOUND_SIZE) break;
|
||||
|
||||
EXPECT_EQ(buff_len, item->GetDataSize());
|
||||
uint8_t *compared_data = new uint8_t[buff_len];
|
||||
for (uint32_t j = 0; j < item->GetDataSize(); ++j) {
|
||||
compared_data[j] = j % 128;
|
||||
}
|
||||
EXPECT_EQ(std::memcmp(compared_data, item->RawData(), item->GetDataSize()), 0);
|
||||
delete[] compared_data;
|
||||
}
|
||||
STREAMING_LOG(DEBUG) << "Received message count => " << recevied_message_cnt;
|
||||
if (recevied_message_cnt == queue_id_vec.size() * MESSAGE_BOUND_SIZE) {
|
||||
STREAMING_LOG(INFO) << "recevied message count => " << recevied_message_cnt
|
||||
<< ", break";
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void StreamingReaderStrategyTest(StreamingConfig &config) {
|
||||
std::vector<ActorID> actor_ids(queue_ids_.size(), peer_actor_id_);
|
||||
STREAMING_LOG(INFO) << "reader actor_ids size: " << actor_ids.size()
|
||||
<< " actor_id: " << peer_actor_id_;
|
||||
std::shared_ptr<RuntimeContext> runtime_context(new RuntimeContext());
|
||||
runtime_context->SetConfig(config);
|
||||
std::shared_ptr<DataReader> reader(new DataReader(runtime_context));
|
||||
|
||||
reader->Init(queue_ids_, actor_ids, -1);
|
||||
ReaderLoopForward(reader, nullptr, queue_ids_);
|
||||
|
||||
STREAMING_LOG(INFO) << "Reader exit";
|
||||
}
|
||||
|
||||
void StreamingWriterExactlyOnceTest() {
|
||||
STREAMING_LOG(INFO)
|
||||
<< "StreamingQueueReaderTestSuite::StreamingWriterExactlyOnceTest";
|
||||
StreamingConfig config;
|
||||
|
||||
StreamingReaderStrategyTest(config);
|
||||
status_ = true;
|
||||
}
|
||||
};
|
||||
|
||||
class TestSuiteFactory {
|
||||
public:
|
||||
static std::shared_ptr<StreamingQueueTestSuite> CreateTestSuite(
|
||||
std::shared_ptr<CoreWorker> worker, std::shared_ptr<TestInitMessage> message) {
|
||||
std::shared_ptr<StreamingQueueTestSuite> test_suite = nullptr;
|
||||
std::string suite_name = message->TestSuiteName();
|
||||
queue::protobuf::StreamingQueueTestRole role = message->Role();
|
||||
const std::vector<ObjectID> &queue_ids = message->QueueIds();
|
||||
const std::vector<ObjectID> &rescale_queue_ids = message->RescaleQueueIds();
|
||||
ActorID peer_actor_id = message->PeerActorId();
|
||||
|
||||
if (role == queue::protobuf::StreamingQueueTestRole::WRITER) {
|
||||
if (suite_name == "StreamingWriterTest") {
|
||||
test_suite = std::make_shared<StreamingQueueWriterTestSuite>(
|
||||
worker, peer_actor_id, queue_ids, rescale_queue_ids);
|
||||
} else {
|
||||
STREAMING_CHECK(false) << "unsurported suite_name: " << suite_name;
|
||||
}
|
||||
} else {
|
||||
if (suite_name == "StreamingWriterTest") {
|
||||
test_suite = std::make_shared<StreamingQueueReaderTestSuite>(
|
||||
worker, peer_actor_id, queue_ids, rescale_queue_ids);
|
||||
} else {
|
||||
STREAMING_CHECK(false) << "unsupported suite_name: " << suite_name;
|
||||
}
|
||||
}
|
||||
|
||||
return test_suite;
|
||||
}
|
||||
};
|
||||
|
||||
class StreamingWorker {
|
||||
public:
|
||||
StreamingWorker(const std::string &store_socket, const std::string &raylet_socket,
|
||||
int node_manager_port, const gcs::GcsClientOptions &gcs_options)
|
||||
: test_suite_(nullptr), peer_actor_handle_(nullptr) {
|
||||
worker_ = std::make_shared<CoreWorker>(
|
||||
WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket,
|
||||
JobID::FromInt(1), gcs_options, "", "127.0.0.1", node_manager_port,
|
||||
std::bind(&StreamingWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7));
|
||||
|
||||
RayFunction reader_async_call_func{ray::Language::PYTHON, {"reader_async_call_func"}};
|
||||
RayFunction reader_sync_call_func{ray::Language::PYTHON, {"reader_sync_call_func"}};
|
||||
RayFunction writer_async_call_func{ray::Language::PYTHON, {"writer_async_call_func"}};
|
||||
RayFunction writer_sync_call_func{ray::Language::PYTHON, {"writer_sync_call_func"}};
|
||||
|
||||
reader_client_ = std::make_shared<ReaderClient>(worker_.get(), reader_async_call_func,
|
||||
reader_sync_call_func);
|
||||
writer_client_ = std::make_shared<WriterClient>(worker_.get(), writer_async_call_func,
|
||||
writer_sync_call_func);
|
||||
STREAMING_LOG(INFO) << "StreamingWorker constructor";
|
||||
}
|
||||
|
||||
void StartExecutingTasks() {
|
||||
// Start executing tasks.
|
||||
worker_->StartExecutingTasks();
|
||||
}
|
||||
|
||||
private:
|
||||
Status ExecuteTask(TaskType task_type, const RayFunction &ray_function,
|
||||
const std::unordered_map<std::string, double> &required_resources,
|
||||
const std::vector<std::shared_ptr<RayObject>> &args,
|
||||
const std::vector<ObjectID> &arg_reference_ids,
|
||||
const std::vector<ObjectID> &return_ids,
|
||||
std::vector<std::shared_ptr<RayObject>> *results) {
|
||||
// Only one arg param used in streaming.
|
||||
STREAMING_CHECK(args.size() >= 1) << "args.size() = " << args.size();
|
||||
|
||||
std::vector<std::string> function_descriptor = ray_function.GetFunctionDescriptor();
|
||||
STREAMING_LOG(INFO) << "StreamingWorker::ExecuteTask " << function_descriptor[0];
|
||||
|
||||
std::string func_name = function_descriptor[0];
|
||||
if (func_name == "init") {
|
||||
std::shared_ptr<LocalMemoryBuffer> local_buffer =
|
||||
std::make_shared<LocalMemoryBuffer>(args[0]->GetData()->Data(),
|
||||
args[0]->GetData()->Size(), true);
|
||||
HandleInitTask(local_buffer);
|
||||
} else if (func_name == "execute_test") {
|
||||
STREAMING_LOG(INFO) << "Test name: " << function_descriptor[1];
|
||||
test_suite_->ExecuteTest(function_descriptor[1]);
|
||||
} else if (func_name == "check_current_test_status") {
|
||||
results->push_back(
|
||||
std::make_shared<RayObject>(test_suite_->CheckCurTestStatus(), nullptr));
|
||||
} else if (func_name == "reader_sync_call_func") {
|
||||
if (test_suite_->TestDone()) {
|
||||
STREAMING_LOG(WARNING) << "Test has done!!";
|
||||
return Status::OK();
|
||||
}
|
||||
std::shared_ptr<LocalMemoryBuffer> local_buffer =
|
||||
std::make_shared<LocalMemoryBuffer>(args[1]->GetData()->Data(),
|
||||
args[1]->GetData()->Size(), true);
|
||||
auto result_buffer = reader_client_->OnReaderMessageSync(local_buffer);
|
||||
results->push_back(std::make_shared<RayObject>(result_buffer, nullptr));
|
||||
} else if (func_name == "reader_async_call_func") {
|
||||
if (test_suite_->TestDone()) {
|
||||
STREAMING_LOG(WARNING) << "Test has done!!";
|
||||
return Status::OK();
|
||||
}
|
||||
std::shared_ptr<LocalMemoryBuffer> local_buffer =
|
||||
std::make_shared<LocalMemoryBuffer>(args[1]->GetData()->Data(),
|
||||
args[1]->GetData()->Size(), true);
|
||||
reader_client_->OnReaderMessage(local_buffer);
|
||||
} else if (func_name == "writer_sync_call_func") {
|
||||
if (test_suite_->TestDone()) {
|
||||
STREAMING_LOG(WARNING) << "Test has done!!";
|
||||
return Status::OK();
|
||||
}
|
||||
std::shared_ptr<LocalMemoryBuffer> local_buffer =
|
||||
std::make_shared<LocalMemoryBuffer>(args[1]->GetData()->Data(),
|
||||
args[1]->GetData()->Size(), true);
|
||||
auto result_buffer = writer_client_->OnWriterMessageSync(local_buffer);
|
||||
results->push_back(std::make_shared<RayObject>(result_buffer, nullptr));
|
||||
} else if (func_name == "writer_async_call_func") {
|
||||
if (test_suite_->TestDone()) {
|
||||
STREAMING_LOG(WARNING) << "Test has done!!";
|
||||
return Status::OK();
|
||||
}
|
||||
std::shared_ptr<LocalMemoryBuffer> local_buffer =
|
||||
std::make_shared<LocalMemoryBuffer>(args[1]->GetData()->Data(),
|
||||
args[1]->GetData()->Size(), true);
|
||||
writer_client_->OnWriterMessage(local_buffer);
|
||||
} else {
|
||||
STREAMING_LOG(WARNING) << "Invalid function name " << func_name;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
void HandleInitTask(std::shared_ptr<LocalMemoryBuffer> buffer) {
|
||||
uint8_t *bytes = buffer->Data();
|
||||
uint8_t *p_cur = bytes;
|
||||
uint32_t *magic_num = (uint32_t *)p_cur;
|
||||
STREAMING_CHECK(*magic_num == Message::MagicNum);
|
||||
|
||||
p_cur += sizeof(Message::MagicNum);
|
||||
queue::protobuf::StreamingQueueMessageType *type =
|
||||
(queue::protobuf::StreamingQueueMessageType *)p_cur;
|
||||
STREAMING_CHECK(
|
||||
*type ==
|
||||
queue::protobuf::StreamingQueueMessageType::StreamingQueueTestInitMsgType);
|
||||
std::shared_ptr<TestInitMessage> message = TestInitMessage::FromBytes(bytes);
|
||||
|
||||
STREAMING_LOG(INFO) << "Init message: " << message->ToString();
|
||||
std::string actor_handle_serialized = message->ActorHandleSerialized();
|
||||
worker_->DeserializeAndRegisterActorHandle(actor_handle_serialized);
|
||||
std::shared_ptr<ActorHandle> actor_handle(new ActorHandle(actor_handle_serialized));
|
||||
STREAMING_CHECK(actor_handle != nullptr);
|
||||
STREAMING_LOG(INFO) << " actor id from handle: " << actor_handle->GetActorID();
|
||||
;
|
||||
|
||||
// STREAMING_LOG(INFO) << "actor_handle_serialized: " << actor_handle_serialized;
|
||||
// peer_actor_handle_ =
|
||||
// std::make_shared<ActorHandle>(actor_handle_serialized);
|
||||
|
||||
STREAMING_LOG(INFO) << "HandleInitTask queues:";
|
||||
for (auto qid : message->QueueIds()) {
|
||||
STREAMING_LOG(INFO) << "queue: " << qid;
|
||||
}
|
||||
for (auto qid : message->RescaleQueueIds()) {
|
||||
STREAMING_LOG(INFO) << "rescale queue: " << qid;
|
||||
}
|
||||
|
||||
test_suite_ = TestSuiteFactory::CreateTestSuite(worker_, message);
|
||||
STREAMING_CHECK(test_suite_ != nullptr);
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<CoreWorker> worker_;
|
||||
std::shared_ptr<ReaderClient> reader_client_;
|
||||
std::shared_ptr<WriterClient> writer_client_;
|
||||
std::shared_ptr<std::thread> test_thread_;
|
||||
std::shared_ptr<StreamingQueueTestSuite> test_suite_;
|
||||
std::shared_ptr<ActorHandle> peer_actor_handle_;
|
||||
};
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
RAY_CHECK(argc == 4);
|
||||
auto store_socket = std::string(argv[1]);
|
||||
auto raylet_socket = std::string(argv[2]);
|
||||
auto node_manager_port = std::stoi(std::string(argv[3]));
|
||||
|
||||
ray::gcs::GcsClientOptions gcs_options("127.0.0.1", 6379, "");
|
||||
ray::streaming::StreamingWorker worker(store_socket, raylet_socket, node_manager_port,
|
||||
gcs_options);
|
||||
worker.StartExecutingTasks();
|
||||
return 0;
|
||||
}
|
136
streaming/src/test/mock_transfer_tests.cc
Normal file
136
streaming/src/test/mock_transfer_tests.cc
Normal file
|
@ -0,0 +1,136 @@
|
|||
#include "data_reader.h"
|
||||
#include "data_writer.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace ray;
|
||||
using namespace ray::streaming;
|
||||
|
||||
TEST(StreamingMockTransfer, mock_produce_consume) {
|
||||
std::shared_ptr<Config> transfer_config;
|
||||
ObjectID channel_id = ObjectID::FromRandom();
|
||||
ProducerChannelInfo producer_channel_info;
|
||||
producer_channel_info.channel_id = channel_id;
|
||||
producer_channel_info.current_seq_id = 0;
|
||||
MockProducer producer(transfer_config, producer_channel_info);
|
||||
|
||||
ConsumerChannelInfo consumer_channel_info;
|
||||
consumer_channel_info.channel_id = channel_id;
|
||||
MockConsumer consumer(transfer_config, consumer_channel_info);
|
||||
|
||||
producer.CreateTransferChannel();
|
||||
uint8_t data[3] = {1, 2, 3};
|
||||
producer.ProduceItemToChannel(data, 3);
|
||||
uint8_t *data_consumed;
|
||||
uint32_t data_size_consumed;
|
||||
uint64_t data_seq_id;
|
||||
consumer.ConsumeItemFromChannel(data_seq_id, data_consumed, data_size_consumed, -1);
|
||||
EXPECT_EQ(data_size_consumed, 3);
|
||||
EXPECT_EQ(data_seq_id, 1);
|
||||
EXPECT_EQ(std::memcmp(data_consumed, data, 3), 0);
|
||||
consumer.NotifyChannelConsumed(1);
|
||||
|
||||
auto status =
|
||||
consumer.ConsumeItemFromChannel(data_seq_id, data_consumed, data_size_consumed, -1);
|
||||
EXPECT_EQ(status, StreamingStatus::NoSuchItem);
|
||||
}
|
||||
|
||||
class StreamingTransferTest : public ::testing::Test {
|
||||
public:
|
||||
StreamingTransferTest() {
|
||||
std::shared_ptr<RuntimeContext> runtime_context(new RuntimeContext());
|
||||
runtime_context->MarkMockTest();
|
||||
writer = std::make_shared<DataWriter>(runtime_context);
|
||||
reader = std::make_shared<DataReader>(runtime_context);
|
||||
}
|
||||
virtual ~StreamingTransferTest() = default;
|
||||
void InitTransfer(int channel_num = 1) {
|
||||
for (int i = 0; i < channel_num; ++i) {
|
||||
queue_vec.push_back(ObjectID::FromRandom());
|
||||
}
|
||||
std::vector<uint64_t> channel_id_vec(queue_vec.size(), 0);
|
||||
std::vector<uint64_t> queue_size_vec(queue_vec.size(), 10000);
|
||||
// actor ids are not used in this test, so we can just use Nil.
|
||||
std::vector<ActorID> actor_id_vec(queue_vec.size(),
|
||||
ActorID::NilFromJob(JobID::FromInt(0)));
|
||||
writer->Init(queue_vec, actor_id_vec, channel_id_vec, queue_size_vec);
|
||||
reader->Init(queue_vec, actor_id_vec, channel_id_vec, queue_size_vec, -1);
|
||||
}
|
||||
void DestroyTransfer() {
|
||||
writer.reset();
|
||||
reader.reset();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<DataWriter> writer;
|
||||
std::shared_ptr<DataReader> reader;
|
||||
std::vector<ObjectID> queue_vec;
|
||||
};
|
||||
|
||||
TEST_F(StreamingTransferTest, exchange_single_channel_test) {
|
||||
InitTransfer();
|
||||
writer->Run();
|
||||
uint8_t data[4] = {1, 2, 3, 0xff};
|
||||
uint32_t data_size = 4;
|
||||
writer->WriteMessageToBufferRing(queue_vec[0], data, data_size);
|
||||
std::shared_ptr<DataBundle> msg;
|
||||
reader->GetBundle(5000, msg);
|
||||
StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(msg->data);
|
||||
auto &message_list = bundle_ptr->GetMessageList();
|
||||
auto &message = message_list.front();
|
||||
EXPECT_EQ(std::memcmp(message->RawData(), data, data_size), 0);
|
||||
}
|
||||
|
||||
TEST_F(StreamingTransferTest, exchange_multichannel_test) {
|
||||
int channel_num = 4;
|
||||
InitTransfer(4);
|
||||
writer->Run();
|
||||
for (int i = 0; i < channel_num; ++i) {
|
||||
uint8_t data[4] = {1, 2, 3, (uint8_t)i};
|
||||
uint32_t data_size = 4;
|
||||
writer->WriteMessageToBufferRing(queue_vec[i], data, data_size);
|
||||
std::shared_ptr<DataBundle> msg;
|
||||
reader->GetBundle(5000, msg);
|
||||
EXPECT_EQ(msg->from, queue_vec[i]);
|
||||
StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(msg->data);
|
||||
auto &message_list = bundle_ptr->GetMessageList();
|
||||
auto &message = message_list.front();
|
||||
EXPECT_EQ(std::memcmp(message->RawData(), data, data_size), 0);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(StreamingTransferTest, exchange_consumed_test) {
|
||||
InitTransfer();
|
||||
writer->Run();
|
||||
uint32_t data_size = 8196;
|
||||
std::shared_ptr<uint8_t> data(new uint8_t[data_size]);
|
||||
auto func = [data, data_size](int index) { std::fill_n(data.get(), data_size, index); };
|
||||
|
||||
int num = 10000;
|
||||
std::thread write_thread([this, data, data_size, &func, num]() {
|
||||
for (uint32_t i = 0; i < num; ++i) {
|
||||
func(i);
|
||||
writer->WriteMessageToBufferRing(queue_vec[0], data.get(), data_size);
|
||||
}
|
||||
});
|
||||
|
||||
std::list<StreamingMessagePtr> read_message_list;
|
||||
while (read_message_list.size() < num) {
|
||||
std::shared_ptr<DataBundle> msg;
|
||||
reader->GetBundle(5000, msg);
|
||||
StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(msg->data);
|
||||
auto &message_list = bundle_ptr->GetMessageList();
|
||||
std::copy(message_list.begin(), message_list.end(),
|
||||
std::back_inserter(read_message_list));
|
||||
}
|
||||
int index = 0;
|
||||
for (auto &message : read_message_list) {
|
||||
func(index++);
|
||||
EXPECT_EQ(std::memcmp(message->RawData(), data.get(), data_size), 0);
|
||||
}
|
||||
write_thread.join();
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
313
streaming/src/test/queue_tests_base.h
Normal file
313
streaming/src/test/queue_tests_base.h
Normal file
|
@ -0,0 +1,313 @@
|
|||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
ray::ObjectID RandomObjectID() { return ObjectID::FromRandom(); }
|
||||
|
||||
static void flushall_redis(void) {
|
||||
redisContext *context = redisConnect("127.0.0.1", 6379);
|
||||
freeReplyObject(redisCommand(context, "FLUSHALL"));
|
||||
freeReplyObject(redisCommand(context, "SET NumRedisShards 1"));
|
||||
freeReplyObject(redisCommand(context, "LPUSH RedisShards 127.0.0.1:6380"));
|
||||
redisFree(context);
|
||||
}
|
||||
/// Base class for real-world tests with streaming queue
|
||||
class StreamingQueueTestBase : public ::testing::TestWithParam<uint64_t> {
|
||||
public:
|
||||
StreamingQueueTestBase(int num_nodes, std::string raylet_exe, std::string store_exe,
|
||||
int port, std::string actor_exe)
|
||||
: gcs_options_("127.0.0.1", 6379, ""),
|
||||
raylet_executable_(raylet_exe),
|
||||
store_executable_(store_exe),
|
||||
actor_executable_(actor_exe),
|
||||
node_manager_port_(port) {
|
||||
// flush redis first.
|
||||
flushall_redis();
|
||||
|
||||
RAY_CHECK(num_nodes >= 0);
|
||||
if (num_nodes > 0) {
|
||||
raylet_socket_names_.resize(num_nodes);
|
||||
raylet_store_socket_names_.resize(num_nodes);
|
||||
}
|
||||
|
||||
// start plasma store.
|
||||
for (auto &store_socket : raylet_store_socket_names_) {
|
||||
store_socket = StartStore();
|
||||
}
|
||||
|
||||
// start raylet on each node. Assign each node with different resources so that
|
||||
// a task can be scheduled to the desired node.
|
||||
for (int i = 0; i < num_nodes; i++) {
|
||||
raylet_socket_names_[i] =
|
||||
StartRaylet(raylet_store_socket_names_[i], "127.0.0.1", node_manager_port_ + i,
|
||||
"127.0.0.1", "\"CPU,4.0,resource" + std::to_string(i) + ",10\"");
|
||||
}
|
||||
}
|
||||
|
||||
~StreamingQueueTestBase() {
|
||||
STREAMING_LOG(INFO) << "Stop raylet store and actors";
|
||||
for (const auto &raylet_socket : raylet_socket_names_) {
|
||||
StopRaylet(raylet_socket);
|
||||
}
|
||||
|
||||
for (const auto &store_socket : raylet_store_socket_names_) {
|
||||
StopStore(store_socket);
|
||||
}
|
||||
}
|
||||
|
||||
JobID NextJobId() const {
|
||||
static uint32_t job_counter = 1;
|
||||
return JobID::FromInt(job_counter++);
|
||||
}
|
||||
|
||||
std::string StartStore() {
|
||||
std::string store_socket_name = "/tmp/store" + RandomObjectID().Hex();
|
||||
std::string store_pid = store_socket_name + ".pid";
|
||||
std::string plasma_command = store_executable_ + " -m 10000000 -s " +
|
||||
store_socket_name +
|
||||
" 1> /dev/null 2> /dev/null & echo $! > " + store_pid;
|
||||
RAY_LOG(DEBUG) << plasma_command;
|
||||
RAY_CHECK(system(plasma_command.c_str()) == 0);
|
||||
usleep(200 * 1000);
|
||||
return store_socket_name;
|
||||
}
|
||||
|
||||
void StopStore(std::string store_socket_name) {
|
||||
std::string store_pid = store_socket_name + ".pid";
|
||||
std::string kill_9 = "kill -9 `cat " + store_pid + "`";
|
||||
RAY_LOG(DEBUG) << kill_9;
|
||||
ASSERT_EQ(system(kill_9.c_str()), 0);
|
||||
ASSERT_EQ(system(("rm -rf " + store_socket_name).c_str()), 0);
|
||||
ASSERT_EQ(system(("rm -rf " + store_socket_name + ".pid").c_str()), 0);
|
||||
}
|
||||
|
||||
std::string StartRaylet(std::string store_socket_name, std::string node_ip_address,
|
||||
int port, std::string redis_address, std::string resource) {
|
||||
std::string raylet_socket_name = "/tmp/raylet" + RandomObjectID().Hex();
|
||||
std::string ray_start_cmd = raylet_executable_;
|
||||
ray_start_cmd.append(" --raylet_socket_name=" + raylet_socket_name)
|
||||
.append(" --store_socket_name=" + store_socket_name)
|
||||
.append(" --object_manager_port=0 --node_manager_port=" + std::to_string(port))
|
||||
.append(" --node_ip_address=" + node_ip_address)
|
||||
.append(" --redis_address=" + redis_address)
|
||||
.append(" --redis_port=6379")
|
||||
.append(" --num_initial_workers=1")
|
||||
.append(" --maximum_startup_concurrency=10")
|
||||
.append(" --static_resource_list=" + resource)
|
||||
.append(" --python_worker_command=\"" + actor_executable_ + " " +
|
||||
store_socket_name + " " + raylet_socket_name + " " +
|
||||
std::to_string(port) + "\"")
|
||||
.append(" --config_list=initial_reconstruction_timeout_milliseconds,2000")
|
||||
.append(" & echo $! > " + raylet_socket_name + ".pid");
|
||||
|
||||
RAY_LOG(DEBUG) << "Ray Start command: " << ray_start_cmd;
|
||||
RAY_CHECK(system(ray_start_cmd.c_str()) == 0);
|
||||
usleep(200 * 1000);
|
||||
return raylet_socket_name;
|
||||
}
|
||||
|
||||
void StopRaylet(std::string raylet_socket_name) {
|
||||
std::string raylet_pid = raylet_socket_name + ".pid";
|
||||
std::string kill_9 = "kill -9 `cat " + raylet_pid + "`";
|
||||
RAY_LOG(DEBUG) << kill_9;
|
||||
ASSERT_TRUE(system(kill_9.c_str()) == 0);
|
||||
ASSERT_TRUE(system(("rm -rf " + raylet_socket_name).c_str()) == 0);
|
||||
ASSERT_TRUE(system(("rm -rf " + raylet_socket_name + ".pid").c_str()) == 0);
|
||||
}
|
||||
|
||||
void InitWorker(CoreWorker &driver, ActorID &self_actor_id, ActorID &peer_actor_id,
|
||||
const queue::protobuf::StreamingQueueTestRole role,
|
||||
const std::vector<ObjectID> &queue_ids,
|
||||
const std::vector<ObjectID> &rescale_queue_ids, std::string suite_name,
|
||||
std::string test_name, uint64_t param) {
|
||||
std::string forked_serialized_str;
|
||||
Status st = driver.SerializeActorHandle(peer_actor_id, &forked_serialized_str);
|
||||
STREAMING_CHECK(st.ok());
|
||||
STREAMING_LOG(INFO) << "forked_serialized_str: " << forked_serialized_str;
|
||||
TestInitMessage msg(role, self_actor_id, peer_actor_id, forked_serialized_str,
|
||||
queue_ids, rescale_queue_ids, suite_name, test_name, param);
|
||||
|
||||
std::vector<TaskArg> args;
|
||||
args.emplace_back(
|
||||
TaskArg::PassByValue(std::make_shared<RayObject>(msg.ToBytes(), nullptr, true)));
|
||||
std::unordered_map<std::string, double> resources;
|
||||
TaskOptions options{0, true, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
RayFunction func{ray::Language::PYTHON, {"init"}};
|
||||
|
||||
RAY_CHECK_OK(driver.SubmitActorTask(self_actor_id, func, args, options, &return_ids));
|
||||
}
|
||||
|
||||
void SubmitTestToActor(CoreWorker &driver, ActorID &actor_id, const std::string test) {
|
||||
uint8_t data[8];
|
||||
auto buffer = std::make_shared<LocalMemoryBuffer>(data, 8, true);
|
||||
std::vector<TaskArg> args;
|
||||
args.emplace_back(
|
||||
TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr, true)));
|
||||
std::unordered_map<std::string, double> resources;
|
||||
TaskOptions options{0, true, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
RayFunction func{ray::Language::PYTHON, {"execute_test", test}};
|
||||
|
||||
RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids));
|
||||
}
|
||||
|
||||
bool CheckCurTest(CoreWorker &driver, ActorID &actor_id, const std::string test_name) {
|
||||
uint8_t data[8];
|
||||
auto buffer = std::make_shared<LocalMemoryBuffer>(data, 8, true);
|
||||
std::vector<TaskArg> args;
|
||||
args.emplace_back(
|
||||
TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr, true)));
|
||||
std::unordered_map<std::string, double> resources;
|
||||
TaskOptions options{1, true, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
RayFunction func{ray::Language::PYTHON, {"check_current_test_status"}};
|
||||
|
||||
RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids));
|
||||
|
||||
std::vector<bool> wait_results;
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
Status wait_st = driver.Wait(return_ids, 1, 5 * 1000, &wait_results);
|
||||
if (!wait_st.ok()) {
|
||||
STREAMING_LOG(ERROR) << "Wait fail.";
|
||||
return false;
|
||||
}
|
||||
STREAMING_CHECK(wait_results.size() >= 1);
|
||||
if (!wait_results[0]) {
|
||||
STREAMING_LOG(WARNING) << "Wait direct call fail.";
|
||||
return false;
|
||||
}
|
||||
|
||||
Status get_st = driver.Get(return_ids, -1, &results);
|
||||
if (!get_st.ok()) {
|
||||
STREAMING_LOG(ERROR) << "Get fail.";
|
||||
return false;
|
||||
}
|
||||
STREAMING_CHECK(results.size() >= 1);
|
||||
if (results[0]->IsException()) {
|
||||
STREAMING_LOG(INFO) << "peer actor may has exceptions.";
|
||||
return false;
|
||||
}
|
||||
STREAMING_CHECK(results[0]->HasData());
|
||||
STREAMING_LOG(DEBUG) << "SendForResult result[0] DataSize: " << results[0]->GetSize();
|
||||
|
||||
const std::shared_ptr<ray::Buffer> result_buffer = results[0]->GetData();
|
||||
std::shared_ptr<LocalMemoryBuffer> return_buffer =
|
||||
std::make_shared<LocalMemoryBuffer>(result_buffer->Data(), result_buffer->Size(),
|
||||
true);
|
||||
|
||||
uint8_t *bytes = result_buffer->Data();
|
||||
uint8_t *p_cur = bytes;
|
||||
uint32_t *magic_num = (uint32_t *)p_cur;
|
||||
STREAMING_CHECK(*magic_num == Message::MagicNum);
|
||||
|
||||
p_cur += sizeof(Message::MagicNum);
|
||||
queue::protobuf::StreamingQueueMessageType *type =
|
||||
(queue::protobuf::StreamingQueueMessageType *)p_cur;
|
||||
STREAMING_CHECK(*type == queue::protobuf::StreamingQueueMessageType::
|
||||
StreamingQueueTestCheckStatusRspMsgType);
|
||||
std::shared_ptr<TestCheckStatusRspMsg> message =
|
||||
TestCheckStatusRspMsg::FromBytes(bytes);
|
||||
STREAMING_CHECK(message->TestName() == test_name);
|
||||
return message->Status();
|
||||
}
|
||||
|
||||
ActorID CreateActorHelper(CoreWorker &worker,
|
||||
const std::unordered_map<std::string, double> &resources,
|
||||
bool is_direct_call, uint64_t max_reconstructions) {
|
||||
std::unique_ptr<ActorHandle> actor_handle;
|
||||
|
||||
// Test creating actor.
|
||||
uint8_t array[] = {1, 2, 3};
|
||||
auto buffer = std::make_shared<LocalMemoryBuffer>(array, sizeof(array));
|
||||
|
||||
RayFunction func{ray::Language::PYTHON, {"actor creation task"}};
|
||||
std::vector<TaskArg> args;
|
||||
args.emplace_back(TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr)));
|
||||
|
||||
ActorCreationOptions actor_options{
|
||||
max_reconstructions, is_direct_call,
|
||||
/*max_concurrency*/ 1, resources, resources, {},
|
||||
/*is_detached*/ false, /*is_asyncio*/ false};
|
||||
|
||||
// Create an actor.
|
||||
ActorID actor_id;
|
||||
RAY_CHECK_OK(worker.CreateActor(func, args, actor_options, &actor_id));
|
||||
return actor_id;
|
||||
}
|
||||
|
||||
void SubmitTest(uint32_t queue_num, std::string suite_name, std::string test_name,
|
||||
uint64_t timeout_ms) {
|
||||
std::vector<ray::ObjectID> queue_id_vec;
|
||||
std::vector<ray::ObjectID> rescale_queue_id_vec;
|
||||
for (uint32_t i = 0; i < queue_num; ++i) {
|
||||
ObjectID queue_id = ray::ObjectID::FromRandom();
|
||||
queue_id_vec.emplace_back(queue_id);
|
||||
}
|
||||
|
||||
// One scale id
|
||||
ObjectID rescale_queue_id = ray::ObjectID::FromRandom();
|
||||
rescale_queue_id_vec.emplace_back(rescale_queue_id);
|
||||
|
||||
std::vector<uint64_t> channel_seq_id_vec(queue_num, 0);
|
||||
|
||||
for (size_t i = 0; i < queue_id_vec.size(); ++i) {
|
||||
STREAMING_LOG(INFO) << " qid hex => " << queue_id_vec[i].Hex();
|
||||
}
|
||||
for (auto &qid : rescale_queue_id_vec) {
|
||||
STREAMING_LOG(INFO) << " rescale qid hex => " << qid.Hex();
|
||||
}
|
||||
STREAMING_LOG(INFO) << "Sub process: writer.";
|
||||
|
||||
CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0],
|
||||
raylet_socket_names_[0], NextJobId(), gcs_options_, "", "",
|
||||
node_manager_port_, nullptr);
|
||||
|
||||
// Create writer and reader actors
|
||||
std::unordered_map<std::string, double> resources;
|
||||
auto actor_id_writer = CreateActorHelper(driver, resources, true, 0);
|
||||
auto actor_id_reader = CreateActorHelper(driver, resources, true, 0);
|
||||
|
||||
InitWorker(driver, actor_id_writer, actor_id_reader,
|
||||
queue::protobuf::StreamingQueueTestRole::WRITER, queue_id_vec,
|
||||
rescale_queue_id_vec, suite_name, test_name, GetParam());
|
||||
InitWorker(driver, actor_id_reader, actor_id_writer,
|
||||
queue::protobuf::StreamingQueueTestRole::READER, queue_id_vec,
|
||||
rescale_queue_id_vec, suite_name, test_name, GetParam());
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
|
||||
|
||||
SubmitTestToActor(driver, actor_id_writer, test_name);
|
||||
SubmitTestToActor(driver, actor_id_reader, test_name);
|
||||
|
||||
uint64_t slept_time_ms = 0;
|
||||
while (slept_time_ms < timeout_ms) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5 * 1000));
|
||||
STREAMING_LOG(INFO) << "Check test status.";
|
||||
if (CheckCurTest(driver, actor_id_writer, test_name) &&
|
||||
CheckCurTest(driver, actor_id_reader, test_name)) {
|
||||
STREAMING_LOG(INFO) << "Test Success, Exit.";
|
||||
return;
|
||||
}
|
||||
slept_time_ms += 5 * 1000;
|
||||
}
|
||||
|
||||
EXPECT_TRUE(false);
|
||||
STREAMING_LOG(INFO) << "Test Timeout, Exit.";
|
||||
}
|
||||
|
||||
void SetUp() {}
|
||||
|
||||
void TearDown() {}
|
||||
|
||||
protected:
|
||||
std::vector<std::string> raylet_socket_names_;
|
||||
std::vector<std::string> raylet_store_socket_names_;
|
||||
gcs::GcsClientOptions gcs_options_;
|
||||
std::string raylet_executable_;
|
||||
std::string store_executable_;
|
||||
std::string actor_executable_;
|
||||
int node_manager_port_;
|
||||
};
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
93
streaming/src/test/ring_buffer_tests.cc
Normal file
93
streaming/src/test/ring_buffer_tests.cc
Normal file
|
@ -0,0 +1,93 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "ray/util/logging.h"
|
||||
|
||||
#include <unistd.h>
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <thread>
|
||||
#include "message/message.h"
|
||||
#include "ring_buffer.h"
|
||||
|
||||
using namespace ray;
|
||||
using namespace ray::streaming;
|
||||
|
||||
size_t data_n = 1000000;
|
||||
TEST(StreamingRingBufferTest, streaming_message_ring_buffer_test) {
|
||||
for (int k = 0; k < 10000; ++k) {
|
||||
StreamingRingBuffer ring_buffer(3, StreamingRingBufferType::SPSC_LOCK);
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
uint8_t data[] = {1, 1, 3};
|
||||
data[0] = i;
|
||||
StreamingMessagePtr message =
|
||||
std::make_shared<StreamingMessage>(data, 3, i, StreamingMessageType::Message);
|
||||
EXPECT_EQ(ring_buffer.Push(message), true);
|
||||
size_t ith = i >= 3 ? 3 : (i + 1);
|
||||
EXPECT_EQ(ring_buffer.Size(), ith);
|
||||
}
|
||||
int th = 2;
|
||||
|
||||
while (!ring_buffer.IsEmpty()) {
|
||||
StreamingMessagePtr message_ptr = ring_buffer.Front();
|
||||
ring_buffer.Pop();
|
||||
EXPECT_EQ(message_ptr->GetDataSize(), 3);
|
||||
EXPECT_EQ(*(message_ptr->RawData()), th++);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(StreamingRingBufferTest, spsc_test) {
|
||||
size_t m_num = 1000;
|
||||
StreamingRingBuffer ring_buffer(m_num, StreamingRingBufferType::SPSC);
|
||||
std::thread thread([&ring_buffer]() {
|
||||
for (size_t j = 0; j < data_n; ++j) {
|
||||
StreamingMessagePtr message = std::make_shared<StreamingMessage>(
|
||||
reinterpret_cast<uint8_t *>(&j), sizeof(size_t), j,
|
||||
StreamingMessageType::Message);
|
||||
while (ring_buffer.IsFull()) {
|
||||
}
|
||||
ring_buffer.Push(message);
|
||||
}
|
||||
});
|
||||
size_t count = 0;
|
||||
while (count < data_n) {
|
||||
while (ring_buffer.IsEmpty()) {
|
||||
}
|
||||
auto &msg = ring_buffer.Front();
|
||||
EXPECT_EQ(std::memcmp(msg->RawData(), &count, sizeof(size_t)), 0);
|
||||
ring_buffer.Pop();
|
||||
count++;
|
||||
}
|
||||
thread.join();
|
||||
EXPECT_EQ(count, data_n);
|
||||
}
|
||||
|
||||
TEST(StreamingRingBufferTest, mutex_test) {
|
||||
size_t m_num = data_n;
|
||||
StreamingRingBuffer ring_buffer(m_num, StreamingRingBufferType::SPSC_LOCK);
|
||||
std::thread thread([&ring_buffer]() {
|
||||
for (size_t j = 0; j < data_n; ++j) {
|
||||
StreamingMessagePtr message = std::make_shared<StreamingMessage>(
|
||||
reinterpret_cast<uint8_t *>(&j), sizeof(size_t), j,
|
||||
StreamingMessageType::Message);
|
||||
while (ring_buffer.IsFull()) {
|
||||
}
|
||||
ring_buffer.Push(message);
|
||||
}
|
||||
});
|
||||
size_t count = 0;
|
||||
while (count < data_n) {
|
||||
while (ring_buffer.IsEmpty()) {
|
||||
}
|
||||
auto msg = ring_buffer.Front();
|
||||
EXPECT_EQ(std::memcmp(msg->RawData(), &count, sizeof(size_t)), 0);
|
||||
ring_buffer.Pop();
|
||||
count++;
|
||||
}
|
||||
thread.join();
|
||||
EXPECT_EQ(count, data_n);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
69
streaming/src/test/run_streaming_queue_test.sh
Normal file
69
streaming/src/test/run_streaming_queue_test.sh
Normal file
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# Run all streaming c++ tests using streaming queue, instead of plasma queue
|
||||
# This needs to be run in the root directory.
|
||||
|
||||
# Try to find an unused port for raylet to use.
|
||||
PORTS="2000 2001 2002 2003 2004 2005 2006 2007 2008 2009"
|
||||
RAYLET_PORT=0
|
||||
for port in $PORTS; do
|
||||
nc -z localhost $port
|
||||
if [[ $? != 0 ]]; then
|
||||
RAYLET_PORT=$port
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ $RAYLET_PORT == 0 ]]; then
|
||||
echo "WARNING: Could not find unused port for raylet to use. Exiting without running tests."
|
||||
exit
|
||||
fi
|
||||
|
||||
# Cause the script to exit if a single command fails.
|
||||
set -e
|
||||
set -x
|
||||
export STREAMING_METRICS_MODE=DEV
|
||||
|
||||
# Get the directory in which this script is executing.
|
||||
SCRIPT_DIR="`dirname \"$0\"`"
|
||||
|
||||
# Get the directory in which this script is executing.
|
||||
SCRIPT_DIR="`dirname \"$0\"`"
|
||||
RAY_ROOT="$SCRIPT_DIR/../../.."
|
||||
# Makes $RAY_ROOT an absolute path.
|
||||
RAY_ROOT="`( cd \"$RAY_ROOT\" && pwd )`"
|
||||
if [ -z "$RAY_ROOT" ] ; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
bazel build "//:core_worker_test" "//:mock_worker" "//:raylet" "//:libray_redis_module.so" "@plasma//:plasma_store_server"
|
||||
bazel build //streaming:streaming_test_worker
|
||||
bazel build //streaming:streaming_queue_tests
|
||||
|
||||
# Ensure we're in the right directory.
|
||||
if [ ! -d "$RAY_ROOT/python" ]; then
|
||||
echo "Unable to find root Ray directory. Has this script moved?"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
REDIS_MODULE="./bazel-bin/libray_redis_module.so"
|
||||
LOAD_MODULE_ARGS="--loadmodule ${REDIS_MODULE}"
|
||||
STORE_EXEC="./bazel-bin/external/plasma/plasma_store_server"
|
||||
RAYLET_EXEC="./bazel-bin/raylet"
|
||||
STREAMING_TEST_WORKER_EXEC="./bazel-bin/streaming/streaming_test_worker"
|
||||
|
||||
# Allow cleanup commands to fail.
|
||||
bazel run //:redis-cli -- -p 6379 shutdown || true
|
||||
sleep 1s
|
||||
bazel run //:redis-cli -- -p 6380 shutdown || true
|
||||
sleep 1s
|
||||
bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6379 &
|
||||
sleep 2s
|
||||
bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6380 &
|
||||
sleep 2s
|
||||
# Run tests.
|
||||
./bazel-bin/streaming/streaming_queue_tests $STORE_EXEC $RAYLET_EXEC $RAYLET_PORT $STREAMING_TEST_WORKER_EXEC
|
||||
sleep 1s
|
||||
bazel run //:redis-cli -- -p 6379 shutdown
|
||||
bazel run //:redis-cli -- -p 6380 shutdown
|
||||
sleep 1s
|
65
streaming/src/test/streaming_queue_tests.cc
Normal file
65
streaming/src/test/streaming_queue_tests.cc
Normal file
|
@ -0,0 +1,65 @@
|
|||
#define BOOST_BIND_NO_PLACEHOLDERS
|
||||
#include <unistd.h>
|
||||
#include "gtest/gtest.h"
|
||||
#include "queue/queue_client.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
|
||||
#include "data_reader.h"
|
||||
#include "data_writer.h"
|
||||
#include "message/message.h"
|
||||
#include "message/message_bundle.h"
|
||||
#include "ring_buffer.h"
|
||||
|
||||
#include "queue_tests_base.h"
|
||||
|
||||
using namespace std::placeholders;
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
static std::string store_executable;
|
||||
static std::string raylet_executable;
|
||||
static std::string actor_executable;
|
||||
static int node_manager_port;
|
||||
|
||||
class StreamingWriterTest : public StreamingQueueTestBase {
|
||||
public:
|
||||
StreamingWriterTest()
|
||||
: StreamingQueueTestBase(1, raylet_executable, store_executable, node_manager_port,
|
||||
actor_executable) {}
|
||||
};
|
||||
|
||||
class StreamingExactlySameTest : public StreamingQueueTestBase {
|
||||
public:
|
||||
StreamingExactlySameTest()
|
||||
: StreamingQueueTestBase(1, raylet_executable, store_executable, node_manager_port,
|
||||
actor_executable) {}
|
||||
};
|
||||
|
||||
TEST_P(StreamingWriterTest, streaming_writer_exactly_once_test) {
|
||||
STREAMING_LOG(INFO) << "StreamingWriterTest.streaming_writer_exactly_once_test";
|
||||
|
||||
uint32_t queue_num = 1;
|
||||
|
||||
STREAMING_LOG(INFO) << "Streaming Strategy => EXACTLY ONCE";
|
||||
SubmitTest(queue_num, "StreamingWriterTest", "streaming_writer_exactly_once_test",
|
||||
60 * 1000);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(StreamingTest, StreamingWriterTest, testing::Values(0));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(StreamingTest, StreamingExactlySameTest,
|
||||
testing::Values(0, 1, 5, 9));
|
||||
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// set_streaming_log_config("streaming_writer_test", StreamingLogLevel::INFO, 0);
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
RAY_CHECK(argc == 5);
|
||||
ray::streaming::store_executable = std::string(argv[1]);
|
||||
ray::streaming::raylet_executable = std::string(argv[2]);
|
||||
ray::streaming::node_manager_port = std::stoi(std::string(argv[3]));
|
||||
ray::streaming::actor_executable = std::string(argv[4]);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
24
streaming/src/test/streaming_util_tests.cc
Normal file
24
streaming/src/test/streaming_util_tests.cc
Normal file
|
@ -0,0 +1,24 @@
|
|||
#include "gtest/gtest.h"
|
||||
|
||||
#include "util/streaming_util.h"
|
||||
|
||||
using namespace ray;
|
||||
using namespace ray::streaming;
|
||||
|
||||
TEST(StreamingUtilTest, test_Byte2hex) {
|
||||
const uint8_t data[2] = {0x11, 0x07};
|
||||
EXPECT_TRUE(Util::Byte2hex(data, 2) == "1107");
|
||||
EXPECT_TRUE(Util::Byte2hex(data, 2) != "1108");
|
||||
}
|
||||
|
||||
TEST(StreamingUtilTest, test_Hex2str) {
|
||||
const uint8_t data[2] = {0x11, 0x07};
|
||||
EXPECT_TRUE(std::memcmp(Util::Hexqid2str("1107").c_str(), data, 2) == 0);
|
||||
const uint8_t data2[2] = {0x10, 0x0f};
|
||||
EXPECT_TRUE(std::memcmp(Util::Hexqid2str("100f").c_str(), data2, 2) == 0);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
12
streaming/src/util/streaming_logging.cc
Normal file
12
streaming/src/util/streaming_logging.cc
Normal file
|
@ -0,0 +1,12 @@
|
|||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "glog/log_severity.h"
|
||||
#include "glog/logging.h"
|
||||
|
||||
#include "streaming_logging.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {} // namespace streaming
|
||||
} // namespace ray
|
11
streaming/src/util/streaming_logging.h
Normal file
11
streaming/src/util/streaming_logging.h
Normal file
|
@ -0,0 +1,11 @@
|
|||
#ifndef RAY_STREAMING_LOGGING_H
|
||||
#define RAY_STREAMING_LOGGING_H
|
||||
#include "ray/util/logging.h"
|
||||
|
||||
#define STREAMING_LOG RAY_LOG
|
||||
#define STREAMING_CHECK RAY_CHECK
|
||||
namespace ray {
|
||||
namespace streaming {} // namespace streaming
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_STREAMING_LOGGING_H
|
42
streaming/src/util/streaming_util.cc
Normal file
42
streaming/src/util/streaming_util.cc
Normal file
|
@ -0,0 +1,42 @@
|
|||
#include <unordered_set>
|
||||
|
||||
#include "streaming_util.h"
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
boost::any &Config::Get(ConfigEnum key) const {
|
||||
auto item = config_map_.find(key);
|
||||
STREAMING_CHECK(item != config_map_.end());
|
||||
return item->second;
|
||||
}
|
||||
|
||||
boost::any Config::Get(ConfigEnum key, boost::any default_value) const {
|
||||
auto item = config_map_.find(key);
|
||||
if (item == config_map_.end()) {
|
||||
return default_value;
|
||||
}
|
||||
return item->second;
|
||||
}
|
||||
|
||||
std::string Util::Byte2hex(const uint8_t *data, uint32_t data_size) {
|
||||
constexpr char hex[] = "0123456789abcdef";
|
||||
std::string result;
|
||||
for (uint32_t i = 0; i < data_size; i++) {
|
||||
unsigned short val = data[i];
|
||||
result.push_back(hex[val >> 4]);
|
||||
result.push_back(hex[val & 0xf]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string Util::Hexqid2str(const std::string &q_id_hex) {
|
||||
std::string result;
|
||||
for (uint32_t i = 0; i < q_id_hex.size(); i += 2) {
|
||||
std::string byte = q_id_hex.substr(i, 2);
|
||||
char chr = static_cast<char>(std::strtol(byte.c_str(), nullptr, 16));
|
||||
result.push_back(chr);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
99
streaming/src/util/streaming_util.h
Normal file
99
streaming/src/util/streaming_util.h
Normal file
|
@ -0,0 +1,99 @@
|
|||
#ifndef RAY_STREAMING_UTIL_H
|
||||
#define RAY_STREAMING_UTIL_H
|
||||
#include <boost/any.hpp>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "util/streaming_logging.h"
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
enum class ConfigEnum : uint32_t {
|
||||
QUEUE_ID_VECTOR = 0,
|
||||
RECONSTRUCT_RETRY_TIMES,
|
||||
RECONSTRUCT_TIMEOUT_PER_MB,
|
||||
CURRENT_DRIVER_ID,
|
||||
/// For direct call
|
||||
CORE_WORKER,
|
||||
SYNC_FUNCTION,
|
||||
ASYNC_FUNCTION,
|
||||
TRANSFER_MIN = QUEUE_ID_VECTOR,
|
||||
TRANSFER_MAX = ASYNC_FUNCTION
|
||||
};
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<::ray::streaming::ConfigEnum> {
|
||||
size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const {
|
||||
return static_cast<uint32_t>(config_enum_key);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hash<const ::ray::streaming::ConfigEnum> {
|
||||
size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const {
|
||||
return static_cast<uint32_t>(config_enum_key);
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
namespace ray {
|
||||
namespace streaming {
|
||||
|
||||
class Config {
|
||||
public:
|
||||
template <typename ValueType>
|
||||
inline void Set(ConfigEnum key, const ValueType &any) {
|
||||
config_map_.emplace(key, any);
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
inline void Set(ConfigEnum key, ValueType &&any) {
|
||||
config_map_.emplace(key, any);
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
inline boost::any &GetOrDefault(ConfigEnum key, ValueType &&any) {
|
||||
auto item = config_map_.find(key);
|
||||
if (item != config_map_.end()) {
|
||||
return item->second;
|
||||
}
|
||||
Set(key, any);
|
||||
return any;
|
||||
}
|
||||
|
||||
boost::any &Get(ConfigEnum key) const;
|
||||
|
||||
boost::any Get(ConfigEnum key, boost::any default_value) const;
|
||||
|
||||
inline uint32_t GetInt32(ConfigEnum key) { return boost::any_cast<uint32_t>(Get(key)); }
|
||||
|
||||
inline uint64_t GetInt64(ConfigEnum key) { return boost::any_cast<uint64_t>(Get(key)); }
|
||||
|
||||
inline double GetDouble(ConfigEnum key) { return boost::any_cast<double>(Get(key)); }
|
||||
|
||||
inline bool GetBool(ConfigEnum key) { return boost::any_cast<bool>(Get(key)); }
|
||||
|
||||
inline std::string GetString(ConfigEnum key) {
|
||||
return boost::any_cast<std::string>(Get(key));
|
||||
}
|
||||
|
||||
virtual ~Config() = default;
|
||||
|
||||
protected:
|
||||
mutable std::unordered_map<ConfigEnum, boost::any> config_map_;
|
||||
};
|
||||
|
||||
class Util {
|
||||
public:
|
||||
static std::string Byte2hex(const uint8_t *data, uint32_t data_size);
|
||||
|
||||
static std::string Hexqid2str(const std::string &q_id_hex);
|
||||
};
|
||||
} // namespace streaming
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_STREAMING_UTIL_H
|
31
thirdparty/patches/grpc-cython-copts.patch
vendored
31
thirdparty/patches/grpc-cython-copts.patch
vendored
|
@ -1,30 +1,49 @@
|
|||
diff --git bazel/cython_library.bzl bazel/cython_library.bzl
|
||||
index 48b41d74e8..6084734f59 100644
|
||||
index 48b41d74e8..a9bc168e5d 100644
|
||||
--- bazel/cython_library.bzl
|
||||
+++ bazel/cython_library.bzl
|
||||
@@ -7,7 +7,7 @@
|
||||
@@ -7,18 +7,20 @@
|
||||
# been written at cython/cython and tensorflow/tensorflow. We branch from
|
||||
# Tensorflow's version as it is more actively maintained and works for gRPC
|
||||
# Python's needs.
|
||||
-def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs):
|
||||
+def pyx_library(name, deps=[], copts=[], py_deps=[], srcs=[], **kwargs):
|
||||
+def pyx_library(name, deps=[], copts=[], cc_kwargs={}, py_deps=[], srcs=[], **kwargs):
|
||||
"""Compiles a group of .pyx / .pxd / .py files.
|
||||
|
||||
First runs Cython to create .cpp files for each input .pyx or .py + .pxd
|
||||
@@ -19,6 +19,7 @@ def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs):
|
||||
- pair. Then builds a shared object for each, passing "deps" to each cc_binary
|
||||
- rule (includes Python headers by default). Finally, creates a py_library rule
|
||||
- with the shared objects and any pure Python "srcs", with py_deps as its
|
||||
- dependencies; the shared objects can be imported like normal Python files.
|
||||
+ pair. Then builds a shared object for each, passing "deps" and `**cc_kwargs`
|
||||
+ to each cc_binary rule (includes Python headers by default). Finally, creates
|
||||
+ a py_library rule with the shared objects and any pure Python "srcs", with py_deps
|
||||
+ as its dependencies; the shared objects can be imported like normal Python files.
|
||||
|
||||
Args:
|
||||
name: Name for the rule.
|
||||
deps: C/C++ dependencies of the Cython (e.g. Numpy headers).
|
||||
+ copts: C/C++ compiler options for Cython
|
||||
+ cc_kwargs: cc_binary extra arguments such as linkstatic, linkopts, features
|
||||
py_deps: Pure Python dependencies of the final library.
|
||||
srcs: .py, .pyx, or .pxd files to either compile or pass through.
|
||||
**kwargs: Extra keyword arguments passed to the py_library.
|
||||
@@ -58,6 +59,7 @@ def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs):
|
||||
@@ -57,9 +59,11 @@ def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs):
|
||||
shared_object_name = stem + ".so"
|
||||
native.cc_binary(
|
||||
name=shared_object_name,
|
||||
srcs=[stem + ".cpp"],
|
||||
- srcs=[stem + ".cpp"],
|
||||
+ srcs=[stem + ".cpp"] + cc_kwargs.pop("srcs", []),
|
||||
+ copts=copts,
|
||||
deps=deps + ["@local_config_python//:python_headers"],
|
||||
linkshared=1,
|
||||
+ **cc_kwargs
|
||||
)
|
||||
shared_objects.append(shared_object_name)
|
||||
|
||||
@@ -72,3 +76,4 @@ def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs):
|
||||
data=shared_objects,
|
||||
**kwargs)
|
||||
|
||||
+
|
||||
--
|
||||
|
|
Loading…
Add table
Reference in a new issue