ray/streaming/python/includes/transfer.pxi
2020-02-08 13:01:28 +08:00

318 lines
No EOL
12 KiB
Cython

# flake8: noqa
from libc.stdint cimport *
from libcpp cimport bool as c_bool
from libcpp.memory cimport shared_ptr, make_shared, dynamic_pointer_cast
from libcpp.string cimport string as c_string
from libcpp.vector cimport vector as c_vector
from libcpp.list cimport list as c_list
from ray.includes.common cimport (
CRayFunction,
LANGUAGE_PYTHON,
CBuffer
)
from ray.includes.unique_ids cimport (
CActorID,
CObjectID
)
from ray._raylet cimport (
Buffer,
CoreWorker,
ActorID,
ObjectID,
FunctionDescriptor,
)
from ray.includes.libcoreworker cimport CCoreWorker
cimport ray.streaming.includes.libstreaming as libstreaming
from ray.streaming.includes.libstreaming cimport (
CStreamingStatus,
CStreamingMessage,
CStreamingMessageBundle,
CRuntimeContext,
CDataBundle,
CDataWriter,
CDataReader,
CReaderClient,
CWriterClient,
CLocalMemoryBuffer,
)
import logging
channel_logger = logging.getLogger(__name__)
cdef class ReaderClient:
cdef:
CReaderClient *client
def __cinit__(self,
CoreWorker worker,
FunctionDescriptor async_func,
FunctionDescriptor sync_func):
cdef:
CCoreWorker *core_worker = worker.core_worker.get()
CRayFunction async_native_func
CRayFunction sync_native_func
async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor)
sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor)
self.client = new CReaderClient(core_worker, async_native_func, sync_native_func)
def __dealloc__(self):
del self.client
self.client = NULL
def on_reader_message(self, const unsigned char[:] value):
cdef:
size_t size = value.nbytes
shared_ptr[CLocalMemoryBuffer] local_buf = \
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
with nogil:
self.client.OnReaderMessage(local_buf)
def on_reader_message_sync(self, const unsigned char[:] value):
cdef:
size_t size = value.nbytes
shared_ptr[CLocalMemoryBuffer] local_buf = \
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
shared_ptr[CLocalMemoryBuffer] result_buffer
with nogil:
result_buffer = self.client.OnReaderMessageSync(local_buf)
return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer))
cdef class WriterClient:
cdef:
CWriterClient * client
def __cinit__(self,
CoreWorker worker,
FunctionDescriptor async_func,
FunctionDescriptor sync_func):
cdef:
CCoreWorker *core_worker = worker.core_worker.get()
CRayFunction async_native_func
CRayFunction sync_native_func
async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor)
sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor)
self.client = new CWriterClient(core_worker, async_native_func, sync_native_func)
def __dealloc__(self):
del self.client
self.client = NULL
def on_writer_message(self, const unsigned char[:] value):
cdef:
size_t size = value.nbytes
shared_ptr[CLocalMemoryBuffer] local_buf = \
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
with nogil:
self.client.OnWriterMessage(local_buf)
def on_writer_message_sync(self, const unsigned char[:] value):
cdef:
size_t size = value.nbytes
shared_ptr[CLocalMemoryBuffer] local_buf = \
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
shared_ptr[CLocalMemoryBuffer] result_buffer
with nogil:
result_buffer = self.client.OnWriterMessageSync(local_buf)
return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer))
cdef class DataWriter:
cdef:
CDataWriter *writer
def __init__(self):
raise Exception("use create() to create DataWriter")
@staticmethod
def create(list py_output_channels,
list output_actor_ids: list[ActorID],
uint64_t queue_size,
list py_msg_ids,
bytes config_bytes,
c_bool is_mock):
cdef:
c_vector[CObjectID] channel_ids = bytes_list_to_qid_vec(py_output_channels)
c_vector[CActorID] actor_ids
c_vector[uint64_t] msg_ids
CDataWriter *c_writer
cdef const unsigned char[:] config_data
for actor_id in output_actor_ids:
actor_ids.push_back((<ActorID>actor_id).data)
for py_msg_id in py_msg_ids:
msg_ids.push_back(<uint64_t>py_msg_id)
cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]()
if is_mock:
ctx.get().MarkMockTest()
if config_bytes:
config_data = config_bytes
channel_logger.info("load config, config bytes size: %s", config_data.nbytes)
ctx.get().SetConfig(<uint8_t *>(&config_data[0]), config_data.nbytes)
c_writer = new CDataWriter(ctx)
cdef:
c_vector[CObjectID] remain_id_vec
c_vector[uint64_t] queue_size_vec
for i in range(channel_ids.size()):
queue_size_vec.push_back(queue_size)
cdef CStreamingStatus status = c_writer.Init(channel_ids, actor_ids, msg_ids, queue_size_vec)
if remain_id_vec.size() != 0:
channel_logger.warning("failed queue amounts => %s", remain_id_vec.size())
if <uint32_t>status != <uint32_t> libstreaming.StatusOK:
msg = "initialize writer failed, status={}".format(<uint32_t>status)
channel_logger.error(msg)
del c_writer
import ray.streaming.runtime.transfer as transfer
raise transfer.ChannelInitException(msg, qid_vector_to_list(remain_id_vec))
c_writer.Run()
channel_logger.info("create native writer succeed")
cdef DataWriter writer = DataWriter.__new__(DataWriter)
writer.writer = c_writer
return writer
def __dealloc__(self):
if self.writer != NULL:
del self.writer
channel_logger.info("deleted DataWriter")
self.writer = NULL
def write(self, ObjectID qid, const unsigned char[:] value):
"""support zero-copy bytes, bytearray, array of unsigned char"""
cdef:
CObjectID native_id = qid.data
uint64_t msg_id
uint8_t *data = <uint8_t *>(&value[0])
uint32_t size = value.nbytes
with nogil:
msg_id = self.writer.WriteMessageToBufferRing(native_id, data, size)
return msg_id
def stop(self):
self.writer.Stop()
channel_logger.info("stopped DataWriter")
cdef class DataReader:
cdef:
CDataReader *reader
readonly bytes meta
readonly bytes data
def __init__(self):
raise Exception("use create() to create DataReader")
@staticmethod
def create(list py_input_queues,
list input_actor_ids: list[ActorID],
list py_seq_ids,
list py_msg_ids,
int64_t timer_interval,
c_bool is_recreate,
bytes config_bytes,
c_bool is_mock):
cdef:
c_vector[CObjectID] queue_id_vec = bytes_list_to_qid_vec(py_input_queues)
c_vector[CActorID] actor_ids
c_vector[uint64_t] seq_ids
c_vector[uint64_t] msg_ids
CDataReader *c_reader
cdef const unsigned char[:] config_data
for actor_id in input_actor_ids:
actor_ids.push_back((<ActorID>actor_id).data)
for py_seq_id in py_seq_ids:
seq_ids.push_back(<uint64_t>py_seq_id)
for py_msg_id in py_msg_ids:
msg_ids.push_back(<uint64_t>py_msg_id)
cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]()
if config_bytes:
config_data = config_bytes
channel_logger.info("load config, config bytes size: %s", config_data.nbytes)
ctx.get().SetConfig(<uint8_t *>(&(config_data[0])), config_data.nbytes)
if is_mock:
ctx.get().MarkMockTest()
c_reader = new CDataReader(ctx)
c_reader.Init(queue_id_vec, actor_ids, seq_ids, msg_ids, timer_interval)
channel_logger.info("create native reader succeed")
cdef DataReader reader = DataReader.__new__(DataReader)
reader.reader = c_reader
return reader
def __dealloc__(self):
if self.reader != NULL:
del self.reader
channel_logger.info("deleted DataReader")
self.reader = NULL
def read(self, uint32_t timeout_millis):
cdef:
shared_ptr[CDataBundle] bundle
CStreamingStatus status
with nogil:
status = self.reader.GetBundle(timeout_millis, bundle)
cdef uint32_t bundle_type = <uint32_t>(bundle.get().meta.get().GetBundleType())
if <uint32_t> status != <uint32_t> libstreaming.StatusOK:
if <uint32_t> status == <uint32_t> libstreaming.StatusInterrupted:
# avoid cyclic import
import ray.streaming.runtime.transfer as transfer
raise transfer.ChannelInterruptException("reader interrupted")
elif <uint32_t> status == <uint32_t> libstreaming.StatusInitQueueFailed:
raise Exception("init channel failed")
elif <uint32_t> status == <uint32_t> libstreaming.StatusWaitQueueTimeOut:
raise Exception("wait channel object timeout")
cdef:
uint32_t msg_nums
CObjectID queue_id
c_list[shared_ptr[CStreamingMessage]] msg_list
list msgs = []
uint64_t timestamp
uint64_t msg_id
if bundle_type == <uint32_t> libstreaming.BundleTypeBundle:
msg_nums = bundle.get().meta.get().GetMessageListSize()
CStreamingMessageBundle.GetMessageListFromRawData(
bundle.get().data + libstreaming.kMessageBundleHeaderSize,
bundle.get().data_size - libstreaming.kMessageBundleHeaderSize,
msg_nums,
msg_list)
timestamp = bundle.get().meta.get().GetMessageBundleTs()
for msg in msg_list:
msg_bytes = msg.get().RawData()[:msg.get().GetDataSize()]
qid_bytes = queue_id.Binary()
msg_id = msg.get().GetMessageSeqId()
msgs.append((msg_bytes, msg_id, timestamp, qid_bytes))
return msgs
elif bundle_type == <uint32_t> libstreaming.BundleTypeEmpty:
return []
else:
raise Exception("Unsupported bundle type {}".format(bundle_type))
def stop(self):
self.reader.Stop()
channel_logger.info("stopped DataReader")
cdef c_vector[CObjectID] bytes_list_to_qid_vec(list py_queue_ids) except *:
assert len(py_queue_ids) > 0
cdef:
c_vector[CObjectID] queue_id_vec
c_string q_id_data
for q_id in py_queue_ids:
q_id_data = q_id
assert q_id_data.size() == CObjectID.Size()
obj_id = CObjectID.FromBinary(q_id_data)
queue_id_vec.push_back(obj_id)
return queue_id_vec
cdef c_vector[c_string] qid_vector_to_list(c_vector[CObjectID] queue_id_vec):
queues = []
for obj_id in queue_id_vec:
queues.append(obj_id.Binary())
return queues