ray/rllib/execution/concurrency_ops.py

149 lines
5.4 KiB
Python
Raw Normal View History

2020-12-24 06:30:33 -08:00
from typing import List, Optional, Any
import queue
from ray.util.iter import LocalIterator, _NextValueNotReady
from ray.util.iter_metrics import SharedMetrics
2020-12-24 06:30:33 -08:00
from ray.rllib.utils.typing import SampleBatchType
def Concurrently(
ops: List[LocalIterator],
*,
mode: str = "round_robin",
output_indexes: Optional[List[int]] = None,
round_robin_weights: Optional[List[int]] = None
) -> LocalIterator[SampleBatchType]:
"""Operator that runs the given parent iterators concurrently.
2020-09-20 11:27:02 +02:00
Args:
mode (str): One of 'round_robin', 'async'. In 'round_robin' mode,
we alternate between pulling items from each parent iterator in
order deterministically. In 'async' mode, we pull from each parent
iterator as fast as they are produced. This is non-deterministic.
output_indexes (list): If specified, only output results from the
given ops. For example, if ``output_indexes=[0]``, only results
from the first op in ops will be returned.
round_robin_weights (list): List of weights to use for round robin
mode. For example, ``[2, 1]`` will cause the iterator to pull twice
as many items from the first iterator as the second. ``[2, 1, *]``
will cause as many items to be pulled as possible from the third
iterator without blocking. This is only allowed in round robin
mode.
Examples:
>>> from ray.rllib.execution import ParallelRollouts
>>> sim_op = ParallelRollouts(...).for_each(...) # doctest: +SKIP
>>> replay_op = LocalReplay(...).for_each(...) # doctest: +SKIP
>>> combined_op = Concurrently( # doctest: +SKIP
... [sim_op, replay_op], mode="async")
"""
if len(ops) < 2:
raise ValueError("Should specify at least 2 ops.")
if mode == "round_robin":
deterministic = True
elif mode == "async":
deterministic = False
if round_robin_weights:
raise ValueError("round_robin_weights cannot be specified in async mode")
else:
raise ValueError("Unknown mode {}".format(mode))
if round_robin_weights and all(r == "*" for r in round_robin_weights):
raise ValueError("Cannot specify all round robin weights = *")
if output_indexes:
for i in output_indexes:
assert i in range(len(ops)), ("Index out of range", i)
def tag(op, i):
return op.for_each(lambda x: (i, x))
ops = [tag(op, i) for i, op in enumerate(ops)]
output = ops[0].union(
*ops[1:], deterministic=deterministic, round_robin_weights=round_robin_weights
)
if output_indexes:
output = output.filter(lambda tup: tup[0] in output_indexes).for_each(
lambda tup: tup[1]
)
return output
class Enqueue:
"""Enqueue data items into a queue.Queue instance.
Returns the input item as output.
The enqueue is non-blocking, so Enqueue operations can executed with
Dequeue via the Concurrently() operator.
Examples:
>>> import queue
>>> from ray.rllib.execution import ParallelRollouts
>>> queue = queue.Queue(100) # doctest: +SKIP
>>> write_op = ParallelRollouts(...).for_each(Enqueue(queue)) # doctest: +SKIP
>>> read_op = Dequeue(queue) # doctest: +SKIP
>>> combined_op = Concurrently( # doctest: +SKIP
... [write_op, read_op], mode="async")
>>> next(combined_op) # doctest: +SKIP
SampleBatch(...)
"""
def __init__(self, output_queue: queue.Queue):
if not isinstance(output_queue, queue.Queue):
raise ValueError("Expected queue.Queue, got {}".format(type(output_queue)))
self.queue = output_queue
2020-12-24 06:30:33 -08:00
def __call__(self, x: Any) -> Any:
try:
self.queue.put(x, timeout=0.001)
except queue.Full:
return _NextValueNotReady()
return x
def Dequeue(
input_queue: queue.Queue, check=lambda: True
) -> LocalIterator[SampleBatchType]:
"""Dequeue data items from a queue.Queue instance.
The dequeue is non-blocking, so Dequeue operations can execute with
Enqueue via the Concurrently() operator.
2020-09-20 11:27:02 +02:00
Args:
input_queue (Queue): queue to pull items from.
check (fn): liveness check. When this function returns false,
Dequeue() will raise an error to halt execution.
Examples:
>>> import queue
>>> from ray.rllib.execution import ParallelRollouts
>>> queue = queue.Queue(100) # doctest: +SKIP
>>> write_op = ParallelRollouts(...) # doctest: +SKIP
... .for_each(Enqueue(queue))
>>> read_op = Dequeue(queue) # doctest: +SKIP
>>> combined_op = Concurrently( # doctest: +SKIP
... [write_op, read_op], mode="async")
>>> next(combined_op) # doctest: +SKIP
SampleBatch(...)
"""
if not isinstance(input_queue, queue.Queue):
raise ValueError("Expected queue.Queue, got {}".format(type(input_queue)))
def base_iterator(timeout=None):
while check():
try:
item = input_queue.get(timeout=0.001)
yield item
except queue.Empty:
yield _NextValueNotReady()
raise RuntimeError(
"Dequeue `check()` returned False! "
"Exiting with Exception from Dequeue iterator."
)
return LocalIterator(base_iterator, SharedMetrics())