ray/streaming/python/includes/transfer.pxi
2020-12-30 10:45:52 +08:00

397 lines
15 KiB
Cython

# flake8: noqa
from libc.stdint cimport *
from libcpp cimport bool as c_bool
from libcpp.memory cimport shared_ptr, make_shared, dynamic_pointer_cast
from libcpp.string cimport string as c_string
from libcpp.vector cimport vector as c_vector
from libcpp.list cimport list as c_list
from libcpp.unordered_map cimport unordered_map as c_unordered_map
from cython.operator cimport dereference, postincrement
from ray.includes.common cimport (
CRayFunction,
LANGUAGE_PYTHON,
LANGUAGE_JAVA,
CBuffer
)
from ray.includes.unique_ids cimport (
CActorID,
CObjectID
)
from ray._raylet cimport (
Buffer,
ActorID,
ObjectRef,
FunctionDescriptor,
)
cimport ray.streaming.includes.libstreaming as libstreaming
from ray.streaming.includes.libstreaming cimport (
CStreamingStatus,
CStreamingMessage,
CStreamingMessageBundle,
CRuntimeContext,
CDataBundle,
CDataWriter,
CDataReader,
CReaderClient,
CWriterClient,
CLocalMemoryBuffer,
CChannelCreationParameter,
CTransferCreationStatus,
CConsumerChannelInfo,
CStreamingBarrierHeader,
kBarrierHeaderSize,
)
from ray._raylet import JavaFunctionDescriptor
import logging
channel_logger = logging.getLogger(__name__)
cdef class ChannelCreationParameter:
cdef:
CChannelCreationParameter parameter
def __cinit__(self, ActorID actor_id, FunctionDescriptor async_func, FunctionDescriptor sync_func):
cdef:
shared_ptr[CRayFunction] async_func_ptr
shared_ptr[CRayFunction] sync_func_ptr
self.parameter = CChannelCreationParameter()
self.parameter.actor_id = (<ActorID>actor_id).data
if isinstance(async_func, JavaFunctionDescriptor):
self.parameter.async_function = make_shared[CRayFunction](LANGUAGE_JAVA, async_func.descriptor)
else:
self.parameter.async_function = make_shared[CRayFunction](LANGUAGE_PYTHON, async_func.descriptor)
if isinstance(sync_func, JavaFunctionDescriptor):
self.parameter.sync_function = make_shared[CRayFunction](LANGUAGE_JAVA, sync_func.descriptor)
else:
self.parameter.sync_function = make_shared[CRayFunction](LANGUAGE_PYTHON, sync_func.descriptor)
cdef CChannelCreationParameter get_parameter(self):
return self.parameter
cdef class ReaderClient:
cdef:
CReaderClient *client
def __cinit__(self):
self.client = new CReaderClient()
def __dealloc__(self):
del self.client
self.client = NULL
def on_reader_message(self, const unsigned char[:] value):
cdef:
size_t size = value.nbytes
shared_ptr[CLocalMemoryBuffer] local_buf = \
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
with nogil:
self.client.OnReaderMessage(local_buf)
def on_reader_message_sync(self, const unsigned char[:] value):
cdef:
size_t size = value.nbytes
shared_ptr[CLocalMemoryBuffer] local_buf = \
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
shared_ptr[CLocalMemoryBuffer] result_buffer
with nogil:
result_buffer = self.client.OnReaderMessageSync(local_buf)
return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer))
cdef class WriterClient:
cdef:
CWriterClient * client
def __cinit__(self):
self.client = new CWriterClient()
def __dealloc__(self):
del self.client
self.client = NULL
def on_writer_message(self, const unsigned char[:] value):
cdef:
size_t size = value.nbytes
shared_ptr[CLocalMemoryBuffer] local_buf = \
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
with nogil:
self.client.OnWriterMessage(local_buf)
def on_writer_message_sync(self, const unsigned char[:] value):
cdef:
size_t size = value.nbytes
shared_ptr[CLocalMemoryBuffer] local_buf = \
make_shared[CLocalMemoryBuffer](<uint8_t *>(&value[0]), size, True)
shared_ptr[CLocalMemoryBuffer] result_buffer
with nogil:
result_buffer = self.client.OnWriterMessageSync(local_buf)
return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer))
cdef class DataWriter:
cdef:
CDataWriter *writer
def __init__(self):
raise Exception("use create() to create DataWriter")
@staticmethod
def create(list py_output_channels,
list output_creation_parameters: list[ChannelCreationParameter],
uint64_t queue_size,
list py_msg_ids,
bytes config_bytes,
c_bool is_mock):
cdef:
c_vector[CObjectID] channel_ids = bytes_list_to_qid_vec(py_output_channels)
c_vector[CChannelCreationParameter] initial_parameters
c_vector[uint64_t] msg_ids
CDataWriter *c_writer
ChannelCreationParameter parameter
cdef const unsigned char[:] config_data
for param in output_creation_parameters:
parameter = param
initial_parameters.push_back(parameter.get_parameter())
for py_msg_id in py_msg_ids:
msg_ids.push_back(<uint64_t>py_msg_id)
cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]()
if is_mock:
ctx.get().MarkMockTest()
if config_bytes:
config_data = config_bytes
channel_logger.info("DataWriter load config, config bytes size: %s", config_data.nbytes)
ctx.get().SetConfig(<uint8_t *>(&config_data[0]), config_data.nbytes)
c_writer = new CDataWriter(ctx)
cdef:
c_vector[CObjectID] remain_id_vec
c_vector[uint64_t] queue_size_vec
for i in range(channel_ids.size()):
queue_size_vec.push_back(queue_size)
cdef CStreamingStatus status = c_writer.Init(channel_ids, initial_parameters, msg_ids, queue_size_vec)
if remain_id_vec.size() != 0:
channel_logger.warning("failed queue amounts => %s", remain_id_vec.size())
if <uint32_t>status != <uint32_t> libstreaming.StatusOK:
msg = "initialize writer failed, status={}".format(<uint32_t>status)
channel_logger.error(msg)
del c_writer
import ray.streaming.runtime.transfer as transfer
raise transfer.ChannelInitException(msg, qid_vector_to_list(remain_id_vec))
c_writer.Run()
channel_logger.info("create native writer succeed")
cdef DataWriter writer = DataWriter.__new__(DataWriter)
writer.writer = c_writer
return writer
def __dealloc__(self):
if self.writer != NULL:
del self.writer
channel_logger.info("deleted DataWriter")
self.writer = NULL
def write(self, ObjectRef qid, const unsigned char[:] value):
"""support zero-copy bytes, byte array, array of unsigned char"""
cdef:
CObjectID native_id = qid.data
uint64_t msg_id
uint8_t *data = <uint8_t *>(&value[0])
uint32_t size = value.nbytes
with nogil:
msg_id = self.writer.WriteMessageToBufferRing(native_id, data, size)
return msg_id
def broadcast_barrier(self, uint64_t checkpoint_id, const unsigned char[:] value):
cdef:
uint8_t *data = <uint8_t *>(&value[0])
uint32_t size = value.nbytes
with nogil:
self.writer.BroadcastBarrier(checkpoint_id, data, size)
def get_output_checkpoints(self):
cdef:
c_vector[uint64_t] results
self.writer.GetChannelOffset(results)
return results
def clear_checkpoint(self, checkpoint_id):
cdef:
uint64_t c_checkpoint_id = checkpoint_id
with nogil:
self.writer.ClearCheckpoint(c_checkpoint_id)
def stop(self):
self.writer.Stop()
channel_logger.info("stopped DataWriter")
cdef class DataReader:
cdef:
CDataReader *reader
readonly bytes meta
readonly bytes data
def __init__(self):
raise Exception("use create() to create DataReader")
@staticmethod
def create(list py_input_queues,
list input_creation_parameters: list[ChannelCreationParameter],
list py_msg_ids,
int64_t timer_interval,
bytes config_bytes,
c_bool is_mock):
cdef:
c_vector[CObjectID] queue_id_vec = bytes_list_to_qid_vec(py_input_queues)
c_vector[CChannelCreationParameter] initial_parameters
c_vector[uint64_t] msg_ids
c_vector[CTransferCreationStatus] c_creation_status
CDataReader *c_reader
ChannelCreationParameter parameter
cdef const unsigned char[:] config_data
for param in input_creation_parameters:
parameter = param
initial_parameters.push_back(parameter.get_parameter())
for py_msg_id in py_msg_ids:
msg_ids.push_back(<uint64_t>py_msg_id)
cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]()
if config_bytes:
config_data = config_bytes
channel_logger.info("DataReader load config, config bytes size: %s", config_data.nbytes)
ctx.get().SetConfig(<uint8_t *>(&(config_data[0])), config_data.nbytes)
if is_mock:
ctx.get().MarkMockTest()
c_reader = new CDataReader(ctx)
c_reader.Init(queue_id_vec, initial_parameters, msg_ids, c_creation_status, timer_interval)
creation_status_map = {}
if not c_creation_status.empty():
for i in range(queue_id_vec.size()):
k = queue_id_vec[i].Binary()
v = <uint64_t>c_creation_status[i]
creation_status_map[k] = v
channel_logger.info("create native reader succeed")
cdef DataReader reader = DataReader.__new__(DataReader)
reader.reader = c_reader
return reader, creation_status_map
def __dealloc__(self):
if self.reader != NULL:
del self.reader
channel_logger.info("deleted DataReader")
self.reader = NULL
def read(self, uint32_t timeout_millis):
cdef:
shared_ptr[CDataBundle] bundle
CStreamingStatus status
with nogil:
status = self.reader.GetBundle(timeout_millis, bundle)
if <uint32_t> status != <uint32_t> libstreaming.StatusOK:
if <uint32_t> status == <uint32_t> libstreaming.StatusInterrupted:
# avoid cyclic import
import ray.streaming.runtime.transfer as transfer
raise transfer.ChannelInterruptException("reader interrupted")
elif <uint32_t> status == <uint32_t> libstreaming.StatusInitQueueFailed:
import ray.streaming.runtime.transfer as transfer
raise transfer.ChannelInitException("init channel failed")
elif <uint32_t> status == <uint32_t> libstreaming.StatusGetBundleTimeOut:
return []
else:
raise Exception("no such status " + str(<uint32_t>status))
cdef:
uint32_t msg_nums
CObjectID queue_id = bundle.get().c_from
c_list[shared_ptr[CStreamingMessage]] msg_list
list msgs = []
uint64_t timestamp
uint64_t msg_id
c_unordered_map[CObjectID, CConsumerChannelInfo] *offset_map = NULL
shared_ptr[CStreamingMessage] barrier
CStreamingBarrierHeader barrier_header
c_unordered_map[CObjectID, CConsumerChannelInfo].iterator it
cdef uint32_t bundle_type = <uint32_t>(bundle.get().meta.get().GetBundleType())
# avoid cyclic import
from ray.streaming.runtime.transfer import DataMessage
if bundle_type == <uint32_t> libstreaming.BundleTypeBundle:
msg_nums = bundle.get().meta.get().GetMessageListSize()
CStreamingMessageBundle.GetMessageListFromRawData(
bundle.get().data + libstreaming.kMessageBundleHeaderSize,
bundle.get().data_size - libstreaming.kMessageBundleHeaderSize,
msg_nums,
msg_list)
timestamp = bundle.get().meta.get().GetMessageBundleTs()
for msg in msg_list:
msg_bytes = msg.get().Payload()[:msg.get().PayloadSize()]
qid_bytes = queue_id.Binary()
msg_id = msg.get().GetMessageId()
msgs.append(
DataMessage(msg_bytes, timestamp, msg_id, qid_bytes))
return msgs
elif bundle_type == <uint32_t> libstreaming.BundleTypeEmpty:
timestamp = bundle.get().meta.get().GetMessageBundleTs()
msg_id = bundle.get().meta.get().GetLastMessageId()
return [DataMessage(None, timestamp, msg_id, queue_id.Binary(), True)]
elif bundle.get().meta.get().IsBarrier():
py_offset_map = {}
self.reader.GetOffsetInfo(offset_map)
it = offset_map.begin()
while it != offset_map.end():
queue_id_bytes = dereference(it).first.Binary()
current_message_id = dereference(it).second.current_message_id
py_offset_map[queue_id_bytes] = current_message_id
postincrement(it)
msg_nums = bundle.get().meta.get().GetMessageListSize()
CStreamingMessageBundle.GetMessageListFromRawData(
bundle.get().data + libstreaming.kMessageBundleHeaderSize,
bundle.get().data_size - libstreaming.kMessageBundleHeaderSize,
msg_nums,
msg_list)
timestamp = bundle.get().meta.get().GetMessageBundleTs()
barrier = msg_list.front()
msg_id = barrier.get().GetMessageId()
CStreamingMessage.GetBarrierIdFromRawData(barrier.get().Payload(), &barrier_header)
barrier_id = barrier_header.barrier_id
barrier_data = (barrier.get().Payload() + kBarrierHeaderSize)[
:barrier.get().PayloadSize() - kBarrierHeaderSize]
barrier_type = <uint64_t> barrier_header.barrier_type
py_queue_id = queue_id.Binary()
from ray.streaming.runtime.transfer import CheckpointBarrier
return [CheckpointBarrier(
barrier_data, timestamp, msg_id, py_queue_id, py_offset_map,
barrier_id, barrier_type)]
else:
raise Exception("Unsupported bundle type {}".format(bundle_type))
def stop(self):
self.reader.Stop()
channel_logger.info("stopped DataReader")
cdef c_vector[CObjectID] bytes_list_to_qid_vec(list py_queue_ids) except *:
assert len(py_queue_ids) > 0
cdef:
c_vector[CObjectID] queue_id_vec
c_string q_id_data
for q_id in py_queue_ids:
q_id_data = q_id
assert q_id_data.size() == CObjectID.Size(), f"{q_id_data.size()}, {CObjectID.Size()}"
obj_id = CObjectID.FromBinary(q_id_data)
queue_id_vec.push_back(obj_id)
return queue_id_vec
cdef c_vector[c_string] qid_vector_to_list(c_vector[CObjectID] queue_id_vec):
queues = []
for obj_id in queue_id_vec:
queues.append(obj_id.Binary())
return queues