ray/streaming/python/runtime/worker.py

386 lines
15 KiB
Python

import enum
import logging.config
import os
import threading
import time
from typing import Optional
import ray
import ray.streaming.runtime.processor as processor
from ray.actor import ActorHandle
from ray.streaming.generated import remote_call_pb2
from ray.streaming.runtime.command import WorkerRollbackRequest
from ray.streaming.runtime.failover import Barrier
from ray.streaming.runtime.graph import ExecutionVertexContext, ExecutionVertex
from ray.streaming.runtime.remote_call import CallResult, RemoteCallMst
from ray.streaming.runtime.context_backend import ContextBackendFactory
from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask
from ray.streaming.runtime.transfer import channel_bytes_to_str
from ray.streaming.config import Config
import ray.streaming._streaming as _streaming
logger = logging.getLogger(__name__)
# special flag to indicate this actor not ready
_NOT_READY_FLAG_ = b" " * 4
@ray.remote
class JobWorker(object):
"""A streaming job worker is used to execute user-defined function and
interact with `JobMaster`"""
master_actor: Optional[ActorHandle]
worker_context: Optional[remote_call_pb2.PythonJobWorkerContext]
execution_vertex_context: Optional[ExecutionVertexContext]
__need_rollback: bool
def __init__(self, execution_vertex_pb_bytes):
logger.info("Creating job worker, pid={}".format(os.getpid()))
execution_vertex_pb = remote_call_pb2\
.ExecutionVertexContext.ExecutionVertex()
execution_vertex_pb.ParseFromString(execution_vertex_pb_bytes)
self.execution_vertex = ExecutionVertex(execution_vertex_pb)
self.config = self.execution_vertex.config
self.worker_context = None
self.execution_vertex_context = None
self.task_id = None
self.task = None
self.stream_processor = None
self.master_actor = None
self.context_backend = ContextBackendFactory.get_context_backend(
self.config)
self.initial_state_lock = threading.Lock()
self.__rollback_cnt: int = 0
self.__is_recreate: bool = False
self.__state = WorkerState()
self.__need_rollback = True
self.reader_client = None
self.writer_client = None
try:
# load checkpoint
was_reconstructed = ray.get_runtime_context(
).was_current_actor_reconstructed
logger.info(
"Worker was reconstructed: {}".format(was_reconstructed))
if was_reconstructed:
job_worker_context_key = self.__get_job_worker_context_key()
logger.info("Worker get checkpoint state by key: {}.".format(
job_worker_context_key))
context_bytes = self.context_backend.get(
job_worker_context_key)
if context_bytes is not None and context_bytes.__len__() > 0:
self.init(context_bytes)
self.request_rollback(
"Python worker recover from checkpoint.")
else:
logger.error(
"Error! Worker get checkpoint state by key {}"
" returns None, please check your state backend"
", only reliable state backend supports fail-over."
.format(job_worker_context_key))
except Exception:
logger.exception("Error in __init__ of JobWorker")
logger.info("Creating job worker succeeded. worker config {}".format(
self.config))
def init(self, worker_context_bytes):
logger.info("Start to init job worker")
try:
# deserialize context
worker_context = remote_call_pb2.PythonJobWorkerContext()
worker_context.ParseFromString(worker_context_bytes)
self.worker_context = worker_context
self.master_actor = ActorHandle._deserialization_helper(
worker_context.master_actor)
# build vertex context from pb
self.execution_vertex_context = ExecutionVertexContext(
worker_context.execution_vertex_context)
self.execution_vertex = self\
.execution_vertex_context.execution_vertex
# save context
job_worker_context_key = self.__get_job_worker_context_key()
self.context_backend.put(job_worker_context_key,
worker_context_bytes)
# use vertex id as task id
self.task_id = self.execution_vertex_context.get_task_id()
# build and get processor from operator
operator = self.execution_vertex_context.stream_operator
self.stream_processor = processor.build_processor(operator)
logger.info("Initializing job worker, exe_vertex_name={},"
"task_id: {}, operator: {}, pid={}".format(
self.execution_vertex_context.exe_vertex_name,
self.task_id, self.stream_processor, os.getpid()))
# get config from vertex
self.config = self.execution_vertex_context.config
if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL):
self.reader_client = _streaming.ReaderClient()
self.writer_client = _streaming.WriterClient()
logger.info("Job worker init succeeded.")
except Exception:
logger.exception("Error when init job worker.")
return False
return True
def create_stream_task(self, checkpoint_id):
if isinstance(self.stream_processor, processor.SourceProcessor):
return SourceStreamTask(self.task_id, self.stream_processor, self,
checkpoint_id)
elif isinstance(self.stream_processor, processor.OneInputProcessor):
return OneInputStreamTask(self.task_id, self.stream_processor,
self, checkpoint_id)
else:
raise Exception("Unsupported processor type: " +
str(type(self.stream_processor)))
def rollback(self, checkpoint_id_bytes):
checkpoint_id_pb = remote_call_pb2.CheckpointId()
checkpoint_id_pb.ParseFromString(checkpoint_id_bytes)
checkpoint_id = checkpoint_id_pb.checkpoint_id
logger.info("Start rollback, checkpoint_id={}".format(checkpoint_id))
self.__rollback_cnt += 1
if self.__rollback_cnt > 1:
self.__is_recreate = True
# skip useless rollback
self.initial_state_lock.acquire()
try:
if self.task is not None and self.task.thread.is_alive()\
and checkpoint_id == self.task.last_checkpoint_id\
and self.task.is_initial_state:
logger.info(
"Task is already in initial state, skip this rollback.")
return self.__gen_call_result(
CallResult.skipped(
"Task is already in initial state, skip this rollback."
))
finally:
self.initial_state_lock.release()
# restart task
try:
if self.task is not None:
# make sure the runner is closed
self.task.cancel_task()
del self.task
self.task = self.create_stream_task(checkpoint_id)
q_recover_info = self.task.recover(self.__is_recreate)
self.__state.set_type(StateType.RUNNING)
self.__need_rollback = False
logger.info(
"Rollback success, checkpoint is {}, qRecoverInfo is {}.".
format(checkpoint_id, q_recover_info))
return self.__gen_call_result(CallResult.success(q_recover_info))
except Exception:
logger.exception("Rollback has exception.")
return self.__gen_call_result(CallResult.fail())
def on_reader_message(self, *buffers):
"""Called by upstream queue writer to send data message to downstream
queue reader.
"""
if self.reader_client is None:
logger.info("reader_client is None, skip writer transfer")
return
self.reader_client.on_reader_message(*buffers)
def on_reader_message_sync(self, buffer: bytes):
"""Called by upstream queue writer to send
control message to downstream downstream queue reader.
"""
if self.reader_client is None:
logger.info("task is None, skip reader transfer")
return _NOT_READY_FLAG_
result = self.reader_client.on_reader_message_sync(buffer)
return result.to_pybytes()
def on_writer_message(self, buffer: bytes):
"""Called by downstream queue reader to send notify message to
upstream queue writer.
"""
if self.writer_client is None:
logger.info("writer_client is None, skip writer transfer")
return
self.writer_client.on_writer_message(buffer)
def on_writer_message_sync(self, buffer: bytes):
"""Called by downstream queue reader to send control message to
upstream queue writer.
"""
if self.writer_client is None:
return _NOT_READY_FLAG_
result = self.writer_client.on_writer_message_sync(buffer)
return result.to_pybytes()
def shutdown_without_reconstruction(self):
logger.info("Python worker shutdown without reconstruction.")
ray.actor.exit_actor()
def notify_checkpoint_timeout(self, checkpoint_id_bytes):
pass
def commit(self, barrier_bytes):
barrier_pb = remote_call_pb2.Barrier()
barrier_pb.ParseFromString(barrier_bytes)
barrier = Barrier(barrier_pb.id)
logger.info("Receive trigger, barrier is {}.".format(barrier))
if self.task is not None:
self.task.commit_trigger(barrier)
ret = remote_call_pb2.BoolResult()
ret.boolRes = True
return ret.SerializeToString()
def clear_expired_cp(self, state_checkpoint_id_bytes,
queue_checkpoint_id_bytes):
state_checkpoint_id = self.__parse_to_checkpoint_id(
state_checkpoint_id_bytes)
queue_checkpoint_id = self.__parse_to_checkpoint_id(
queue_checkpoint_id_bytes)
logger.info("Start to clear expired checkpoint, checkpoint_id={},"
"queue_checkpoint_id={}, exe_vertex_name={}.".format(
state_checkpoint_id, queue_checkpoint_id,
self.execution_vertex_context.exe_vertex_name))
ret = remote_call_pb2.BoolResult()
ret.boolRes = self.__clear_expired_cp_state(state_checkpoint_id) \
if state_checkpoint_id > 0 else True
ret.boolRes &= self.__clear_expired_queue_msg(queue_checkpoint_id)
logger.info(
"Clear expired checkpoint done, result={}, checkpoint_id={},"
"queue_checkpoint_id={}, exe_vertex_name={}.".format(
ret.boolRes, state_checkpoint_id, queue_checkpoint_id,
self.execution_vertex_context.exe_vertex_name))
return ret.SerializeToString()
def __clear_expired_cp_state(self, checkpoint_id):
if self.__need_rollback:
logger.warning("Need rollback, skip clear_expired_cp_state"
", checkpoint id: {}".format(checkpoint_id))
return False
logger.info("Clear expired checkpoint state, cp id is {}.".format(
checkpoint_id))
if self.task is not None:
self.task.clear_expired_cp_state(checkpoint_id)
return True
def __clear_expired_queue_msg(self, checkpoint_id):
if self.__need_rollback:
logger.warning("Need rollback, skip clear_expired_queue_msg"
", checkpoint id: {}".format(checkpoint_id))
return False
logger.info("Clear expired queue msg, checkpoint_id is {}.".format(
checkpoint_id))
if self.task is not None:
self.task.clear_expired_queue_msg(checkpoint_id)
return True
def __parse_to_checkpoint_id(self, checkpoint_id_bytes):
checkpoint_id_pb = remote_call_pb2.CheckpointId()
checkpoint_id_pb.ParseFromString(checkpoint_id_bytes)
return checkpoint_id_pb.checkpoint_id
def check_if_need_rollback(self):
ret = remote_call_pb2.BoolResult()
ret.boolRes = self.__need_rollback
return ret.SerializeToString()
def request_rollback(self, exception_msg="Python exception."):
logger.info("Request rollback.")
self.__need_rollback = True
self.__is_recreate = True
request_ret = False
for i in range(Config.REQUEST_ROLLBACK_RETRY_TIMES):
logger.info("request rollback {} time".format(i))
try:
request_ret = RemoteCallMst.request_job_worker_rollback(
self.master_actor,
WorkerRollbackRequest(
self.execution_vertex_context.actor_id.binary(),
"Exception msg=%s, retry time=%d." % (exception_msg,
i)))
except Exception:
logger.exception("Unexpected error when rollback")
logger.info("request rollback {} time, ret={}".format(
i, request_ret))
if not request_ret:
logger.warning(
"Request rollback return false"
", maybe it's invalid request, try to sleep 1s.")
time.sleep(1)
else:
break
if not request_ret:
logger.warning("Request failed after retry {} times,"
"now worker shutdown without reconstruction."
.format(Config.REQUEST_ROLLBACK_RETRY_TIMES))
self.shutdown_without_reconstruction()
self.__state.set_type(StateType.WAIT_ROLLBACK)
def __gen_call_result(self, call_result):
call_result_pb = remote_call_pb2.CallResult()
call_result_pb.success = call_result.success
call_result_pb.result_code = call_result.result_code.value
if call_result.result_msg is not None:
call_result_pb.result_msg = call_result.result_msg
if call_result.result_obj is not None:
q_recover_info = call_result.result_obj
for q, status in q_recover_info.get_creation_status().items():
call_result_pb.result_obj.creation_status[channel_bytes_to_str(
q)] = status.value
return call_result_pb.SerializeToString()
def _gen_unique_key(self, key_prefix):
return key_prefix \
+ str(self.config.get(Config.STREAMING_JOB_NAME)) \
+ "_" + str(self.execution_vertex.execution_vertex_id)
def __get_job_worker_context_key(self) -> str:
return self._gen_unique_key(Config.JOB_WORKER_CONTEXT_KEY)
class WorkerState:
"""
worker state
"""
def __init__(self):
self.__type = StateType.INIT
def set_type(self, type):
self.__type = type
def get_type(self):
return self.__type
class StateType(enum.Enum):
"""
state type
"""
INIT = 1
RUNNING = 2
WAIT_ROLLBACK = 3