ray/streaming/python/runtime/task.py

382 lines
15 KiB
Python

import logging
import pickle
import threading
import time
import typing
from abc import ABC, abstractmethod
from typing import Optional
from ray.streaming.collector import OutputCollector
from ray.streaming.config import Config
from ray.streaming.context import RuntimeContextImpl
from ray.streaming.generated import remote_call_pb2
from ray.streaming.runtime import serialization
from ray.streaming.runtime.command import WorkerCommitReport
from ray.streaming.runtime.failover import Barrier, OpCheckpointInfo
from ray.streaming.runtime.remote_call import RemoteCallMst
from ray.streaming.runtime.serialization import \
PythonSerializer, CrossLangSerializer
from ray.streaming.runtime.transfer import CheckpointBarrier
from ray.streaming.runtime.transfer import DataMessage
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
from ray.streaming.runtime.transfer import ChannelRecoverInfo
from ray.streaming.runtime.transfer import ChannelInterruptException
if typing.TYPE_CHECKING:
from ray.streaming.runtime.worker import JobWorker
from ray.streaming.runtime.processor import Processor, SourceProcessor
logger = logging.getLogger(__name__)
class StreamTask(ABC):
"""Base class for all streaming tasks. Each task runs a processor."""
def __init__(self, task_id: int, processor: "Processor",
worker: "JobWorker", last_checkpoint_id: int):
self.worker_context = worker.worker_context
self.vertex_context = worker.execution_vertex_context
self.task_id = task_id
self.processor = processor
self.worker = worker
self.config: dict = worker.config
self.reader: Optional[DataReader] = None
self.writer: Optional[DataWriter] = None
self.is_initial_state = True
self.last_checkpoint_id: int = last_checkpoint_id
self.thread = threading.Thread(target=self.run, daemon=True)
def do_checkpoint(self, checkpoint_id: int, input_points):
logger.info("Start do checkpoint, cp id {}, inputPoints {}.".format(
checkpoint_id, input_points))
output_points = None
if self.writer is not None:
output_points = self.writer.get_output_checkpoints()
operator_checkpoint = self.processor.save_checkpoint()
op_checkpoint_info = OpCheckpointInfo(
operator_checkpoint, input_points, output_points, checkpoint_id)
self.__save_cp_state_and_report(op_checkpoint_info, checkpoint_id)
barrier_pb = remote_call_pb2.Barrier()
barrier_pb.id = checkpoint_id
byte_buffer = barrier_pb.SerializeToString()
if self.writer is not None:
self.writer.broadcast_barrier(checkpoint_id, byte_buffer)
logger.info("Operator checkpoint {} finish.".format(checkpoint_id))
def __save_cp_state_and_report(self, op_checkpoint_info, checkpoint_id):
logger.info(
"Start to save cp state and report, checkpoint id is {}.".format(
checkpoint_id))
self.__save_cp(op_checkpoint_info, checkpoint_id)
self.__report_commit(checkpoint_id)
self.last_checkpoint_id = checkpoint_id
def __save_cp(self, op_checkpoint_info, checkpoint_id):
logger.info("save operator cp, op_checkpoint_info={}".format(
op_checkpoint_info))
cp_bytes = pickle.dumps(op_checkpoint_info)
self.worker.context_backend.put(
self.__gen_op_checkpoint_key(checkpoint_id), cp_bytes)
def __report_commit(self, checkpoint_id: int):
logger.info("Report commit, checkpoint id {}.".format(checkpoint_id))
report = WorkerCommitReport(self.vertex_context.actor_id.binary(),
checkpoint_id)
RemoteCallMst.report_job_worker_commit(self.worker.master_actor,
report)
def clear_expired_cp_state(self, checkpoint_id):
cp_key = self.__gen_op_checkpoint_key(checkpoint_id)
self.worker.context_backend.remove(cp_key)
def clear_expired_queue_msg(self, checkpoint_id):
# clear operator checkpoint
if self.writer is not None:
self.writer.clear_checkpoint(checkpoint_id)
def request_rollback(self, exception_msg: str):
self.worker.request_rollback(exception_msg)
def __gen_op_checkpoint_key(self, checkpoint_id):
op_checkpoint_key = Config.JOB_WORKER_OP_CHECKPOINT_PREFIX_KEY + str(
self.vertex_context.job_name) + "_" + str(
self.vertex_context.exe_vertex_name) + "_" + str(checkpoint_id)
logger.info(
"Generate op checkpoint key {}. ".format(op_checkpoint_key))
return op_checkpoint_key
def prepare_task(self, is_recreate: bool):
logger.info(
"Preparing stream task, is_recreate={}.".format(is_recreate))
channel_conf = dict(self.worker.config)
channel_size = int(
self.worker.config.get(Config.CHANNEL_SIZE,
Config.CHANNEL_SIZE_DEFAULT))
channel_conf[Config.CHANNEL_SIZE] = channel_size
channel_conf[Config.CHANNEL_TYPE] = self.worker.config \
.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL)
execution_vertex_context = self.worker.execution_vertex_context
build_time = execution_vertex_context.build_time
# when use memory state, if actor throw exception, will miss state
op_checkpoint_info = OpCheckpointInfo()
cp_bytes = None
# get operator checkpoint
if is_recreate:
cp_key = self.__gen_op_checkpoint_key(self.last_checkpoint_id)
logger.info("Getting task checkpoints from state, "
"cpKey={}, checkpointId={}.".format(
cp_key, self.last_checkpoint_id))
cp_bytes = self.worker.context_backend.get(cp_key)
if cp_bytes is None:
msg = "Task recover failed, checkpoint is null!"\
"cpKey={}".format(cp_key)
raise RuntimeError(msg)
if cp_bytes is not None:
op_checkpoint_info = pickle.loads(cp_bytes)
self.processor.load_checkpoint(op_checkpoint_info.operator_point)
logger.info("Stream task recover from checkpoint state,"
"checkpoint bytes len={}, checkpointInfo={}.".format(
cp_bytes.__len__(), op_checkpoint_info))
# writers
collectors = []
output_actors_map = {}
for edge in execution_vertex_context.output_execution_edges:
target_task_id = edge.target_execution_vertex_id
target_actor = execution_vertex_context \
.get_target_actor_by_execution_vertex_id(target_task_id)
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
build_time)
output_actors_map[channel_name] = target_actor
if len(output_actors_map) > 0:
channel_str_ids = list(output_actors_map.keys())
target_actors = list(output_actors_map.values())
logger.info("Create DataWriter channel_ids {},"
"target_actors {}, output_points={}.".format(
channel_str_ids, target_actors,
op_checkpoint_info.output_points))
self.writer = DataWriter(channel_str_ids, target_actors,
channel_conf)
logger.info("Create DataWriter succeed channel_ids {}, "
"target_actors {}.".format(channel_str_ids,
target_actors))
for edge in execution_vertex_context.output_execution_edges:
collectors.append(
OutputCollector(self.writer, channel_str_ids,
target_actors, edge.partition))
# readers
input_actor_map = {}
for edge in execution_vertex_context.input_execution_edges:
source_task_id = edge.source_execution_vertex_id
source_actor = execution_vertex_context \
.get_source_actor_by_execution_vertex_id(source_task_id)
channel_name = ChannelID.gen_id(source_task_id, self.task_id,
build_time)
input_actor_map[channel_name] = source_actor
if len(input_actor_map) > 0:
channel_str_ids = list(input_actor_map.keys())
from_actors = list(input_actor_map.values())
logger.info("Create DataReader, channels {},"
"input_actors {}, input_points={}.".format(
channel_str_ids, from_actors,
op_checkpoint_info.input_points))
self.reader = DataReader(channel_str_ids, from_actors,
channel_conf)
def exit_handler():
# Make DataReader stop read data when MockQueue destructor
# gets called to avoid crash
self.cancel_task()
import atexit
atexit.register(exit_handler)
runtime_context = RuntimeContextImpl(
self.worker.task_id,
execution_vertex_context.execution_vertex.execution_vertex_index,
execution_vertex_context.get_parallelism(),
config=channel_conf,
job_config=channel_conf)
logger.info("open Processor {}".format(self.processor))
self.processor.open(collectors, runtime_context)
# immediately save cp. In case of FO in cp 0
# or use old cp in multi node FO.
self.__save_cp(op_checkpoint_info, self.last_checkpoint_id)
def recover(self, is_recreate: bool):
self.prepare_task(is_recreate)
recover_info = ChannelRecoverInfo()
if self.reader is not None:
recover_info = self.reader.get_channel_recover_info()
self.thread.start()
logger.info("Start operator success.")
return recover_info
@abstractmethod
def run(self):
pass
@abstractmethod
def cancel_task(self):
pass
@abstractmethod
def commit_trigger(self, barrier: Barrier) -> bool:
pass
class InputStreamTask(StreamTask):
"""Base class for stream tasks that execute a
:class:`runtime.processor.OneInputProcessor` or
:class:`runtime.processor.TwoInputProcessor` """
def commit_trigger(self, barrier):
raise RuntimeError(
"commit_trigger is only supported in SourceStreamTask.")
def __init__(self, task_id, processor_instance, worker,
last_checkpoint_id):
super().__init__(task_id, processor_instance, worker,
last_checkpoint_id)
self.running = True
self.stopped = False
self.read_timeout_millis = \
int(worker.config.get(Config.READ_TIMEOUT_MS,
Config.DEFAULT_READ_TIMEOUT_MS))
self.python_serializer = PythonSerializer()
self.cross_lang_serializer = CrossLangSerializer()
def run(self):
logger.info("Input task thread start.")
try:
while self.running:
self.worker.initial_state_lock.acquire()
try:
item = self.reader.read(self.read_timeout_millis)
self.is_initial_state = False
finally:
self.worker.initial_state_lock.release()
if item is None:
continue
if isinstance(item, DataMessage):
msg_data = item.body
type_id = msg_data[0]
if type_id == serialization.PYTHON_TYPE_ID:
msg = self.python_serializer.deserialize(msg_data[1:])
else:
msg = self.cross_lang_serializer.deserialize(
msg_data[1:])
self.processor.process(msg)
elif isinstance(item, CheckpointBarrier):
logger.info("Got barrier:{}".format(item))
logger.info("Start to do checkpoint {}.".format(
item.checkpoint_id))
input_points = item.get_input_checkpoints()
self.do_checkpoint(item.checkpoint_id, input_points)
logger.info("Do checkpoint {} success.".format(
item.checkpoint_id))
else:
raise RuntimeError(
"Unknown item type! item={}".format(item))
except ChannelInterruptException:
logger.info("queue has stopped.")
except BaseException as e:
logger.exception(
"Last success checkpointId={}, now occur error.".format(
self.last_checkpoint_id))
self.request_rollback(str(e))
logger.info("Source fetcher thread exit.")
self.stopped = True
def cancel_task(self):
self.running = False
while not self.stopped:
time.sleep(0.5)
pass
class OneInputStreamTask(InputStreamTask):
"""A stream task for executing :class:`runtime.processor.OneInputProcessor`
"""
def __init__(self, task_id, processor_instance, worker,
last_checkpoint_id):
super().__init__(task_id, processor_instance, worker,
last_checkpoint_id)
class SourceStreamTask(StreamTask):
"""A stream task for executing :class:`runtime.processor.SourceProcessor`
"""
processor: "SourceProcessor"
def __init__(self, task_id: int, processor_instance: "SourceProcessor",
worker: "JobWorker", last_checkpoint_id):
super().__init__(task_id, processor_instance, worker,
last_checkpoint_id)
self.running = True
self.stopped = False
self.__pending_barrier: Optional[Barrier] = None
def run(self):
logger.info("Source task thread start.")
try:
while self.running:
self.processor.fetch()
# check checkpoint
if self.__pending_barrier is not None:
# source fetcher only have outputPoints
barrier = self.__pending_barrier
logger.info("Start to do checkpoint {}.".format(
barrier.id))
self.do_checkpoint(barrier.id, barrier)
logger.info("Finish to do checkpoint {}.".format(
barrier.id))
self.__pending_barrier = None
except ChannelInterruptException:
logger.info("queue has stopped.")
except Exception as e:
logger.exception(
"Last success checkpointId={}, now occur error.".format(
self.last_checkpoint_id))
self.request_rollback(str(e))
logger.info("Source fetcher thread exit.")
self.stopped = True
def commit_trigger(self, barrier):
if self.__pending_barrier is not None:
logger.warning(
"Last barrier is not broadcast now, skip this barrier trigger."
)
return False
self.__pending_barrier = barrier
return True
def cancel_task(self):
self.running = False
while not self.stopped:
time.sleep(0.5)
pass