ray/streaming/python/partition.py
2020-06-18 15:11:07 +08:00

129 lines
3.7 KiB
Python

import importlib
import inspect
from abc import ABC, abstractmethod
from ray import cloudpickle
from ray.streaming.runtime import gateway_client
class Partition(ABC):
"""Interface of the partitioning strategy."""
@abstractmethod
def partition(self, record, num_partition: int):
"""Given a record and downstream partitions, determine which partition(s)
should receive the record.
Args:
record: The record.
num_partition: num of partitions
Returns:
IDs of the downstream partitions that should receive the record.
"""
pass
class BroadcastPartition(Partition):
"""Broadcast the record to all downstream partitions."""
def __init__(self):
self.__partitions = []
def partition(self, record, num_partition: int):
if len(self.__partitions) != num_partition:
self.__partitions = list(range(num_partition))
return self.__partitions
class KeyPartition(Partition):
"""Partition the record by the key."""
def __init__(self):
self.__partitions = [-1]
def partition(self, key_record, num_partition: int):
# TODO support key group
self.__partitions[0] = abs(hash(key_record.key)) % num_partition
return self.__partitions
class RoundRobinPartition(Partition):
"""Partition record to downstream tasks in a round-robin matter."""
def __init__(self):
self.__partitions = [-1]
self.seq = 0
def partition(self, key_record, num_partition: int):
self.seq = (self.seq + 1) % num_partition
self.__partitions[0] = self.seq
return self.__partitions
class ForwardPartition(Partition):
"""Default partition for operator if the operator can be chained with
succeeding operators."""
def __init__(self):
self.__partitions = [0]
def partition(self, key_record, num_partition: int):
return self.__partitions
class SimplePartition(Partition):
"""Wrap a python function as subclass of :class:`Partition`"""
def __init__(self, func):
self.func = func
def partition(self, record, num_partition: int):
return self.func(record, num_partition)
def serialize(partition_func):
"""
Serialize the partition function so that it can be deserialized by
:func:`deserialize`
"""
return cloudpickle.dumps(partition_func)
def deserialize(partition_bytes):
"""Deserialize the binary partition function serialized by
:func:`serialize`"""
return cloudpickle.loads(partition_bytes)
def load_partition(descriptor_partition_bytes: bytes):
"""
Deserialize `descriptor_partition_bytes` to get partition info, then
get or load partition function.
Note that this function must be kept in sync with
`io.ray.streaming.runtime.python.GraphPbBuilder.serializePartition`
Args:
descriptor_partition_bytes: serialized partition info
Returns:
partition function
"""
assert len(descriptor_partition_bytes) > 0
partition_bytes, module_name, function_name =\
gateway_client.deserialize(descriptor_partition_bytes)
if partition_bytes:
return deserialize(partition_bytes)
else:
assert module_name
mod = importlib.import_module(module_name)
assert function_name
func = getattr(mod, function_name)
# If func is a python function, user partition is a simple python
# function, which will be wrapped as a SimplePartition.
# If func is a python class, user partition is a sub class
# of Partition.
if inspect.isfunction(func):
return SimplePartition(func)
else:
assert issubclass(func, Partition)
return func()