mirror of
https://github.com/vale981/ray
synced 2025-03-08 11:31:40 -05:00
173 lines
6.1 KiB
Python
173 lines
6.1 KiB
Python
import logging
|
|
import threading
|
|
from abc import ABC, abstractmethod
|
|
|
|
from ray.streaming.collector import OutputCollector
|
|
from ray.streaming.config import Config
|
|
from ray.streaming.context import RuntimeContextImpl
|
|
from ray.streaming.runtime import serialization
|
|
from ray.streaming.runtime.serialization import \
|
|
PythonSerializer, CrossLangSerializer
|
|
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class StreamTask(ABC):
|
|
"""Base class for all streaming tasks. Each task runs a processor."""
|
|
|
|
def __init__(self, task_id, processor, worker):
|
|
self.task_id = task_id
|
|
self.processor = processor
|
|
self.worker = worker
|
|
self.config = worker.config
|
|
self.reader = None # DataReader
|
|
self.writers = {} # ExecutionEdge -> DataWriter
|
|
self.thread = None
|
|
self.prepare_task()
|
|
self.thread = threading.Thread(target=self.run, daemon=True)
|
|
|
|
def prepare_task(self):
|
|
channel_conf = dict(self.worker.config)
|
|
channel_size = int(
|
|
self.worker.config.get(Config.CHANNEL_SIZE,
|
|
Config.CHANNEL_SIZE_DEFAULT))
|
|
channel_conf[Config.CHANNEL_SIZE] = channel_size
|
|
channel_conf[Config.CHANNEL_TYPE] = self.worker.config \
|
|
.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL)
|
|
|
|
execution_vertex_context = self.worker.execution_vertex_context
|
|
build_time = execution_vertex_context.build_time
|
|
|
|
# writers
|
|
collectors = []
|
|
output_actors_map = {}
|
|
for edge in execution_vertex_context.output_execution_edges:
|
|
target_task_id = edge.target_execution_vertex_id
|
|
target_actor = execution_vertex_context\
|
|
.get_target_actor_by_vertex_id(target_task_id)
|
|
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
|
|
build_time)
|
|
output_actors_map[channel_name] = target_actor
|
|
|
|
if len(output_actors_map) > 0:
|
|
channel_ids = list(output_actors_map.keys())
|
|
target_actors = list(output_actors_map.values())
|
|
logger.info(
|
|
"Create DataWriter channel_ids {}, target_actors {}."
|
|
.format(channel_ids, target_actors))
|
|
writer = DataWriter(channel_ids, target_actors, channel_conf)
|
|
self.writers[edge] = writer
|
|
collectors.append(
|
|
OutputCollector(writer, channel_ids, target_actors,
|
|
edge.partition))
|
|
|
|
# readers
|
|
input_actor_map = {}
|
|
for edge in execution_vertex_context.input_execution_edges:
|
|
source_task_id = edge.source_execution_vertex_id
|
|
source_actor = execution_vertex_context\
|
|
.get_source_actor_by_vertex_id(source_task_id)
|
|
channel_name = ChannelID.gen_id(source_task_id, self.task_id,
|
|
build_time)
|
|
input_actor_map[channel_name] = source_actor
|
|
|
|
if len(input_actor_map) > 0:
|
|
channel_ids = list(input_actor_map.keys())
|
|
from_actors = list(input_actor_map.values())
|
|
logger.info("Create DataReader, channels {}, input_actors {}."
|
|
.format(channel_ids, from_actors))
|
|
self.reader = DataReader(channel_ids, from_actors, channel_conf)
|
|
|
|
def exit_handler():
|
|
# Make DataReader stop read data when MockQueue destructor
|
|
# gets called to avoid crash
|
|
self.cancel_task()
|
|
|
|
import atexit
|
|
atexit.register(exit_handler)
|
|
|
|
# TODO(chaokunyang) add task/job config
|
|
runtime_context = RuntimeContextImpl(
|
|
self.worker.task_id,
|
|
execution_vertex_context.execution_vertex.execution_vertex_index,
|
|
execution_vertex_context.get_parallelism())
|
|
logger.info("open Processor {}".format(self.processor))
|
|
self.processor.open(collectors, runtime_context)
|
|
|
|
@abstractmethod
|
|
def init(self):
|
|
pass
|
|
|
|
def start(self):
|
|
self.thread.start()
|
|
|
|
@abstractmethod
|
|
def run(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def cancel_task(self):
|
|
pass
|
|
|
|
|
|
class InputStreamTask(StreamTask):
|
|
"""Base class for stream tasks that execute a
|
|
:class:`runtime.processor.OneInputProcessor` or
|
|
:class:`runtime.processor.TwoInputProcessor` """
|
|
|
|
def __init__(self, task_id, processor_instance, worker):
|
|
super().__init__(task_id, processor_instance, worker)
|
|
self.running = True
|
|
self.stopped = False
|
|
self.read_timeout_millis = \
|
|
int(worker.config.get(Config.READ_TIMEOUT_MS,
|
|
Config.DEFAULT_READ_TIMEOUT_MS))
|
|
self.python_serializer = PythonSerializer()
|
|
self.cross_lang_serializer = CrossLangSerializer()
|
|
|
|
def init(self):
|
|
pass
|
|
|
|
def run(self):
|
|
while self.running:
|
|
item = self.reader.read(self.read_timeout_millis)
|
|
if item is not None:
|
|
msg_data = item.body()
|
|
type_id = msg_data[:1]
|
|
if (type_id == serialization._PYTHON_TYPE_ID):
|
|
msg = self.python_serializer.deserialize(msg_data[1:])
|
|
else:
|
|
msg = self.cross_lang_serializer.deserialize(msg_data[1:])
|
|
self.processor.process(msg)
|
|
self.stopped = True
|
|
|
|
def cancel_task(self):
|
|
self.running = False
|
|
while not self.stopped:
|
|
pass
|
|
|
|
|
|
class OneInputStreamTask(InputStreamTask):
|
|
"""A stream task for executing :class:`runtime.processor.OneInputProcessor`
|
|
"""
|
|
|
|
def __init__(self, task_id, processor_instance, worker):
|
|
super().__init__(task_id, processor_instance, worker)
|
|
|
|
|
|
class SourceStreamTask(StreamTask):
|
|
"""A stream task for executing :class:`runtime.processor.SourceProcessor`
|
|
"""
|
|
|
|
def __init__(self, task_id, processor_instance, worker):
|
|
super().__init__(task_id, processor_instance, worker)
|
|
|
|
def init(self):
|
|
pass
|
|
|
|
def run(self):
|
|
self.processor.run()
|
|
|
|
def cancel_task(self):
|
|
pass
|