ray/rllib/execution/buffers/mixin_replay_buffer.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

149 lines
5.8 KiB
Python
Raw Normal View History

import collections
import platform
import random
from typing import Optional
from ray.rllib.execution.replay_ops import SimpleReplayBuffer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.typing import PolicyID, SampleBatchType
class MixInMultiAgentReplayBuffer:
"""This buffer adds replayed samples to a stream of new experiences.
- Any newly added batch (`add_batch()`) is immediately returned upon
the next `replay` call (close to on-policy) as well as being moved
into the buffer.
- Additionally, a certain number of old samples is mixed into the
returned sample according to a given "replay ratio".
- If >1 calls to `add_batch()` are made without any `replay()` calls
in between, all newly added batches are returned (plus some older samples
according to the "replay ratio").
Examples:
>>> from ray.rllib.execution.replay_buffer import MixInMultiAgentReplayBuffer
# replay ratio 0.66 (2/3 replayed, 1/3 new samples):
>>> buffer = MixInMultiAgentReplayBuffer(capacity=100, # doctest: +SKIP
... replay_ratio=0.66) # doctest: +SKIP
>>> A, B, C, D = ... # doctest: +SKIP
>>> buffer.add_batch(A) # doctest: +SKIP
>>> buffer.add_batch(B) # doctest: +SKIP
>>> buffer.replay() # doctest: +SKIP
[A, B, B]
>>> buffer.add_batch(C) # doctest: +SKIP
>>> buffer.replay() # doctest: +SKIP
[C, A, B]
>>> # or: [C, A, A] or [C, B, B], but always C as it
>>> # is the newest sample
>>> buffer.add_batch(D) # doctest: +SKIP
>>> buffer.replay() # doctest: +SKIP
[D, A, C]
>>> # replay proportion 0.0 -> replay disabled:
>>> from ray.rllib.execution import MixInReplay
>>> buffer = MixInReplay(capacity=100, replay_ratio=0.0) # doctest: +SKIP
>>> buffer.add_batch(A) # doctest: +SKIP
>>> buffer.replay() # doctest: +SKIP
[A]
>>> buffer.add_batch(B) # doctest: +SKIP
>>> buffer.replay() # doctest: +SKIP
[B]
"""
def __init__(self, capacity: int, replay_ratio: float):
"""Initializes MixInReplay instance.
Args:
capacity: Number of batches to store in total.
replay_ratio: Ratio of replayed samples in the returned
batches. E.g. a ratio of 0.0 means only return new samples
(no replay), a ratio of 0.5 means always return newest sample
plus one old one (1:1), a ratio of 0.66 means always return
the newest sample plus 2 old (replayed) ones (1:2), etc...
"""
self.capacity = capacity
self.replay_ratio = replay_ratio
self.replay_proportion = None
if self.replay_ratio != 1.0:
self.replay_proportion = self.replay_ratio / (1.0 - self.replay_ratio)
def new_buffer():
return SimpleReplayBuffer(num_slots=capacity)
self.replay_buffers = collections.defaultdict(new_buffer)
# Metrics.
self.add_batch_timer = TimerStat()
self.replay_timer = TimerStat()
self.update_priorities_timer = TimerStat()
# Added timesteps over lifetime.
self.num_added = 0
# Last added batch(es).
self.last_added_batches = collections.defaultdict(list)
def add_batch(self, batch: SampleBatchType) -> None:
"""Adds a batch to the appropriate policy's replay buffer.
Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
it is not a MultiAgentBatch. Subsequently adds the individual policy
batches to the storage.
Args:
batch: The batch to be added.
"""
# Make a copy so the replay buffer doesn't pin plasma memory.
batch = batch.copy()
batch = batch.as_multi_agent()
with self.add_batch_timer:
for policy_id, sample_batch in batch.policy_batches.items():
self.replay_buffers[policy_id].add_batch(sample_batch)
self.last_added_batches[policy_id].append(sample_batch)
self.num_added += batch.count
def replay(
self, policy_id: PolicyID = DEFAULT_POLICY_ID
) -> Optional[SampleBatchType]:
buffer = self.replay_buffers[policy_id]
# Return None, if:
# - Buffer empty or
# - `replay_ratio` < 1.0 (new samples required in returned batch)
# and no new samples to mix with replayed ones.
if len(buffer) == 0 or (
len(self.last_added_batches[policy_id]) == 0 and self.replay_ratio < 1.0
):
return None
# Mix buffer's last added batches with older replayed batches.
with self.replay_timer:
output_batches = self.last_added_batches[policy_id]
self.last_added_batches[policy_id] = []
# No replay desired -> Return here.
if self.replay_ratio == 0.0:
return SampleBatch.concat_samples(output_batches)
# Only replay desired -> Return a (replayed) sample from the
# buffer.
elif self.replay_ratio == 1.0:
return buffer.replay()
# Replay ratio = old / [old + new]
# Replay proportion: old / new
num_new = len(output_batches)
replay_proportion = self.replay_proportion
while random.random() < num_new * replay_proportion:
replay_proportion -= 1
output_batches.append(buffer.replay())
return SampleBatch.concat_samples(output_batches)
def get_host(self) -> str:
"""Returns the computer's network name.
Returns:
The computer's networks name or an empty string, if the network
name could not be determined.
"""
return platform.node()