mirror of
https://github.com/vale981/ray
synced 2025-03-09 12:56:46 -04:00
284 lines
11 KiB
Python
284 lines
11 KiB
Python
![]() |
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import hashlib
|
||
|
import logging
|
||
|
import pickle
|
||
|
import sys
|
||
|
import time
|
||
|
|
||
|
import ray
|
||
|
import ray.streaming.runtime.transfer as transfer
|
||
|
from ray.streaming.config import Config
|
||
|
from ray.streaming.operator import PStrategy
|
||
|
from ray.streaming.runtime.transfer import ChannelID
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
logging.basicConfig(level=logging.INFO)
|
||
|
|
||
|
# Forward and broadcast stream partitioning strategies
|
||
|
forward_broadcast_strategies = [PStrategy.Forward, PStrategy.Broadcast]
|
||
|
|
||
|
|
||
|
# Used to choose output channel in case of hash-based shuffling
|
||
|
def _hash(value):
|
||
|
if isinstance(value, int):
|
||
|
return value
|
||
|
try:
|
||
|
return int(hashlib.sha1(value.encode("utf-8")).hexdigest(), 16)
|
||
|
except AttributeError:
|
||
|
return int(hashlib.sha1(value).hexdigest(), 16)
|
||
|
|
||
|
|
||
|
class DataChannel(object):
|
||
|
"""A data channel for actor-to-actor communication.
|
||
|
|
||
|
Attributes:
|
||
|
env (Environment): The environment the channel belongs to.
|
||
|
src_operator_id (UUID): The id of the source operator of the channel.
|
||
|
src_instance_index (int): The id of the source instance.
|
||
|
dst_operator_id (UUID): The id of the destination operator of the
|
||
|
channel.
|
||
|
dst_instance_index (int): The id of the destination instance.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, src_operator_id, src_instance_index, dst_operator_id,
|
||
|
dst_instance_index, str_qid):
|
||
|
self.src_operator_id = src_operator_id
|
||
|
self.src_instance_index = src_instance_index
|
||
|
self.dst_operator_id = dst_operator_id
|
||
|
self.dst_instance_index = dst_instance_index
|
||
|
self.str_qid = str_qid
|
||
|
self.qid = ChannelID(str_qid)
|
||
|
|
||
|
def __repr__(self):
|
||
|
return "(src({},{}),dst({},{}), qid({}))".format(
|
||
|
self.src_operator_id, self.src_instance_index,
|
||
|
self.dst_operator_id, self.dst_instance_index, self.str_qid)
|
||
|
|
||
|
|
||
|
_CLOSE_FLAG = b" "
|
||
|
|
||
|
|
||
|
# Pulls and merges data from multiple input channels
|
||
|
class DataInput(object):
|
||
|
"""An input gate of an operator instance.
|
||
|
|
||
|
The input gate pulls records from all input channels in a round-robin
|
||
|
fashion.
|
||
|
|
||
|
Attributes:
|
||
|
input_channels (list): The list of input channels.
|
||
|
channel_index (int): The index of the next channel to pull from.
|
||
|
max_index (int): The number of input channels.
|
||
|
closed (list): A list of flags indicating whether an input channel
|
||
|
has been marked as 'closed'.
|
||
|
all_closed (bool): Denotes whether all input channels have been
|
||
|
closed (True) or not (False).
|
||
|
"""
|
||
|
|
||
|
def __init__(self, env, channels):
|
||
|
assert len(channels) > 0
|
||
|
self.env = env
|
||
|
self.reader = None # created in `init` method
|
||
|
self.input_channels = channels
|
||
|
self.channel_index = 0
|
||
|
self.max_index = len(channels)
|
||
|
# Tracks the channels that have been closed. qid: close status
|
||
|
self.closed = {}
|
||
|
|
||
|
def init(self):
|
||
|
channels = [c.str_qid for c in self.input_channels]
|
||
|
input_actors = []
|
||
|
for c in self.input_channels:
|
||
|
actor = self.env.execution_graph.get_actor(c.src_operator_id,
|
||
|
c.src_instance_index)
|
||
|
input_actors.append(actor)
|
||
|
logger.info("DataInput input_actors %s", input_actors)
|
||
|
conf = {
|
||
|
Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context()
|
||
|
.current_driver_id,
|
||
|
Config.CHANNEL_TYPE: self.env.config.channel_type
|
||
|
}
|
||
|
self.reader = transfer.DataReader(channels, input_actors, conf)
|
||
|
|
||
|
def pull(self):
|
||
|
# pull from channel
|
||
|
item = self.reader.read(100)
|
||
|
while item is None:
|
||
|
time.sleep(0.001)
|
||
|
item = self.reader.read(100)
|
||
|
msg_data = item.body()
|
||
|
if msg_data == _CLOSE_FLAG:
|
||
|
self.closed[item.channel_id] = True
|
||
|
if len(self.closed) == len(self.input_channels):
|
||
|
return None
|
||
|
else:
|
||
|
return self.pull()
|
||
|
else:
|
||
|
return pickle.loads(msg_data)
|
||
|
|
||
|
def close(self):
|
||
|
self.reader.stop()
|
||
|
|
||
|
|
||
|
# Selects output channel(s) and pushes data
|
||
|
class DataOutput(object):
|
||
|
"""An output gate of an operator instance.
|
||
|
|
||
|
The output gate pushes records to output channels according to the
|
||
|
user-defined partitioning scheme.
|
||
|
|
||
|
Attributes:
|
||
|
partitioning_schemes (dict): A mapping from destination operator ids
|
||
|
to partitioning schemes (see: PScheme in operator.py).
|
||
|
forward_channels (list): A list of channels to forward records.
|
||
|
shuffle_channels (list(list)): A list of output channels to shuffle
|
||
|
records grouped by destination operator.
|
||
|
shuffle_key_channels (list(list)): A list of output channels to
|
||
|
shuffle records by a key grouped by destination operator.
|
||
|
shuffle_exists (bool): A flag indicating that there exists at least
|
||
|
one shuffle_channel.
|
||
|
shuffle_key_exists (bool): A flag indicating that there exists at
|
||
|
least one shuffle_key_channel.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, env, channels, partitioning_schemes):
|
||
|
assert len(channels) > 0
|
||
|
self.env = env
|
||
|
self.writer = None # created in `init` method
|
||
|
self.channels = channels
|
||
|
self.key_selector = None
|
||
|
self.round_robin_indexes = [0]
|
||
|
self.partitioning_schemes = partitioning_schemes
|
||
|
# Prepare output -- collect channels by type
|
||
|
self.forward_channels = [] # Forward and broadcast channels
|
||
|
slots = sum(1 for scheme in self.partitioning_schemes.values()
|
||
|
if scheme.strategy == PStrategy.RoundRobin)
|
||
|
self.round_robin_channels = [[]] * slots # RoundRobin channels
|
||
|
self.round_robin_indexes = [-1] * slots
|
||
|
slots = sum(1 for scheme in self.partitioning_schemes.values()
|
||
|
if scheme.strategy == PStrategy.Shuffle)
|
||
|
# Flag used to avoid hashing when there is no shuffling
|
||
|
self.shuffle_exists = slots > 0
|
||
|
self.shuffle_channels = [[]] * slots # Shuffle channels
|
||
|
slots = sum(1 for scheme in self.partitioning_schemes.values()
|
||
|
if scheme.strategy == PStrategy.ShuffleByKey)
|
||
|
# Flag used to avoid hashing when there is no shuffling by key
|
||
|
self.shuffle_key_exists = slots > 0
|
||
|
self.shuffle_key_channels = [[]] * slots # Shuffle by key channels
|
||
|
# Distinct shuffle destinations
|
||
|
shuffle_destinations = {}
|
||
|
# Distinct shuffle by key destinations
|
||
|
shuffle_by_key_destinations = {}
|
||
|
# Distinct round robin destinations
|
||
|
round_robin_destinations = {}
|
||
|
index_1 = 0
|
||
|
index_2 = 0
|
||
|
index_3 = 0
|
||
|
for channel in channels:
|
||
|
p_scheme = self.partitioning_schemes[channel.dst_operator_id]
|
||
|
strategy = p_scheme.strategy
|
||
|
if strategy in forward_broadcast_strategies:
|
||
|
self.forward_channels.append(channel)
|
||
|
elif strategy == PStrategy.Shuffle:
|
||
|
pos = shuffle_destinations.setdefault(channel.dst_operator_id,
|
||
|
index_1)
|
||
|
self.shuffle_channels[pos].append(channel)
|
||
|
if pos == index_1:
|
||
|
index_1 += 1
|
||
|
elif strategy == PStrategy.ShuffleByKey:
|
||
|
pos = shuffle_by_key_destinations.setdefault(
|
||
|
channel.dst_operator_id, index_2)
|
||
|
self.shuffle_key_channels[pos].append(channel)
|
||
|
if pos == index_2:
|
||
|
index_2 += 1
|
||
|
elif strategy == PStrategy.RoundRobin:
|
||
|
pos = round_robin_destinations.setdefault(
|
||
|
channel.dst_operator_id, index_3)
|
||
|
self.round_robin_channels[pos].append(channel)
|
||
|
if pos == index_3:
|
||
|
index_3 += 1
|
||
|
else: # TODO (john): Add support for other strategies
|
||
|
sys.exit("Unrecognized or unsupported partitioning strategy.")
|
||
|
# A KeyedDataStream can only be shuffled by key
|
||
|
assert not (self.shuffle_exists and self.shuffle_key_exists)
|
||
|
|
||
|
def init(self):
|
||
|
"""init DataOutput which creates DataWriter"""
|
||
|
channel_ids = [c.str_qid for c in self.channels]
|
||
|
to_actors = []
|
||
|
for c in self.channels:
|
||
|
actor = self.env.execution_graph.get_actor(c.dst_operator_id,
|
||
|
c.dst_instance_index)
|
||
|
to_actors.append(actor)
|
||
|
logger.info("DataOutput output_actors %s", to_actors)
|
||
|
|
||
|
conf = {
|
||
|
Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context()
|
||
|
.current_driver_id,
|
||
|
Config.CHANNEL_TYPE: self.env.config.channel_type
|
||
|
}
|
||
|
self.writer = transfer.DataWriter(channel_ids, to_actors, conf)
|
||
|
|
||
|
def close(self):
|
||
|
"""Close the channel (True) by propagating _CLOSE_FLAG
|
||
|
|
||
|
_CLOSE_FLAG is used as special type of record that is propagated from
|
||
|
sources to sink to notify that the end of data in a stream.
|
||
|
"""
|
||
|
for c in self.channels:
|
||
|
self.writer.write(c.qid, _CLOSE_FLAG)
|
||
|
# must ensure DataWriter send None flag to peer actor
|
||
|
self.writer.stop()
|
||
|
|
||
|
def push(self, record):
|
||
|
target_channels = []
|
||
|
# Forward record
|
||
|
for c in self.forward_channels:
|
||
|
logger.debug("[writer] Push record '{}' to channel {}".format(
|
||
|
record, c))
|
||
|
target_channels.append(c)
|
||
|
# Forward record
|
||
|
index = 0
|
||
|
for channels in self.round_robin_channels:
|
||
|
self.round_robin_indexes[index] += 1
|
||
|
if self.round_robin_indexes[index] == len(channels):
|
||
|
self.round_robin_indexes[index] = 0 # Reset index
|
||
|
c = channels[self.round_robin_indexes[index]]
|
||
|
logger.debug("[writer] Push record '{}' to channel {}".format(
|
||
|
record, c))
|
||
|
target_channels.append(c)
|
||
|
index += 1
|
||
|
# Hash-based shuffling by key
|
||
|
if self.shuffle_key_exists:
|
||
|
key, _ = record
|
||
|
h = _hash(key)
|
||
|
for channels in self.shuffle_key_channels:
|
||
|
num_instances = len(channels) # Downstream instances
|
||
|
c = channels[h % num_instances]
|
||
|
logger.debug(
|
||
|
"[key_shuffle] Push record '{}' to channel {}".format(
|
||
|
record, c))
|
||
|
target_channels.append(c)
|
||
|
elif self.shuffle_exists: # Hash-based shuffling per destination
|
||
|
h = _hash(record)
|
||
|
for channels in self.shuffle_channels:
|
||
|
num_instances = len(channels) # Downstream instances
|
||
|
c = channels[h % num_instances]
|
||
|
logger.debug("[shuffle] Push record '{}' to channel {}".format(
|
||
|
record, c))
|
||
|
target_channels.append(c)
|
||
|
else: # TODO (john): Handle rescaling
|
||
|
pass
|
||
|
|
||
|
msg_data = pickle.dumps(record)
|
||
|
for c in target_channels:
|
||
|
# send data to channel
|
||
|
self.writer.write(c.qid, msg_data)
|
||
|
|
||
|
def push_all(self, records):
|
||
|
for record in records:
|
||
|
self.push(record)
|