ray/rllib/execution/replay_ops.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

213 lines
7.8 KiB
Python
Raw Normal View History

from typing import List, Any, Optional
import random
from ray.actor import ActorHandle
from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady
from ray.util.iter_metrics import SharedMetrics
from ray.rllib.utils.replay_buffers.replay_buffer import warn_replay_capacity
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer,
)
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
from ray.rllib.utils.typing import SampleBatchType
class StoreToReplayBuffer:
"""Callable that stores data into replay buffer actors.
If constructed with a local replay actor, data will be stored into that
buffer. If constructed with a list of replay actor handles, data will
be stored randomly among those actors.
This should be used with the .for_each() operator on a rollouts iterator.
The batch that was stored is returned.
Examples:
>>> from ray.rllib.utils.replay_buffers import multi_agent_replay_buffer
>>> from ray.rllib.execution.replay_ops import StoreToReplayBuffer
>>> from ray.rllib.execution import ParallelRollouts
>>> actors = [ # doctest: +SKIP
... multi_agent_replay_buffer.ReplayActor.remote() for _ in range(4)]
>>> rollouts = ParallelRollouts(...) # doctest: +SKIP
>>> store_op = rollouts.for_each( # doctest: +SKIP
... StoreToReplayBuffer(actors=actors))
>>> next(store_op) # doctest: +SKIP
SampleBatch(...)
"""
def __init__(
self,
*,
local_buffer: Optional[MultiAgentReplayBuffer] = None,
actors: Optional[List[ActorHandle]] = None,
):
"""
Args:
local_buffer: The local replay buffer to store the data into.
actors: An optional list of replay actors to use instead of
`local_buffer`.
"""
if local_buffer is not None and actors is not None:
raise ValueError(
"Either `local_buffer` or `replay_actors` must be given, not both!"
)
if local_buffer is not None:
self.local_actor = local_buffer
self.replay_actors = None
else:
self.local_actor = None
self.replay_actors = actors
def __call__(self, batch: SampleBatchType):
if self.local_actor is not None:
self.local_actor.add(batch)
else:
actor = random.choice(self.replay_actors)
actor.add.remote(batch)
return batch
def Replay(
*,
local_buffer: Optional[MultiAgentReplayBuffer] = None,
num_items_to_replay: int = 1,
actors: Optional[List[ActorHandle]] = None,
num_async: int = 4,
) -> LocalIterator[SampleBatchType]:
"""Replay experiences from the given buffer or actors.
This should be combined with the StoreToReplayActors operation using the
Concurrently() operator.
Args:
local_buffer: Local buffer to use. Only one of this and replay_actors
can be specified.
num_items_to_replay: Number of items to sample from buffer
actors: List of replay actors. Only one of this and local_buffer
can be specified.
num_async: In async mode, the max number of async requests in flight
per actor.
Examples:
>>> from ray.rllib.utils.replay_buffers import multi_agent_replay_buffer
>>> actors = [ # doctest: +SKIP
... multi_agent_replay_buffer.ReplayActor.remote() for _ in range(4)]
>>> replay_op = Replay(actors=actors, # doctest: +SKIP
... num_items_to_replay=batch_size)
>>> next(replay_op) # doctest: +SKIP
SampleBatch(...)
"""
if local_buffer is not None and actors is not None:
raise ValueError("Exactly one of local_buffer and replay_actors must be given.")
if actors is not None:
for actor in actors:
actor.make_iterator.remote(num_items_to_replay=num_items_to_replay)
replay = from_actors(actors)
return replay.gather_async(num_async=num_async).filter(lambda x: x is not None)
def gen_replay(_):
while True:
item = local_buffer.sample(num_items_to_replay)
if item is None:
yield _NextValueNotReady()
else:
yield item
return LocalIterator(gen_replay, SharedMetrics())
class WaitUntilTimestepsElapsed:
"""Callable that returns True once a given number of timesteps are hit."""
def __init__(self, target_num_timesteps: int):
self.target_num_timesteps = target_num_timesteps
def __call__(self, item: Any) -> bool:
metrics = _get_shared_metrics()
ts = metrics.counters[STEPS_SAMPLED_COUNTER]
return ts > self.target_num_timesteps
# TODO(ekl) deprecate this in favor of the replay_sequence_length option.
2020-05-12 13:07:19 -07:00
class SimpleReplayBuffer:
"""Simple replay buffer that operates over batches."""
2020-12-24 06:30:33 -08:00
def __init__(self, num_slots: int, replay_proportion: Optional[float] = None):
2020-05-12 13:07:19 -07:00
"""Initialize SimpleReplayBuffer.
Args:
num_slots: Number of batches to store in total.
2020-05-12 13:07:19 -07:00
"""
self.num_slots = num_slots
self.replay_batches = []
self.replay_index = 0
2020-12-24 06:30:33 -08:00
def add_batch(self, sample_batch: SampleBatchType) -> None:
warn_replay_capacity(item=sample_batch, num_items=self.num_slots)
2020-05-12 13:07:19 -07:00
if self.num_slots > 0:
if len(self.replay_batches) < self.num_slots:
self.replay_batches.append(sample_batch)
else:
self.replay_batches[self.replay_index] = sample_batch
self.replay_index += 1
self.replay_index %= self.num_slots
2020-12-24 06:30:33 -08:00
def replay(self) -> SampleBatchType:
2020-05-12 13:07:19 -07:00
return random.choice(self.replay_batches)
def __len__(self):
return len(self.replay_batches)
class MixInReplay:
"""This operator adds replay to a stream of experiences.
It takes input batches, and returns a list of batches that include replayed
data as well. The number of replayed batches is determined by the
configured replay proportion. The max age of a batch is determined by the
number of replay slots.
"""
def __init__(self, num_slots: int, replay_proportion: float):
"""Initialize MixInReplay.
Args:
num_slots: Number of batches to store in total.
replay_proportion: The input batch will be returned
and an additional number of batches proportional to this value
will be added as well.
Examples:
# replay proportion 2:1
>>> from ray.rllib.execution.replay_ops import MixInReplay
>>> rollouts = ... # doctest: +SKIP
>>> replay_op = MixInReplay( # doctest: +SKIP
... rollouts, 100, replay_proportion=2)
>>> print(next(replay_op)) # doctest: +SKIP
[SampleBatch(<input>), SampleBatch(<replay>), SampleBatch(<rep.>)]
# replay proportion 0:1, replay disabled
>>> replay_op = MixInReplay( # doctest: +SKIP
... rollouts, 100, replay_proportion=0)
>>> print(next(replay_op)) # doctest: +SKIP
[SampleBatch(<input>)]
"""
if replay_proportion > 0 and num_slots == 0:
raise ValueError("You must set num_slots > 0 if replay_proportion > 0.")
self.replay_buffer = SimpleReplayBuffer(num_slots)
self.replay_proportion = replay_proportion
def __call__(self, sample_batch: SampleBatchType) -> List[SampleBatchType]:
# Put in replay buffer if enabled.
self.replay_buffer.add_batch(sample_batch)
# Proportional replay.
output_batches = [sample_batch]
f = self.replay_proportion
while random.random() < f:
f -= 1
output_batches.append(self.replay_buffer.replay())
return output_batches