ray/streaming/python/runtime/transfer.py
2020-04-29 13:42:08 +08:00

373 lines
13 KiB
Python

import logging
import random
from queue import Queue
from typing import List
import ray
import ray.streaming._streaming as _streaming
import ray.streaming.generated.streaming_pb2 as streaming_pb
from ray.actor import ActorHandle
from ray.streaming.config import Config
from ray._raylet import JavaFunctionDescriptor
from ray._raylet import PythonFunctionDescriptor
from ray._raylet import Language
CHANNEL_ID_LEN = 20
class ChannelID:
"""
ChannelID is used to identify a transfer channel between
a upstream worker and downstream worker.
"""
def __init__(self, channel_id_str: str):
"""
Args:
channel_id_str: string representation of channel id
"""
self.channel_id_str = channel_id_str
self.object_qid = ray.ObjectID(channel_id_str_to_bytes(channel_id_str))
def __eq__(self, other):
if other is None:
return False
if type(other) is ChannelID:
return self.channel_id_str == other.channel_id_str
else:
return False
def __hash__(self):
return hash(self.channel_id_str)
def __repr__(self):
return self.channel_id_str
@staticmethod
def gen_random_id():
"""Generate a random channel id string
"""
res = ""
for i in range(CHANNEL_ID_LEN * 2):
res += str(chr(random.randint(0, 5) + ord("A")))
return res
@staticmethod
def gen_id(from_index, to_index, ts):
"""Generate channel id, which is 20 character"""
channel_id = bytearray(20)
for i in range(11, 7, -1):
channel_id[i] = ts & 0xff
ts >>= 8
channel_id[16] = (from_index & 0xffff) >> 8
channel_id[17] = (from_index & 0xff)
channel_id[18] = (to_index & 0xffff) >> 8
channel_id[19] = (to_index & 0xff)
return channel_bytes_to_str(bytes(channel_id))
def channel_id_str_to_bytes(channel_id_str):
"""
Args:
channel_id_str: string representation of channel id
Returns:
bytes representation of channel id
"""
assert type(channel_id_str) in [str, bytes]
if isinstance(channel_id_str, bytes):
return channel_id_str
qid_bytes = bytes.fromhex(channel_id_str)
assert len(qid_bytes) == CHANNEL_ID_LEN
return qid_bytes
def channel_bytes_to_str(id_bytes):
"""
Args:
id_bytes: bytes representation of channel id
Returns:
string representation of channel id
"""
assert type(id_bytes) in [str, bytes]
if isinstance(id_bytes, str):
return id_bytes
return bytes.hex(id_bytes)
class DataMessage:
"""
DataMessage represents data between upstream and downstream operator
"""
def __init__(self,
body,
timestamp,
channel_id,
message_id_,
is_empty_message=False):
self.__body = body
self.__timestamp = timestamp
self.__channel_id = channel_id
self.__message_id = message_id_
self.__is_empty_message = is_empty_message
def __len__(self):
return len(self.__body)
def body(self):
"""Message data"""
return self.__body
def timestamp(self):
"""Get timestamp when item is written by upstream DataWriter
"""
return self.__timestamp
def channel_id(self):
"""Get string id of channel where data is coming from
"""
return self.__channel_id
def is_empty_message(self):
"""Whether this message is an empty message.
Upstream DataWriter will send an empty message when this is no data
in specified interval.
"""
return self.__is_empty_message
@property
def message_id(self):
return self.__message_id
class ChannelCreationParametersBuilder:
"""
wrap initial parameters needed by a streaming queue
"""
_java_reader_async_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker.JobWorker", "onReaderMessage",
"([B)V")
_java_reader_sync_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker.JobWorker", "onReaderMessageSync",
"([B)[B")
_java_writer_async_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker.JobWorker", "onWriterMessage",
"([B)V")
_java_writer_sync_function_descriptor = JavaFunctionDescriptor(
"io.ray.streaming.runtime.worker.JobWorker", "onWriterMessageSync",
"([B)[B")
_python_reader_async_function_descriptor = PythonFunctionDescriptor(
"ray.streaming.runtime.worker", "on_reader_message", "JobWorker")
_python_reader_sync_function_descriptor = PythonFunctionDescriptor(
"ray.streaming.runtime.worker", "on_reader_message_sync", "JobWorker")
_python_writer_async_function_descriptor = PythonFunctionDescriptor(
"ray.streaming.runtime.worker", "on_writer_message", "JobWorker")
_python_writer_sync_function_descriptor = PythonFunctionDescriptor(
"ray.streaming.runtime.worker", "on_writer_message_sync", "JobWorker")
def get_parameters(self):
return self._parameters
def __init__(self):
self._parameters = []
def build_input_queue_parameters(self, queue_ids_dict):
self.build_parameters(queue_ids_dict,
self._java_writer_async_function_descriptor,
self._java_writer_sync_function_descriptor,
self._python_writer_async_function_descriptor,
self._python_writer_sync_function_descriptor)
return self
def build_output_queue_parameters(self, to_actors):
self.build_parameters(to_actors,
self._java_reader_async_function_descriptor,
self._java_reader_sync_function_descriptor,
self._python_reader_async_function_descriptor,
self._python_reader_sync_function_descriptor)
return self
def build_parameters(self, actors, java_async_func, java_sync_func,
py_async_func, py_sync_func):
for handle in actors:
parameter = None
if handle._ray_actor_language == Language.PYTHON:
parameter = _streaming.ChannelCreationParameter(
handle._ray_actor_id, py_async_func, py_sync_func)
else:
parameter = _streaming.ChannelCreationParameter(
handle._ray_actor_id, java_async_func, java_sync_func)
self._parameters.append(parameter)
return self
@staticmethod
def set_python_writer_function_descriptor(async_function, sync_function):
ChannelCreationParametersBuilder. \
_python_writer_async_function_descriptor = async_function
ChannelCreationParametersBuilder. \
_python_writer_sync_function_descriptor = sync_function
@staticmethod
def set_python_reader_function_descriptor(async_function, sync_function):
ChannelCreationParametersBuilder. \
_python_reader_async_function_descriptor = async_function
ChannelCreationParametersBuilder. \
_python_reader_sync_function_descriptor = sync_function
logger = logging.getLogger(__name__)
class DataWriter:
"""Data Writer is a wrapper of streaming c++ DataWriter, which sends data
to downstream workers
"""
def __init__(self, output_channels, to_actors: List[ActorHandle],
conf: dict):
"""Get DataWriter of output channels
Args:
output_channels: output channels ids
to_actors: downstream output actors
Returns:
DataWriter
"""
assert len(output_channels) > 0
py_output_channels = [
channel_id_str_to_bytes(qid_str) for qid_str in output_channels
]
creation_parameters = ChannelCreationParametersBuilder()
creation_parameters.build_output_queue_parameters(to_actors)
channel_size = conf.get(Config.CHANNEL_SIZE,
Config.CHANNEL_SIZE_DEFAULT)
py_msg_ids = [0 for _ in range(len(output_channels))]
config_bytes = _to_native_conf(conf)
is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL
self.writer = _streaming.DataWriter.create(
py_output_channels, creation_parameters.get_parameters(),
channel_size, py_msg_ids, config_bytes, is_mock)
logger.info("create DataWriter succeed")
def write(self, channel_id: ChannelID, item: bytes):
"""Write data into native channel
Args:
channel_id: channel id
item: bytes data
Returns:
msg_id
"""
assert type(item) == bytes
msg_id = self.writer.write(channel_id.object_qid, item)
return msg_id
def stop(self):
logger.info("stopping channel writer.")
self.writer.stop()
# destruct DataWriter
self.writer = None
def close(self):
logger.info("closing channel writer.")
class DataReader:
"""Data Reader is wrapper of streaming c++ DataReader, which read data
from channels of upstream workers
"""
def __init__(self, input_channels: List, from_actors: List[ActorHandle],
conf: dict):
"""Get DataReader of input channels
Args:
input_channels: input channels
from_actors: upstream input actors
Returns:
DataReader
"""
assert len(input_channels) > 0
py_input_channels = [
channel_id_str_to_bytes(qid_str) for qid_str in input_channels
]
creation_parameters = ChannelCreationParametersBuilder()
creation_parameters.build_input_queue_parameters(from_actors)
py_seq_ids = [0 for _ in range(len(input_channels))]
py_msg_ids = [0 for _ in range(len(input_channels))]
timer_interval = int(conf.get(Config.TIMER_INTERVAL_MS, -1))
is_recreate = bool(conf.get(Config.IS_RECREATE, False))
config_bytes = _to_native_conf(conf)
self.__queue = Queue(10000)
is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL
self.reader = _streaming.DataReader.create(
py_input_channels, creation_parameters.get_parameters(),
py_seq_ids, py_msg_ids, timer_interval, is_recreate, config_bytes,
is_mock)
logger.info("create DataReader succeed")
def read(self, timeout_millis):
"""Read data from channel
Args:
timeout_millis: timeout millis when there is no data in channel
for this duration
Returns:
channel item
"""
if self.__queue.empty():
msgs = self.reader.read(timeout_millis)
for msg in msgs:
msg_bytes, msg_id, timestamp, qid_bytes = msg
data_msg = DataMessage(msg_bytes, timestamp,
channel_bytes_to_str(qid_bytes), msg_id)
self.__queue.put(data_msg)
if self.__queue.empty():
return None
return self.__queue.get()
def stop(self):
logger.info("stopping Data Reader.")
self.reader.stop()
# destruct DataReader
self.reader = None
def close(self):
logger.info("closing Data Reader.")
def _to_native_conf(conf):
config = streaming_pb.StreamingConfig()
if Config.STREAMING_JOB_NAME in conf:
config.job_name = conf[Config.STREAMING_JOB_NAME]
if Config.STREAMING_WORKER_NAME in conf:
config.worker_name = conf[Config.STREAMING_WORKER_NAME]
if Config.STREAMING_OP_NAME in conf:
config.op_name = conf[Config.STREAMING_OP_NAME]
# TODO set operator type
if Config.STREAMING_RING_BUFFER_CAPACITY in conf:
config.ring_buffer_capacity = \
conf[Config.STREAMING_RING_BUFFER_CAPACITY]
if Config.STREAMING_EMPTY_MESSAGE_INTERVAL in conf:
config.empty_message_interval = \
conf[Config.STREAMING_EMPTY_MESSAGE_INTERVAL]
if Config.FLOW_CONTROL_TYPE in conf:
conf.flow_control_type = conf[Config.FLOW_CONTROL_TYPE]
if Config.WRITER_CONSUMED_STEP in conf:
conf.writer_consumed_step = \
conf[Config.WRITER_CONSUMED_STEP]
if Config.READER_CONSUMED_STEP in conf:
conf.reader_consumed_step = \
conf[Config.READER_CONSUMED_STEP]
logger.info("conf: %s", str(config))
return config.SerializeToString()
class ChannelInitException(Exception):
def __init__(self, msg, abnormal_channels):
self.abnormal_channels = abnormal_channels
self.msg = msg
class ChannelInterruptException(Exception):
def __init__(self, msg=None):
self.msg = msg