# flake8: noqa from libc.stdint cimport * from libcpp cimport bool as c_bool from libcpp.memory cimport shared_ptr, make_shared, dynamic_pointer_cast from libcpp.string cimport string as c_string from libcpp.vector cimport vector as c_vector from libcpp.list cimport list as c_list from ray.includes.common cimport ( CRayFunction, LANGUAGE_PYTHON, CBuffer ) from ray.includes.unique_ids cimport ( CActorID, CObjectID ) from ray._raylet cimport ( Buffer, CoreWorker, ActorID, ObjectID, FunctionDescriptor, ) from ray.includes.libcoreworker cimport CCoreWorker cimport ray.streaming.includes.libstreaming as libstreaming from ray.streaming.includes.libstreaming cimport ( CStreamingStatus, CStreamingMessage, CStreamingMessageBundle, CRuntimeContext, CDataBundle, CDataWriter, CDataReader, CReaderClient, CWriterClient, CLocalMemoryBuffer, ) import logging channel_logger = logging.getLogger(__name__) cdef class ReaderClient: cdef: CReaderClient *client def __cinit__(self, CoreWorker worker, FunctionDescriptor async_func, FunctionDescriptor sync_func): cdef: CCoreWorker *core_worker = worker.core_worker.get() CRayFunction async_native_func CRayFunction sync_native_func async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor) sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor) self.client = new CReaderClient(core_worker, async_native_func, sync_native_func) def __dealloc__(self): del self.client self.client = NULL def on_reader_message(self, const unsigned char[:] value): cdef: size_t size = value.nbytes shared_ptr[CLocalMemoryBuffer] local_buf = \ make_shared[CLocalMemoryBuffer]((&value[0]), size, True) with nogil: self.client.OnReaderMessage(local_buf) def on_reader_message_sync(self, const unsigned char[:] value): cdef: size_t size = value.nbytes shared_ptr[CLocalMemoryBuffer] local_buf = \ make_shared[CLocalMemoryBuffer]((&value[0]), size, True) shared_ptr[CLocalMemoryBuffer] result_buffer with nogil: result_buffer = self.client.OnReaderMessageSync(local_buf) return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer)) cdef class WriterClient: cdef: CWriterClient * client def __cinit__(self, CoreWorker worker, FunctionDescriptor async_func, FunctionDescriptor sync_func): cdef: CCoreWorker *core_worker = worker.core_worker.get() CRayFunction async_native_func CRayFunction sync_native_func async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor) sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor) self.client = new CWriterClient(core_worker, async_native_func, sync_native_func) def __dealloc__(self): del self.client self.client = NULL def on_writer_message(self, const unsigned char[:] value): cdef: size_t size = value.nbytes shared_ptr[CLocalMemoryBuffer] local_buf = \ make_shared[CLocalMemoryBuffer]((&value[0]), size, True) with nogil: self.client.OnWriterMessage(local_buf) def on_writer_message_sync(self, const unsigned char[:] value): cdef: size_t size = value.nbytes shared_ptr[CLocalMemoryBuffer] local_buf = \ make_shared[CLocalMemoryBuffer]((&value[0]), size, True) shared_ptr[CLocalMemoryBuffer] result_buffer with nogil: result_buffer = self.client.OnWriterMessageSync(local_buf) return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer)) cdef class DataWriter: cdef: CDataWriter *writer def __init__(self): raise Exception("use create() to create DataWriter") @staticmethod def create(list py_output_channels, list output_actor_ids: list[ActorID], uint64_t queue_size, list py_msg_ids, bytes config_bytes, c_bool is_mock): cdef: c_vector[CObjectID] channel_ids = bytes_list_to_qid_vec(py_output_channels) c_vector[CActorID] actor_ids c_vector[uint64_t] msg_ids CDataWriter *c_writer cdef const unsigned char[:] config_data for actor_id in output_actor_ids: actor_ids.push_back((actor_id).data) for py_msg_id in py_msg_ids: msg_ids.push_back(py_msg_id) cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]() if is_mock: ctx.get().MarkMockTest() if config_bytes: config_data = config_bytes channel_logger.info("load config, config bytes size: %s", config_data.nbytes) ctx.get().SetConfig((&config_data[0]), config_data.nbytes) c_writer = new CDataWriter(ctx) cdef: c_vector[CObjectID] remain_id_vec c_vector[uint64_t] queue_size_vec for i in range(channel_ids.size()): queue_size_vec.push_back(queue_size) cdef CStreamingStatus status = c_writer.Init(channel_ids, actor_ids, msg_ids, queue_size_vec) if remain_id_vec.size() != 0: channel_logger.warning("failed queue amounts => %s", remain_id_vec.size()) if status != libstreaming.StatusOK: msg = "initialize writer failed, status={}".format(status) channel_logger.error(msg) del c_writer import ray.streaming.runtime.transfer as transfer raise transfer.ChannelInitException(msg, qid_vector_to_list(remain_id_vec)) c_writer.Run() channel_logger.info("create native writer succeed") cdef DataWriter writer = DataWriter.__new__(DataWriter) writer.writer = c_writer return writer def __dealloc__(self): if self.writer != NULL: del self.writer channel_logger.info("deleted DataWriter") self.writer = NULL def write(self, ObjectID qid, const unsigned char[:] value): """support zero-copy bytes, bytearray, array of unsigned char""" cdef: CObjectID native_id = qid.data uint64_t msg_id uint8_t *data = (&value[0]) uint32_t size = value.nbytes with nogil: msg_id = self.writer.WriteMessageToBufferRing(native_id, data, size) return msg_id def stop(self): self.writer.Stop() channel_logger.info("stopped DataWriter") cdef class DataReader: cdef: CDataReader *reader readonly bytes meta readonly bytes data def __init__(self): raise Exception("use create() to create DataReader") @staticmethod def create(list py_input_queues, list input_actor_ids: list[ActorID], list py_seq_ids, list py_msg_ids, int64_t timer_interval, c_bool is_recreate, bytes config_bytes, c_bool is_mock): cdef: c_vector[CObjectID] queue_id_vec = bytes_list_to_qid_vec(py_input_queues) c_vector[CActorID] actor_ids c_vector[uint64_t] seq_ids c_vector[uint64_t] msg_ids CDataReader *c_reader cdef const unsigned char[:] config_data for actor_id in input_actor_ids: actor_ids.push_back((actor_id).data) for py_seq_id in py_seq_ids: seq_ids.push_back(py_seq_id) for py_msg_id in py_msg_ids: msg_ids.push_back(py_msg_id) cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]() if config_bytes: config_data = config_bytes channel_logger.info("load config, config bytes size: %s", config_data.nbytes) ctx.get().SetConfig((&(config_data[0])), config_data.nbytes) if is_mock: ctx.get().MarkMockTest() c_reader = new CDataReader(ctx) c_reader.Init(queue_id_vec, actor_ids, seq_ids, msg_ids, timer_interval) channel_logger.info("create native reader succeed") cdef DataReader reader = DataReader.__new__(DataReader) reader.reader = c_reader return reader def __dealloc__(self): if self.reader != NULL: del self.reader channel_logger.info("deleted DataReader") self.reader = NULL def read(self, uint32_t timeout_millis): cdef: shared_ptr[CDataBundle] bundle CStreamingStatus status with nogil: status = self.reader.GetBundle(timeout_millis, bundle) cdef uint32_t bundle_type = (bundle.get().meta.get().GetBundleType()) if status != libstreaming.StatusOK: if status == libstreaming.StatusInterrupted: # avoid cyclic import import ray.streaming.runtime.transfer as transfer raise transfer.ChannelInterruptException("reader interrupted") elif status == libstreaming.StatusInitQueueFailed: raise Exception("init channel failed") elif status == libstreaming.StatusWaitQueueTimeOut: raise Exception("wait channel object timeout") cdef: uint32_t msg_nums CObjectID queue_id c_list[shared_ptr[CStreamingMessage]] msg_list list msgs = [] uint64_t timestamp uint64_t msg_id if bundle_type == libstreaming.BundleTypeBundle: msg_nums = bundle.get().meta.get().GetMessageListSize() CStreamingMessageBundle.GetMessageListFromRawData( bundle.get().data + libstreaming.kMessageBundleHeaderSize, bundle.get().data_size - libstreaming.kMessageBundleHeaderSize, msg_nums, msg_list) timestamp = bundle.get().meta.get().GetMessageBundleTs() for msg in msg_list: msg_bytes = msg.get().RawData()[:msg.get().GetDataSize()] qid_bytes = queue_id.Binary() msg_id = msg.get().GetMessageSeqId() msgs.append((msg_bytes, msg_id, timestamp, qid_bytes)) return msgs elif bundle_type == libstreaming.BundleTypeEmpty: return [] else: raise Exception("Unsupported bundle type {}".format(bundle_type)) def stop(self): self.reader.Stop() channel_logger.info("stopped DataReader") cdef c_vector[CObjectID] bytes_list_to_qid_vec(list py_queue_ids) except *: assert len(py_queue_ids) > 0 cdef: c_vector[CObjectID] queue_id_vec c_string q_id_data for q_id in py_queue_ids: q_id_data = q_id assert q_id_data.size() == CObjectID.Size() obj_id = CObjectID.FromBinary(q_id_data) queue_id_vec.push_back(obj_id) return queue_id_vec cdef c_vector[c_string] qid_vector_to_list(c_vector[CObjectID] queue_id_vec): queues = [] for obj_id in queue_id_vec: queues.append(obj_id.Binary()) return queues