import logging import pickle import threading from abc import ABC, abstractmethod import ray from ray.streaming.collector import OutputCollector from ray.streaming.config import Config from ray.streaming.context import RuntimeContextImpl from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader logger = logging.getLogger(__name__) class StreamTask(ABC): """Base class for all streaming tasks. Each task runs a processor.""" def __init__(self, task_id, processor, worker): self.task_id = task_id self.processor = processor self.worker = worker self.reader = None # DataReader self.writers = {} # ExecutionEdge -> DataWriter self.thread = None self.prepare_task() self.thread = threading.Thread(target=self.run, daemon=True) def prepare_task(self): channel_conf = dict(self.worker.config) channel_size = int( self.worker.config.get(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT)) channel_conf[Config.CHANNEL_SIZE] = channel_size channel_conf[Config.TASK_JOB_ID] = ray.runtime_context.\ _get_runtime_context().current_driver_id channel_conf[Config.CHANNEL_TYPE] = self.worker.config \ .get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL) execution_graph = self.worker.execution_graph execution_node = self.worker.execution_node # writers collectors = [] for edge in execution_node.output_edges: output_actor_ids = {} task_id2_worker = execution_graph.get_task_id2_worker_by_node_id( edge.target_node_id) for target_task_id, target_actor in task_id2_worker.items(): channel_name = ChannelID.gen_id(self.task_id, target_task_id, execution_graph.build_time()) output_actor_ids[channel_name] = target_actor if len(output_actor_ids) > 0: channel_ids = list(output_actor_ids.keys()) to_actor_ids = list(output_actor_ids.values()) writer = DataWriter(channel_ids, to_actor_ids, channel_conf) logger.info("Create DataWriter succeed.") self.writers[edge] = writer collectors.append( OutputCollector(channel_ids, writer, edge.partition)) # readers input_actor_ids = {} for edge in execution_node.input_edges: task_id2_worker = execution_graph.get_task_id2_worker_by_node_id( edge.src_node_id) for src_task_id, src_actor in task_id2_worker.items(): channel_name = ChannelID.gen_id(src_task_id, self.task_id, execution_graph.build_time()) input_actor_ids[channel_name] = src_actor if len(input_actor_ids) > 0: channel_ids = list(input_actor_ids.keys()) from_actor_ids = list(input_actor_ids.values()) logger.info("Create DataReader, channels {}.".format(channel_ids)) self.reader = DataReader(channel_ids, from_actor_ids, channel_conf) def exit_handler(): # Make DataReader stop read data when MockQueue destructor # gets called to avoid crash self.cancel_task() import atexit atexit.register(exit_handler) runtime_context = RuntimeContextImpl( self.worker.execution_task.task_id, self.worker.execution_task.task_index, execution_node.parallelism) logger.info("open Processor {}".format(self.processor)) self.processor.open(collectors, runtime_context) @abstractmethod def init(self): pass def start(self): self.thread.start() @abstractmethod def run(self): pass @abstractmethod def cancel_task(self): pass class InputStreamTask(StreamTask): """Base class for stream tasks that execute a :class:`runtime.processor.OneInputProcessor` or :class:`runtime.processor.TwoInputProcessor` """ def __init__(self, task_id, processor_instance, worker): super().__init__(task_id, processor_instance, worker) self.running = True self.stopped = False self.read_timeout_millis = \ int(worker.config.get(Config.READ_TIMEOUT_MS, Config.DEFAULT_READ_TIMEOUT_MS)) def init(self): pass def run(self): while self.running: item = self.reader.read(self.read_timeout_millis) if item is not None: msg_data = item.body() msg = pickle.loads(msg_data) self.processor.process(msg) self.stopped = True def cancel_task(self): self.running = False while not self.stopped: pass class OneInputStreamTask(InputStreamTask): """A stream task for executing :class:`runtime.processor.OneInputProcessor` """ def __init__(self, task_id, processor_instance, worker): super().__init__(task_id, processor_instance, worker) class SourceStreamTask(StreamTask): """A stream task for executing :class:`runtime.processor.SourceProcessor` """ def __init__(self, task_id, processor_instance, worker): super().__init__(task_id, processor_instance, worker) def init(self): pass def run(self): self.processor.run() def cancel_task(self): pass