mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
318 lines
No EOL
12 KiB
Cython
318 lines
No EOL
12 KiB
Cython
# flake8: noqa
|
|
|
|
from libc.stdint cimport *
|
|
from libcpp cimport bool as c_bool
|
|
from libcpp.memory cimport shared_ptr, make_shared, dynamic_pointer_cast
|
|
from libcpp.string cimport string as c_string
|
|
from libcpp.vector cimport vector as c_vector
|
|
from libcpp.list cimport list as c_list
|
|
|
|
from ray.includes.common cimport (
|
|
CRayFunction,
|
|
LANGUAGE_PYTHON,
|
|
CBuffer
|
|
)
|
|
|
|
from ray.includes.unique_ids cimport (
|
|
CActorID,
|
|
CObjectID
|
|
)
|
|
from ray._raylet cimport (
|
|
Buffer,
|
|
CoreWorker,
|
|
ActorID,
|
|
ObjectID,
|
|
FunctionDescriptor,
|
|
)
|
|
|
|
from ray.includes.libcoreworker cimport CCoreWorker
|
|
|
|
cimport ray.streaming.includes.libstreaming as libstreaming
|
|
from ray.streaming.includes.libstreaming cimport (
|
|
CStreamingStatus,
|
|
CStreamingMessage,
|
|
CStreamingMessageBundle,
|
|
CRuntimeContext,
|
|
CDataBundle,
|
|
CDataWriter,
|
|
CDataReader,
|
|
CReaderClient,
|
|
CWriterClient,
|
|
CLocalMemoryBuffer,
|
|
)
|
|
|
|
import logging
|
|
|
|
|
|
channel_logger = logging.getLogger(__name__)
|
|
|
|
|
|
cdef class ReaderClient:
|
|
cdef:
|
|
CReaderClient *client
|
|
|
|
def __cinit__(self,
|
|
CoreWorker worker,
|
|
FunctionDescriptor async_func,
|
|
FunctionDescriptor sync_func):
|
|
cdef:
|
|
CCoreWorker *core_worker = worker.core_worker.get()
|
|
CRayFunction async_native_func
|
|
CRayFunction sync_native_func
|
|
async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor)
|
|
sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor)
|
|
self.client = new CReaderClient(core_worker, async_native_func, sync_native_func)
|
|
|
|
def __dealloc__(self):
|
|
del self.client
|
|
self.client = NULL
|
|
|
|
def on_reader_message(self, const unsigned char[:] value):
|
|
cdef:
|
|
size_t size = value.nbytes
|
|
shared_ptr[CLocalMemoryBuffer] local_buf = \
|
|
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
|
|
with nogil:
|
|
self.client.OnReaderMessage(local_buf)
|
|
|
|
def on_reader_message_sync(self, const unsigned char[:] value):
|
|
cdef:
|
|
size_t size = value.nbytes
|
|
shared_ptr[CLocalMemoryBuffer] local_buf = \
|
|
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
|
|
shared_ptr[CLocalMemoryBuffer] result_buffer
|
|
with nogil:
|
|
result_buffer = self.client.OnReaderMessageSync(local_buf)
|
|
return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer))
|
|
|
|
|
|
cdef class WriterClient:
|
|
cdef:
|
|
CWriterClient * client
|
|
|
|
def __cinit__(self,
|
|
CoreWorker worker,
|
|
FunctionDescriptor async_func,
|
|
FunctionDescriptor sync_func):
|
|
cdef:
|
|
CCoreWorker *core_worker = worker.core_worker.get()
|
|
CRayFunction async_native_func
|
|
CRayFunction sync_native_func
|
|
async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor)
|
|
sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor)
|
|
self.client = new CWriterClient(core_worker, async_native_func, sync_native_func)
|
|
|
|
def __dealloc__(self):
|
|
del self.client
|
|
self.client = NULL
|
|
|
|
def on_writer_message(self, const unsigned char[:] value):
|
|
cdef:
|
|
size_t size = value.nbytes
|
|
shared_ptr[CLocalMemoryBuffer] local_buf = \
|
|
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
|
|
with nogil:
|
|
self.client.OnWriterMessage(local_buf)
|
|
|
|
def on_writer_message_sync(self, const unsigned char[:] value):
|
|
cdef:
|
|
size_t size = value.nbytes
|
|
shared_ptr[CLocalMemoryBuffer] local_buf = \
|
|
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
|
|
shared_ptr[CLocalMemoryBuffer] result_buffer
|
|
with nogil:
|
|
result_buffer = self.client.OnWriterMessageSync(local_buf)
|
|
return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer))
|
|
|
|
|
|
cdef class DataWriter:
|
|
cdef:
|
|
CDataWriter *writer
|
|
|
|
def __init__(self):
|
|
raise Exception("use create() to create DataWriter")
|
|
|
|
@staticmethod
|
|
def create(list py_output_channels,
|
|
list output_actor_ids: list[ActorID],
|
|
uint64_t queue_size,
|
|
list py_msg_ids,
|
|
bytes config_bytes,
|
|
c_bool is_mock):
|
|
cdef:
|
|
c_vector[CObjectID] channel_ids = bytes_list_to_qid_vec(py_output_channels)
|
|
c_vector[CActorID] actor_ids
|
|
c_vector[uint64_t] msg_ids
|
|
CDataWriter *c_writer
|
|
cdef const unsigned char[:] config_data
|
|
for actor_id in output_actor_ids:
|
|
actor_ids.push_back((<ActorID>actor_id).data)
|
|
for py_msg_id in py_msg_ids:
|
|
msg_ids.push_back(<uint64_t>py_msg_id)
|
|
|
|
cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]()
|
|
if is_mock:
|
|
ctx.get().MarkMockTest()
|
|
if config_bytes:
|
|
config_data = config_bytes
|
|
channel_logger.info("load config, config bytes size: %s", config_data.nbytes)
|
|
ctx.get().SetConfig(<uint8_t *>(&config_data[0]), config_data.nbytes)
|
|
c_writer = new CDataWriter(ctx)
|
|
cdef:
|
|
c_vector[CObjectID] remain_id_vec
|
|
c_vector[uint64_t] queue_size_vec
|
|
for i in range(channel_ids.size()):
|
|
queue_size_vec.push_back(queue_size)
|
|
cdef CStreamingStatus status = c_writer.Init(channel_ids, actor_ids, msg_ids, queue_size_vec)
|
|
if remain_id_vec.size() != 0:
|
|
channel_logger.warning("failed queue amounts => %s", remain_id_vec.size())
|
|
if <uint32_t>status != <uint32_t> libstreaming.StatusOK:
|
|
msg = "initialize writer failed, status={}".format(<uint32_t>status)
|
|
channel_logger.error(msg)
|
|
del c_writer
|
|
import ray.streaming.runtime.transfer as transfer
|
|
raise transfer.ChannelInitException(msg, qid_vector_to_list(remain_id_vec))
|
|
|
|
c_writer.Run()
|
|
channel_logger.info("create native writer succeed")
|
|
cdef DataWriter writer = DataWriter.__new__(DataWriter)
|
|
writer.writer = c_writer
|
|
return writer
|
|
|
|
def __dealloc__(self):
|
|
if self.writer != NULL:
|
|
del self.writer
|
|
channel_logger.info("deleted DataWriter")
|
|
self.writer = NULL
|
|
|
|
def write(self, ObjectID qid, const unsigned char[:] value):
|
|
"""support zero-copy bytes, bytearray, array of unsigned char"""
|
|
cdef:
|
|
CObjectID native_id = qid.data
|
|
uint64_t msg_id
|
|
uint8_t *data = <uint8_t *>(&value[0])
|
|
uint32_t size = value.nbytes
|
|
with nogil:
|
|
msg_id = self.writer.WriteMessageToBufferRing(native_id, data, size)
|
|
return msg_id
|
|
|
|
def stop(self):
|
|
self.writer.Stop()
|
|
channel_logger.info("stopped DataWriter")
|
|
|
|
|
|
cdef class DataReader:
|
|
cdef:
|
|
CDataReader *reader
|
|
readonly bytes meta
|
|
readonly bytes data
|
|
|
|
def __init__(self):
|
|
raise Exception("use create() to create DataReader")
|
|
|
|
@staticmethod
|
|
def create(list py_input_queues,
|
|
list input_actor_ids: list[ActorID],
|
|
list py_seq_ids,
|
|
list py_msg_ids,
|
|
int64_t timer_interval,
|
|
c_bool is_recreate,
|
|
bytes config_bytes,
|
|
c_bool is_mock):
|
|
cdef:
|
|
c_vector[CObjectID] queue_id_vec = bytes_list_to_qid_vec(py_input_queues)
|
|
c_vector[CActorID] actor_ids
|
|
c_vector[uint64_t] seq_ids
|
|
c_vector[uint64_t] msg_ids
|
|
CDataReader *c_reader
|
|
cdef const unsigned char[:] config_data
|
|
for actor_id in input_actor_ids:
|
|
actor_ids.push_back((<ActorID>actor_id).data)
|
|
for py_seq_id in py_seq_ids:
|
|
seq_ids.push_back(<uint64_t>py_seq_id)
|
|
for py_msg_id in py_msg_ids:
|
|
msg_ids.push_back(<uint64_t>py_msg_id)
|
|
cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]()
|
|
if config_bytes:
|
|
config_data = config_bytes
|
|
channel_logger.info("load config, config bytes size: %s", config_data.nbytes)
|
|
ctx.get().SetConfig(<uint8_t *>(&(config_data[0])), config_data.nbytes)
|
|
if is_mock:
|
|
ctx.get().MarkMockTest()
|
|
c_reader = new CDataReader(ctx)
|
|
c_reader.Init(queue_id_vec, actor_ids, seq_ids, msg_ids, timer_interval)
|
|
channel_logger.info("create native reader succeed")
|
|
cdef DataReader reader = DataReader.__new__(DataReader)
|
|
reader.reader = c_reader
|
|
return reader
|
|
|
|
def __dealloc__(self):
|
|
if self.reader != NULL:
|
|
del self.reader
|
|
channel_logger.info("deleted DataReader")
|
|
self.reader = NULL
|
|
|
|
def read(self, uint32_t timeout_millis):
|
|
cdef:
|
|
shared_ptr[CDataBundle] bundle
|
|
CStreamingStatus status
|
|
with nogil:
|
|
status = self.reader.GetBundle(timeout_millis, bundle)
|
|
cdef uint32_t bundle_type = <uint32_t>(bundle.get().meta.get().GetBundleType())
|
|
if <uint32_t> status != <uint32_t> libstreaming.StatusOK:
|
|
if <uint32_t> status == <uint32_t> libstreaming.StatusInterrupted:
|
|
# avoid cyclic import
|
|
import ray.streaming.runtime.transfer as transfer
|
|
raise transfer.ChannelInterruptException("reader interrupted")
|
|
elif <uint32_t> status == <uint32_t> libstreaming.StatusInitQueueFailed:
|
|
raise Exception("init channel failed")
|
|
elif <uint32_t> status == <uint32_t> libstreaming.StatusWaitQueueTimeOut:
|
|
raise Exception("wait channel object timeout")
|
|
cdef:
|
|
uint32_t msg_nums
|
|
CObjectID queue_id
|
|
c_list[shared_ptr[CStreamingMessage]] msg_list
|
|
list msgs = []
|
|
uint64_t timestamp
|
|
uint64_t msg_id
|
|
if bundle_type == <uint32_t> libstreaming.BundleTypeBundle:
|
|
msg_nums = bundle.get().meta.get().GetMessageListSize()
|
|
CStreamingMessageBundle.GetMessageListFromRawData(
|
|
bundle.get().data + libstreaming.kMessageBundleHeaderSize,
|
|
bundle.get().data_size - libstreaming.kMessageBundleHeaderSize,
|
|
msg_nums,
|
|
msg_list)
|
|
timestamp = bundle.get().meta.get().GetMessageBundleTs()
|
|
for msg in msg_list:
|
|
msg_bytes = msg.get().RawData()[:msg.get().GetDataSize()]
|
|
qid_bytes = queue_id.Binary()
|
|
msg_id = msg.get().GetMessageSeqId()
|
|
msgs.append((msg_bytes, msg_id, timestamp, qid_bytes))
|
|
return msgs
|
|
elif bundle_type == <uint32_t> libstreaming.BundleTypeEmpty:
|
|
return []
|
|
else:
|
|
raise Exception("Unsupported bundle type {}".format(bundle_type))
|
|
|
|
def stop(self):
|
|
self.reader.Stop()
|
|
channel_logger.info("stopped DataReader")
|
|
|
|
|
|
cdef c_vector[CObjectID] bytes_list_to_qid_vec(list py_queue_ids) except *:
|
|
assert len(py_queue_ids) > 0
|
|
cdef:
|
|
c_vector[CObjectID] queue_id_vec
|
|
c_string q_id_data
|
|
for q_id in py_queue_ids:
|
|
q_id_data = q_id
|
|
assert q_id_data.size() == CObjectID.Size()
|
|
obj_id = CObjectID.FromBinary(q_id_data)
|
|
queue_id_vec.push_back(obj_id)
|
|
return queue_id_vec
|
|
|
|
cdef c_vector[c_string] qid_vector_to_list(c_vector[CObjectID] queue_id_vec):
|
|
queues = []
|
|
for obj_id in queue_id_vec:
|
|
queues.append(obj_id.Binary())
|
|
return queues |