ray/streaming/python/runtime/task.py

173 lines
6.1 KiB
Python

import logging
import threading
from abc import ABC, abstractmethod
from ray.streaming.collector import OutputCollector
from ray.streaming.config import Config
from ray.streaming.context import RuntimeContextImpl
from ray.streaming.runtime import serialization
from ray.streaming.runtime.serialization import \
PythonSerializer, CrossLangSerializer
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
logger = logging.getLogger(__name__)
class StreamTask(ABC):
"""Base class for all streaming tasks. Each task runs a processor."""
def __init__(self, task_id, processor, worker):
self.task_id = task_id
self.processor = processor
self.worker = worker
self.config = worker.config
self.reader = None # DataReader
self.writers = {} # ExecutionEdge -> DataWriter
self.thread = None
self.prepare_task()
self.thread = threading.Thread(target=self.run, daemon=True)
def prepare_task(self):
channel_conf = dict(self.worker.config)
channel_size = int(
self.worker.config.get(Config.CHANNEL_SIZE,
Config.CHANNEL_SIZE_DEFAULT))
channel_conf[Config.CHANNEL_SIZE] = channel_size
channel_conf[Config.CHANNEL_TYPE] = self.worker.config \
.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL)
execution_vertex_context = self.worker.execution_vertex_context
build_time = execution_vertex_context.build_time
# writers
collectors = []
output_actors_map = {}
for edge in execution_vertex_context.output_execution_edges:
target_task_id = edge.target_execution_vertex_id
target_actor = execution_vertex_context\
.get_target_actor_by_vertex_id(target_task_id)
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
build_time)
output_actors_map[channel_name] = target_actor
if len(output_actors_map) > 0:
channel_ids = list(output_actors_map.keys())
target_actors = list(output_actors_map.values())
logger.info(
"Create DataWriter channel_ids {}, target_actors {}."
.format(channel_ids, target_actors))
writer = DataWriter(channel_ids, target_actors, channel_conf)
self.writers[edge] = writer
collectors.append(
OutputCollector(writer, channel_ids, target_actors,
edge.partition))
# readers
input_actor_map = {}
for edge in execution_vertex_context.input_execution_edges:
source_task_id = edge.source_execution_vertex_id
source_actor = execution_vertex_context\
.get_source_actor_by_vertex_id(source_task_id)
channel_name = ChannelID.gen_id(source_task_id, self.task_id,
build_time)
input_actor_map[channel_name] = source_actor
if len(input_actor_map) > 0:
channel_ids = list(input_actor_map.keys())
from_actors = list(input_actor_map.values())
logger.info("Create DataReader, channels {}, input_actors {}."
.format(channel_ids, from_actors))
self.reader = DataReader(channel_ids, from_actors, channel_conf)
def exit_handler():
# Make DataReader stop read data when MockQueue destructor
# gets called to avoid crash
self.cancel_task()
import atexit
atexit.register(exit_handler)
# TODO(chaokunyang) add task/job config
runtime_context = RuntimeContextImpl(
self.worker.task_id,
execution_vertex_context.execution_vertex.execution_vertex_index,
execution_vertex_context.get_parallelism())
logger.info("open Processor {}".format(self.processor))
self.processor.open(collectors, runtime_context)
@abstractmethod
def init(self):
pass
def start(self):
self.thread.start()
@abstractmethod
def run(self):
pass
@abstractmethod
def cancel_task(self):
pass
class InputStreamTask(StreamTask):
"""Base class for stream tasks that execute a
:class:`runtime.processor.OneInputProcessor` or
:class:`runtime.processor.TwoInputProcessor` """
def __init__(self, task_id, processor_instance, worker):
super().__init__(task_id, processor_instance, worker)
self.running = True
self.stopped = False
self.read_timeout_millis = \
int(worker.config.get(Config.READ_TIMEOUT_MS,
Config.DEFAULT_READ_TIMEOUT_MS))
self.python_serializer = PythonSerializer()
self.cross_lang_serializer = CrossLangSerializer()
def init(self):
pass
def run(self):
while self.running:
item = self.reader.read(self.read_timeout_millis)
if item is not None:
msg_data = item.body()
type_id = msg_data[:1]
if (type_id == serialization._PYTHON_TYPE_ID):
msg = self.python_serializer.deserialize(msg_data[1:])
else:
msg = self.cross_lang_serializer.deserialize(msg_data[1:])
self.processor.process(msg)
self.stopped = True
def cancel_task(self):
self.running = False
while not self.stopped:
pass
class OneInputStreamTask(InputStreamTask):
"""A stream task for executing :class:`runtime.processor.OneInputProcessor`
"""
def __init__(self, task_id, processor_instance, worker):
super().__init__(task_id, processor_instance, worker)
class SourceStreamTask(StreamTask):
"""A stream task for executing :class:`runtime.processor.SourceProcessor`
"""
def __init__(self, task_id, processor_instance, worker):
super().__init__(task_id, processor_instance, worker)
def init(self):
pass
def run(self):
self.processor.run()
def cancel_task(self):
pass