import logging import random from queue import Queue from typing import List from enum import Enum from abc import ABC, abstractmethod import ray import ray.streaming._streaming as _streaming import ray.streaming.generated.streaming_pb2 as streaming_pb from ray.actor import ActorHandle from ray.streaming.config import Config from ray._raylet import JavaFunctionDescriptor from ray._raylet import PythonFunctionDescriptor from ray._raylet import Language CHANNEL_ID_LEN = ray.ObjectID.nil().size() logger = logging.getLogger(__name__) class ChannelID: """ ChannelID is used to identify a transfer channel between a upstream worker and downstream worker. """ def __init__(self, channel_id_str: str): """ Args: channel_id_str: string representation of channel id """ self.channel_id_str = channel_id_str self.object_qid = ray.ObjectRef( channel_id_str_to_bytes(channel_id_str)) def __eq__(self, other): if other is None: return False if type(other) is ChannelID: return self.channel_id_str == other.channel_id_str else: return False def __hash__(self): return hash(self.channel_id_str) def __repr__(self): return self.channel_id_str @staticmethod def gen_random_id(): """Generate a random channel id string """ res = "" for i in range(CHANNEL_ID_LEN * 2): res += str(chr(random.randint(0, 5) + ord("A"))) return res @staticmethod def gen_id(from_index, to_index, ts): """Generate channel id, which is `CHANNEL_ID_LEN` character""" channel_id = bytearray(CHANNEL_ID_LEN) for i in range(11, 7, -1): channel_id[i] = ts & 0xff ts >>= 8 channel_id[16] = (from_index & 0xffff) >> 8 channel_id[17] = (from_index & 0xff) channel_id[18] = (to_index & 0xffff) >> 8 channel_id[19] = (to_index & 0xff) return channel_bytes_to_str(bytes(channel_id)) def channel_id_str_to_bytes(channel_id_str): """ Args: channel_id_str: string representation of channel id Returns: bytes representation of channel id """ assert type(channel_id_str) in [str, bytes] if isinstance(channel_id_str, bytes): return channel_id_str qid_bytes = bytes.fromhex(channel_id_str) assert len(qid_bytes) == CHANNEL_ID_LEN return qid_bytes def channel_bytes_to_str(id_bytes): """ Args: id_bytes: bytes representation of channel id Returns: string representation of channel id """ assert type(id_bytes) in [str, bytes] if isinstance(id_bytes, str): return id_bytes return bytes.hex(id_bytes) class Message(ABC): @property @abstractmethod def body(self): """Message data""" pass @property @abstractmethod def timestamp(self): """Get timestamp when item is written by upstream DataWriter """ pass @property @abstractmethod def channel_id(self): """Get string id of channel where data is coming from""" pass @property @abstractmethod def message_id(self): """Get message id of the message""" pass class DataMessage(Message): """ DataMessage represents data between upstream and downstream operator. """ def __init__(self, body, timestamp, message_id, channel_id, is_empty_message=False): self.__body = body self.__timestamp = timestamp self.__channel_id = channel_id self.__message_id = message_id self.__is_empty_message = is_empty_message def __len__(self): return len(self.__body) @property def body(self): return self.__body @property def timestamp(self): return self.__timestamp @property def channel_id(self): return self.__channel_id @property def message_id(self): return self.__message_id @property def is_empty_message(self): """Whether this message is an empty message. Upstream DataWriter will send an empty message when this is no data in specified interval. """ return self.__is_empty_message class CheckpointBarrier(Message): """ CheckpointBarrier separates the records in the data stream into the set of records that goes into the current snapshot, and the records that go into the next snapshot. Each barrier carries the ID of the snapshot whose records it pushed in front of it. """ def __init__(self, barrier_data, timestamp, message_id, channel_id, offsets, barrier_id, barrier_type): self.__barrier_data = barrier_data self.__timestamp = timestamp self.__message_id = message_id self.__channel_id = channel_id self.checkpoint_id = barrier_id self.offsets = offsets self.barrier_type = barrier_type @property def body(self): return self.__barrier_data @property def timestamp(self): return self.__timestamp @property def channel_id(self): return self.__channel_id @property def message_id(self): return self.__message_id def get_input_checkpoints(self): return self.offsets def __str__(self): return "Barrier(Checkpoint id : {})".format(self.checkpoint_id) class ChannelCreationParametersBuilder: """ wrap initial parameters needed by a streaming queue """ _java_reader_async_function_descriptor = JavaFunctionDescriptor( "io.ray.streaming.runtime.worker.JobWorker", "onReaderMessage", "([B)V") _java_reader_sync_function_descriptor = JavaFunctionDescriptor( "io.ray.streaming.runtime.worker.JobWorker", "onReaderMessageSync", "([B)[B") _java_writer_async_function_descriptor = JavaFunctionDescriptor( "io.ray.streaming.runtime.worker.JobWorker", "onWriterMessage", "([B)V") _java_writer_sync_function_descriptor = JavaFunctionDescriptor( "io.ray.streaming.runtime.worker.JobWorker", "onWriterMessageSync", "([B)[B") _python_reader_async_function_descriptor = PythonFunctionDescriptor( "ray.streaming.runtime.worker", "on_reader_message", "JobWorker") _python_reader_sync_function_descriptor = PythonFunctionDescriptor( "ray.streaming.runtime.worker", "on_reader_message_sync", "JobWorker") _python_writer_async_function_descriptor = PythonFunctionDescriptor( "ray.streaming.runtime.worker", "on_writer_message", "JobWorker") _python_writer_sync_function_descriptor = PythonFunctionDescriptor( "ray.streaming.runtime.worker", "on_writer_message_sync", "JobWorker") def get_parameters(self): return self._parameters def __init__(self): self._parameters = [] def build_input_queue_parameters(self, from_actors): self.build_parameters(from_actors, self._java_writer_async_function_descriptor, self._java_writer_sync_function_descriptor, self._python_writer_async_function_descriptor, self._python_writer_sync_function_descriptor) return self def build_output_queue_parameters(self, to_actors): self.build_parameters(to_actors, self._java_reader_async_function_descriptor, self._java_reader_sync_function_descriptor, self._python_reader_async_function_descriptor, self._python_reader_sync_function_descriptor) return self def build_parameters(self, actors, java_async_func, java_sync_func, py_async_func, py_sync_func): for handle in actors: parameter = None if handle._ray_actor_language == Language.PYTHON: parameter = _streaming.ChannelCreationParameter( handle._ray_actor_id, py_async_func, py_sync_func) else: parameter = _streaming.ChannelCreationParameter( handle._ray_actor_id, java_async_func, java_sync_func) self._parameters.append(parameter) return self @staticmethod def set_python_writer_function_descriptor(async_function, sync_function): ChannelCreationParametersBuilder. \ _python_writer_async_function_descriptor = async_function ChannelCreationParametersBuilder. \ _python_writer_sync_function_descriptor = sync_function @staticmethod def set_python_reader_function_descriptor(async_function, sync_function): ChannelCreationParametersBuilder. \ _python_reader_async_function_descriptor = async_function ChannelCreationParametersBuilder. \ _python_reader_sync_function_descriptor = sync_function class DataWriter: """Data Writer is a wrapper of streaming c++ DataWriter, which sends data to downstream workers """ def __init__(self, output_channels, to_actors: List[ActorHandle], conf: dict): """Get DataWriter of output channels Args: output_channels: output channels ids to_actors: downstream output actors Returns: DataWriter """ assert len(output_channels) > 0 py_output_channels = [ channel_id_str_to_bytes(qid_str) for qid_str in output_channels ] creation_parameters = ChannelCreationParametersBuilder() creation_parameters.build_output_queue_parameters(to_actors) channel_size = conf.get(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT) py_msg_ids = [0 for _ in range(len(output_channels))] config_bytes = _to_native_conf(conf) is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL self.writer = _streaming.DataWriter.create( py_output_channels, creation_parameters.get_parameters(), channel_size, py_msg_ids, config_bytes, is_mock) logger.info("create DataWriter succeed") def write(self, channel_id: ChannelID, item: bytes): """Write data into native channel Args: channel_id: channel id item: bytes data Returns: msg_id """ assert type(item) == bytes msg_id = self.writer.write(channel_id.object_qid, item) return msg_id def broadcast_barrier(self, checkpoint_id: int, body: bytes): """Broadcast barriers to all downstream channels Args: checkpoint_id: the checkpoint_id body: barrier payload """ self.writer.broadcast_barrier(checkpoint_id, body) def get_output_checkpoints(self) -> List[int]: """Get output offsets of all downstream channels Returns: a list contains current msg_id of each downstream channel """ return self.writer.get_output_checkpoints() def clear_checkpoint(self, checkpoint_id): logger.info("producer start to clear checkpoint, checkpoint_id={}" .format(checkpoint_id)) self.writer.clear_checkpoint(checkpoint_id) def stop(self): logger.info("stopping channel writer.") self.writer.stop() # destruct DataWriter self.writer = None def close(self): logger.info("closing channel writer.") class DataReader: """Data Reader is wrapper of streaming c++ DataReader, which read data from channels of upstream workers """ def __init__(self, input_channels: List, from_actors: List[ActorHandle], conf: dict): """Get DataReader of input channels Args: input_channels: input channels from_actors: upstream input actors Returns: DataReader """ assert len(input_channels) > 0 py_input_channels = [ channel_id_str_to_bytes(qid_str) for qid_str in input_channels ] creation_parameters = ChannelCreationParametersBuilder() creation_parameters.build_input_queue_parameters(from_actors) py_msg_ids = [0 for _ in range(len(input_channels))] timer_interval = int(conf.get(Config.TIMER_INTERVAL_MS, -1)) config_bytes = _to_native_conf(conf) self.__queue = Queue(10000) is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL self.reader, queues_creation_status = _streaming.DataReader.create( py_input_channels, creation_parameters.get_parameters(), py_msg_ids, timer_interval, config_bytes, is_mock) self.__creation_status = {} for q, status in queues_creation_status.items(): self.__creation_status[q] = ChannelCreationStatus(status) logger.info("create DataReader succeed, creation_status={}".format( self.__creation_status)) def read(self, timeout_millis): """Read data from channel Args: timeout_millis: timeout millis when there is no data in channel for this duration Returns: channel item """ if self.__queue.empty(): messages = self.reader.read(timeout_millis) for message in messages: self.__queue.put(message) if self.__queue.empty(): return None return self.__queue.get() def get_channel_recover_info(self): return ChannelRecoverInfo(self.__creation_status) def stop(self): logger.info("stopping Data Reader.") self.reader.stop() # destruct DataReader self.reader = None def close(self): logger.info("closing Data Reader.") def _to_native_conf(conf): config = streaming_pb.StreamingConfig() if Config.STREAMING_JOB_NAME in conf: config.job_name = conf[Config.STREAMING_JOB_NAME] if Config.STREAMING_WORKER_NAME in conf: config.worker_name = conf[Config.STREAMING_WORKER_NAME] if Config.STREAMING_OP_NAME in conf: config.op_name = conf[Config.STREAMING_OP_NAME] # TODO set operator type if Config.STREAMING_RING_BUFFER_CAPACITY in conf: config.ring_buffer_capacity = \ conf[Config.STREAMING_RING_BUFFER_CAPACITY] if Config.STREAMING_EMPTY_MESSAGE_INTERVAL in conf: config.empty_message_interval = \ conf[Config.STREAMING_EMPTY_MESSAGE_INTERVAL] if Config.FLOW_CONTROL_TYPE in conf: conf.flow_control_type = conf[Config.FLOW_CONTROL_TYPE] if Config.WRITER_CONSUMED_STEP in conf: conf.writer_consumed_step = \ conf[Config.WRITER_CONSUMED_STEP] if Config.READER_CONSUMED_STEP in conf: conf.reader_consumed_step = \ conf[Config.READER_CONSUMED_STEP] logger.info("conf: %s", str(config)) return config.SerializeToString() class ChannelInitException(Exception): def __init__(self, msg, abnormal_channels): self.abnormal_channels = abnormal_channels self.msg = msg class ChannelInterruptException(Exception): def __init__(self, msg=None): self.msg = msg class ChannelRecoverInfo: def __init__(self, queue_creation_status_map=None): if queue_creation_status_map is None: queue_creation_status_map = {} self.__queue_creation_status_map = queue_creation_status_map def get_creation_status(self): return self.__queue_creation_status_map def get_data_lost_queues(self): data_lost_queues = set() for (q, status) in self.__queue_creation_status_map.items(): if status == ChannelCreationStatus.DataLost: data_lost_queues.add(q) return data_lost_queues def __str__(self): return "QueueRecoverInfo [dataLostQueues=%s]" \ % (self.get_data_lost_queues()) class ChannelCreationStatus(Enum): FreshStarted = 0 PullOk = 1 Timeout = 2 DataLost = 3 def channel_id_bytes_to_str(id_bytes): """ Args: id_bytes: bytes representation of channel id Returns: string representation of channel id """ assert type(id_bytes) in [str, bytes] if isinstance(id_bytes, str): return id_bytes return bytes.hex(id_bytes)