mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
382 lines
15 KiB
Python
382 lines
15 KiB
Python
import logging
|
|
import pickle
|
|
import threading
|
|
import time
|
|
import typing
|
|
from abc import ABC, abstractmethod
|
|
from typing import Optional
|
|
|
|
from ray.streaming.collector import OutputCollector
|
|
from ray.streaming.config import Config
|
|
from ray.streaming.context import RuntimeContextImpl
|
|
from ray.streaming.generated import remote_call_pb2
|
|
from ray.streaming.runtime import serialization
|
|
from ray.streaming.runtime.command import WorkerCommitReport
|
|
from ray.streaming.runtime.failover import Barrier, OpCheckpointInfo
|
|
from ray.streaming.runtime.remote_call import RemoteCallMst
|
|
from ray.streaming.runtime.serialization import \
|
|
PythonSerializer, CrossLangSerializer
|
|
from ray.streaming.runtime.transfer import CheckpointBarrier
|
|
from ray.streaming.runtime.transfer import DataMessage
|
|
from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader
|
|
from ray.streaming.runtime.transfer import ChannelRecoverInfo
|
|
from ray.streaming.runtime.transfer import ChannelInterruptException
|
|
|
|
if typing.TYPE_CHECKING:
|
|
from ray.streaming.runtime.worker import JobWorker
|
|
from ray.streaming.runtime.processor import Processor, SourceProcessor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class StreamTask(ABC):
|
|
"""Base class for all streaming tasks. Each task runs a processor."""
|
|
|
|
def __init__(self, task_id: int, processor: "Processor",
|
|
worker: "JobWorker", last_checkpoint_id: int):
|
|
self.worker_context = worker.worker_context
|
|
self.vertex_context = worker.execution_vertex_context
|
|
self.task_id = task_id
|
|
self.processor = processor
|
|
self.worker = worker
|
|
self.config: dict = worker.config
|
|
self.reader: Optional[DataReader] = None
|
|
self.writer: Optional[DataWriter] = None
|
|
self.is_initial_state = True
|
|
self.last_checkpoint_id: int = last_checkpoint_id
|
|
self.thread = threading.Thread(target=self.run, daemon=True)
|
|
|
|
def do_checkpoint(self, checkpoint_id: int, input_points):
|
|
logger.info("Start do checkpoint, cp id {}, inputPoints {}.".format(
|
|
checkpoint_id, input_points))
|
|
|
|
output_points = None
|
|
if self.writer is not None:
|
|
output_points = self.writer.get_output_checkpoints()
|
|
|
|
operator_checkpoint = self.processor.save_checkpoint()
|
|
op_checkpoint_info = OpCheckpointInfo(
|
|
operator_checkpoint, input_points, output_points, checkpoint_id)
|
|
self.__save_cp_state_and_report(op_checkpoint_info, checkpoint_id)
|
|
|
|
barrier_pb = remote_call_pb2.Barrier()
|
|
barrier_pb.id = checkpoint_id
|
|
byte_buffer = barrier_pb.SerializeToString()
|
|
if self.writer is not None:
|
|
self.writer.broadcast_barrier(checkpoint_id, byte_buffer)
|
|
logger.info("Operator checkpoint {} finish.".format(checkpoint_id))
|
|
|
|
def __save_cp_state_and_report(self, op_checkpoint_info, checkpoint_id):
|
|
logger.info(
|
|
"Start to save cp state and report, checkpoint id is {}.".format(
|
|
checkpoint_id))
|
|
self.__save_cp(op_checkpoint_info, checkpoint_id)
|
|
self.__report_commit(checkpoint_id)
|
|
self.last_checkpoint_id = checkpoint_id
|
|
|
|
def __save_cp(self, op_checkpoint_info, checkpoint_id):
|
|
logger.info("save operator cp, op_checkpoint_info={}".format(
|
|
op_checkpoint_info))
|
|
cp_bytes = pickle.dumps(op_checkpoint_info)
|
|
self.worker.context_backend.put(
|
|
self.__gen_op_checkpoint_key(checkpoint_id), cp_bytes)
|
|
|
|
def __report_commit(self, checkpoint_id: int):
|
|
logger.info("Report commit, checkpoint id {}.".format(checkpoint_id))
|
|
report = WorkerCommitReport(self.vertex_context.actor_id.binary(),
|
|
checkpoint_id)
|
|
RemoteCallMst.report_job_worker_commit(self.worker.master_actor,
|
|
report)
|
|
|
|
def clear_expired_cp_state(self, checkpoint_id):
|
|
cp_key = self.__gen_op_checkpoint_key(checkpoint_id)
|
|
self.worker.context_backend.remove(cp_key)
|
|
|
|
def clear_expired_queue_msg(self, checkpoint_id):
|
|
# clear operator checkpoint
|
|
if self.writer is not None:
|
|
self.writer.clear_checkpoint(checkpoint_id)
|
|
|
|
def request_rollback(self, exception_msg: str):
|
|
self.worker.request_rollback(exception_msg)
|
|
|
|
def __gen_op_checkpoint_key(self, checkpoint_id):
|
|
op_checkpoint_key = Config.JOB_WORKER_OP_CHECKPOINT_PREFIX_KEY + str(
|
|
self.vertex_context.job_name) + "_" + str(
|
|
self.vertex_context.exe_vertex_name) + "_" + str(checkpoint_id)
|
|
logger.info(
|
|
"Generate op checkpoint key {}. ".format(op_checkpoint_key))
|
|
return op_checkpoint_key
|
|
|
|
def prepare_task(self, is_recreate: bool):
|
|
logger.info(
|
|
"Preparing stream task, is_recreate={}.".format(is_recreate))
|
|
channel_conf = dict(self.worker.config)
|
|
channel_size = int(
|
|
self.worker.config.get(Config.CHANNEL_SIZE,
|
|
Config.CHANNEL_SIZE_DEFAULT))
|
|
channel_conf[Config.CHANNEL_SIZE] = channel_size
|
|
channel_conf[Config.CHANNEL_TYPE] = self.worker.config \
|
|
.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL)
|
|
|
|
execution_vertex_context = self.worker.execution_vertex_context
|
|
build_time = execution_vertex_context.build_time
|
|
|
|
# when use memory state, if actor throw exception, will miss state
|
|
op_checkpoint_info = OpCheckpointInfo()
|
|
|
|
cp_bytes = None
|
|
# get operator checkpoint
|
|
if is_recreate:
|
|
cp_key = self.__gen_op_checkpoint_key(self.last_checkpoint_id)
|
|
logger.info("Getting task checkpoints from state, "
|
|
"cpKey={}, checkpointId={}.".format(
|
|
cp_key, self.last_checkpoint_id))
|
|
cp_bytes = self.worker.context_backend.get(cp_key)
|
|
if cp_bytes is None:
|
|
msg = "Task recover failed, checkpoint is null!"\
|
|
"cpKey={}".format(cp_key)
|
|
raise RuntimeError(msg)
|
|
|
|
if cp_bytes is not None:
|
|
op_checkpoint_info = pickle.loads(cp_bytes)
|
|
self.processor.load_checkpoint(op_checkpoint_info.operator_point)
|
|
logger.info("Stream task recover from checkpoint state,"
|
|
"checkpoint bytes len={}, checkpointInfo={}.".format(
|
|
cp_bytes.__len__(), op_checkpoint_info))
|
|
|
|
# writers
|
|
collectors = []
|
|
output_actors_map = {}
|
|
for edge in execution_vertex_context.output_execution_edges:
|
|
target_task_id = edge.target_execution_vertex_id
|
|
target_actor = execution_vertex_context \
|
|
.get_target_actor_by_execution_vertex_id(target_task_id)
|
|
channel_name = ChannelID.gen_id(self.task_id, target_task_id,
|
|
build_time)
|
|
output_actors_map[channel_name] = target_actor
|
|
|
|
if len(output_actors_map) > 0:
|
|
channel_str_ids = list(output_actors_map.keys())
|
|
target_actors = list(output_actors_map.values())
|
|
logger.info("Create DataWriter channel_ids {},"
|
|
"target_actors {}, output_points={}.".format(
|
|
channel_str_ids, target_actors,
|
|
op_checkpoint_info.output_points))
|
|
self.writer = DataWriter(channel_str_ids, target_actors,
|
|
channel_conf)
|
|
logger.info("Create DataWriter succeed channel_ids {}, "
|
|
"target_actors {}.".format(channel_str_ids,
|
|
target_actors))
|
|
for edge in execution_vertex_context.output_execution_edges:
|
|
collectors.append(
|
|
OutputCollector(self.writer, channel_str_ids,
|
|
target_actors, edge.partition))
|
|
|
|
# readers
|
|
input_actor_map = {}
|
|
for edge in execution_vertex_context.input_execution_edges:
|
|
source_task_id = edge.source_execution_vertex_id
|
|
source_actor = execution_vertex_context \
|
|
.get_source_actor_by_execution_vertex_id(source_task_id)
|
|
channel_name = ChannelID.gen_id(source_task_id, self.task_id,
|
|
build_time)
|
|
input_actor_map[channel_name] = source_actor
|
|
|
|
if len(input_actor_map) > 0:
|
|
channel_str_ids = list(input_actor_map.keys())
|
|
from_actors = list(input_actor_map.values())
|
|
logger.info("Create DataReader, channels {},"
|
|
"input_actors {}, input_points={}.".format(
|
|
channel_str_ids, from_actors,
|
|
op_checkpoint_info.input_points))
|
|
self.reader = DataReader(channel_str_ids, from_actors,
|
|
channel_conf)
|
|
|
|
def exit_handler():
|
|
# Make DataReader stop read data when MockQueue destructor
|
|
# gets called to avoid crash
|
|
self.cancel_task()
|
|
|
|
import atexit
|
|
atexit.register(exit_handler)
|
|
|
|
runtime_context = RuntimeContextImpl(
|
|
self.worker.task_id,
|
|
execution_vertex_context.execution_vertex.execution_vertex_index,
|
|
execution_vertex_context.get_parallelism(),
|
|
config=channel_conf,
|
|
job_config=channel_conf)
|
|
logger.info("open Processor {}".format(self.processor))
|
|
self.processor.open(collectors, runtime_context)
|
|
|
|
# immediately save cp. In case of FO in cp 0
|
|
# or use old cp in multi node FO.
|
|
self.__save_cp(op_checkpoint_info, self.last_checkpoint_id)
|
|
|
|
def recover(self, is_recreate: bool):
|
|
self.prepare_task(is_recreate)
|
|
|
|
recover_info = ChannelRecoverInfo()
|
|
if self.reader is not None:
|
|
recover_info = self.reader.get_channel_recover_info()
|
|
|
|
self.thread.start()
|
|
|
|
logger.info("Start operator success.")
|
|
return recover_info
|
|
|
|
@abstractmethod
|
|
def run(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def cancel_task(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def commit_trigger(self, barrier: Barrier) -> bool:
|
|
pass
|
|
|
|
|
|
class InputStreamTask(StreamTask):
|
|
"""Base class for stream tasks that execute a
|
|
:class:`runtime.processor.OneInputProcessor` or
|
|
:class:`runtime.processor.TwoInputProcessor` """
|
|
|
|
def commit_trigger(self, barrier):
|
|
raise RuntimeError(
|
|
"commit_trigger is only supported in SourceStreamTask.")
|
|
|
|
def __init__(self, task_id, processor_instance, worker,
|
|
last_checkpoint_id):
|
|
super().__init__(task_id, processor_instance, worker,
|
|
last_checkpoint_id)
|
|
self.running = True
|
|
self.stopped = False
|
|
self.read_timeout_millis = \
|
|
int(worker.config.get(Config.READ_TIMEOUT_MS,
|
|
Config.DEFAULT_READ_TIMEOUT_MS))
|
|
self.python_serializer = PythonSerializer()
|
|
self.cross_lang_serializer = CrossLangSerializer()
|
|
|
|
def run(self):
|
|
logger.info("Input task thread start.")
|
|
try:
|
|
while self.running:
|
|
self.worker.initial_state_lock.acquire()
|
|
try:
|
|
item = self.reader.read(self.read_timeout_millis)
|
|
self.is_initial_state = False
|
|
finally:
|
|
self.worker.initial_state_lock.release()
|
|
|
|
if item is None:
|
|
continue
|
|
|
|
if isinstance(item, DataMessage):
|
|
msg_data = item.body
|
|
type_id = msg_data[0]
|
|
if type_id == serialization.PYTHON_TYPE_ID:
|
|
msg = self.python_serializer.deserialize(msg_data[1:])
|
|
else:
|
|
msg = self.cross_lang_serializer.deserialize(
|
|
msg_data[1:])
|
|
self.processor.process(msg)
|
|
elif isinstance(item, CheckpointBarrier):
|
|
logger.info("Got barrier:{}".format(item))
|
|
logger.info("Start to do checkpoint {}.".format(
|
|
item.checkpoint_id))
|
|
|
|
input_points = item.get_input_checkpoints()
|
|
|
|
self.do_checkpoint(item.checkpoint_id, input_points)
|
|
logger.info("Do checkpoint {} success.".format(
|
|
item.checkpoint_id))
|
|
else:
|
|
raise RuntimeError(
|
|
"Unknown item type! item={}".format(item))
|
|
|
|
except ChannelInterruptException:
|
|
logger.info("queue has stopped.")
|
|
except BaseException as e:
|
|
logger.exception(
|
|
"Last success checkpointId={}, now occur error.".format(
|
|
self.last_checkpoint_id))
|
|
self.request_rollback(str(e))
|
|
|
|
logger.info("Source fetcher thread exit.")
|
|
self.stopped = True
|
|
|
|
def cancel_task(self):
|
|
self.running = False
|
|
while not self.stopped:
|
|
time.sleep(0.5)
|
|
pass
|
|
|
|
|
|
class OneInputStreamTask(InputStreamTask):
|
|
"""A stream task for executing :class:`runtime.processor.OneInputProcessor`
|
|
"""
|
|
|
|
def __init__(self, task_id, processor_instance, worker,
|
|
last_checkpoint_id):
|
|
super().__init__(task_id, processor_instance, worker,
|
|
last_checkpoint_id)
|
|
|
|
|
|
class SourceStreamTask(StreamTask):
|
|
"""A stream task for executing :class:`runtime.processor.SourceProcessor`
|
|
"""
|
|
processor: "SourceProcessor"
|
|
|
|
def __init__(self, task_id: int, processor_instance: "SourceProcessor",
|
|
worker: "JobWorker", last_checkpoint_id):
|
|
super().__init__(task_id, processor_instance, worker,
|
|
last_checkpoint_id)
|
|
self.running = True
|
|
self.stopped = False
|
|
self.__pending_barrier: Optional[Barrier] = None
|
|
|
|
def run(self):
|
|
logger.info("Source task thread start.")
|
|
try:
|
|
while self.running:
|
|
self.processor.fetch()
|
|
# check checkpoint
|
|
if self.__pending_barrier is not None:
|
|
# source fetcher only have outputPoints
|
|
barrier = self.__pending_barrier
|
|
logger.info("Start to do checkpoint {}.".format(
|
|
barrier.id))
|
|
self.do_checkpoint(barrier.id, barrier)
|
|
logger.info("Finish to do checkpoint {}.".format(
|
|
barrier.id))
|
|
self.__pending_barrier = None
|
|
|
|
except ChannelInterruptException:
|
|
logger.info("queue has stopped.")
|
|
except Exception as e:
|
|
logger.exception(
|
|
"Last success checkpointId={}, now occur error.".format(
|
|
self.last_checkpoint_id))
|
|
self.request_rollback(str(e))
|
|
|
|
logger.info("Source fetcher thread exit.")
|
|
self.stopped = True
|
|
|
|
def commit_trigger(self, barrier):
|
|
if self.__pending_barrier is not None:
|
|
logger.warning(
|
|
"Last barrier is not broadcast now, skip this barrier trigger."
|
|
)
|
|
return False
|
|
|
|
self.__pending_barrier = barrier
|
|
return True
|
|
|
|
def cancel_task(self):
|
|
self.running = False
|
|
while not self.stopped:
|
|
time.sleep(0.5)
|
|
pass
|