ray/rllib/utils/replay_buffers/utils.py

346 lines
14 KiB
Python

import logging
import psutil
from typing import Optional, Any
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import deprecation_warning
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.replay_buffers import (
MultiAgentPrioritizedReplayBuffer,
ReplayBuffer,
MultiAgentReplayBuffer,
)
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.typing import ResultDict, SampleBatchType, TrainerConfigDict
from ray.util import log_once
logger = logging.getLogger(__name__)
def update_priorities_in_replay_buffer(
replay_buffer: ReplayBuffer,
config: TrainerConfigDict,
train_batch: SampleBatchType,
train_results: ResultDict,
) -> None:
"""Updates the priorities in a prioritized replay buffer, given training results.
The `abs(TD-error)` from the loss (inside `train_results`) is used as new
priorities for the row-indices that were sampled for the train batch.
Don't do anything if the given buffer does not support prioritized replay.
Args:
replay_buffer: The replay buffer, whose priority values to update. This may also
be a buffer that does not support priorities.
config: The Trainer's config dict.
train_batch: The batch used for the training update.
train_results: A train results dict, generated by e.g. the `train_one_step()`
utility.
"""
# Only update priorities if buffer supports them.
if isinstance(replay_buffer, MultiAgentPrioritizedReplayBuffer):
# Go through training results for the different policies (maybe multi-agent).
prio_dict = {}
for policy_id, info in train_results.items():
# TODO(sven): This is currently structured differently for
# torch/tf. Clean up these results/info dicts across
# policies (note: fixing this in torch_policy.py will
# break e.g. DDPPO!).
td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error"))
# Set the get_interceptor to None in order to be able to access the numpy
# arrays directly (instead of e.g. a torch array).
train_batch.policy_batches[policy_id].set_get_interceptor(None)
# Get the replay buffer row indices that make up the `train_batch`.
batch_indices = train_batch.policy_batches[policy_id].get("batch_indexes")
if td_error is None:
if log_once(
"no_td_error_in_train_results_from_policy_{}".format(policy_id)
):
logger.warning(
"Trying to update priorities for policy with id `{}` in "
"prioritized replay buffer without providing td_errors in "
"train_results. Priority update for this policy is being "
"skipped.".format(policy_id)
)
continue
if batch_indices is None:
if log_once(
"no_batch_indices_in_train_result_for_policy_{}".format(policy_id)
):
logger.warning(
"Trying to update priorities for policy with id `{}` in "
"prioritized replay buffer without providing batch_indices in "
"train_batch. Priority update for this policy is being "
"skipped.".format(policy_id)
)
continue
# Try to transform batch_indices to td_error dimensions
if len(batch_indices) != len(td_error):
T = replay_buffer.replay_sequence_length
assert (
len(batch_indices) > len(td_error) and len(batch_indices) % T == 0
)
batch_indices = batch_indices.reshape([-1, T])[:, 0]
assert len(batch_indices) == len(td_error)
prio_dict[policy_id] = (batch_indices, td_error)
# Make the actual buffer API call to update the priority weights on all
# policies.
replay_buffer.update_priorities(prio_dict)
def sample_min_n_steps_from_buffer(
replay_buffer: ReplayBuffer, min_steps: int, count_by_agent_steps: bool
) -> Optional[SampleBatchType]:
"""Samples a minimum of n timesteps from a given replay buffer.
This utility method is primarily used by the QMIX algorithm and helps with
sampling a given number of time steps which has stored samples in units
of sequences or complete episodes. Samples n batches from replay buffer
until the total number of timesteps reaches `train_batch_size`.
Args:
replay_buffer: The replay buffer to sample from
num_timesteps: The number of timesteps to sample
count_by_agent_steps: Whether to count agent steps or env steps
Returns:
A concatenated SampleBatch or MultiAgentBatch with samples from the
buffer.
"""
train_batch_size = 0
train_batches = []
while train_batch_size < min_steps:
batch = replay_buffer.sample(num_items=1)
if batch is None:
return None
train_batches.append(batch)
train_batch_size += (
train_batches[-1].agent_steps()
if count_by_agent_steps
else train_batches[-1].env_steps()
)
# All batch types are the same type, hence we can use any concat_samples()
train_batch = SampleBatch.concat_samples(train_batches)
return train_batch
@ExperimentalAPI
def validate_buffer_config(config: dict):
if config.get("replay_buffer_config", None) is None:
config["replay_buffer_config"] = {}
if config.get("worker_side_prioritization", DEPRECATED_VALUE) != DEPRECATED_VALUE:
deprecation_warning(
old="config['worker_side_prioritization']",
new="config['replay_buffer_config']['worker_side_prioritization']",
error=True,
)
prioritized_replay = config.get("prioritized_replay", DEPRECATED_VALUE)
if prioritized_replay != DEPRECATED_VALUE:
deprecation_warning(
old="config['prioritized_replay'] or config['replay_buffer_config']["
"'prioritized_replay']",
help="Replay prioritization specified by config key. RLlib's new replay "
"buffer API requires setting `config["
"'replay_buffer_config']['type']`, e.g. `config["
"'replay_buffer_config']['type'] = "
"'MultiAgentPrioritizedReplayBuffer'` to change the default "
"behaviour.",
error=True,
)
capacity = config.get("buffer_size", DEPRECATED_VALUE)
if capacity == DEPRECATED_VALUE:
capacity = config["replay_buffer_config"].get("buffer_size", DEPRECATED_VALUE)
if capacity != DEPRECATED_VALUE:
deprecation_warning(
old="config['buffer_size'] or config['replay_buffer_config']["
"'buffer_size']",
new="config['replay_buffer_config']['capacity']",
error=True,
)
replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
if replay_burn_in != DEPRECATED_VALUE:
config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
deprecation_warning(
old="config['burn_in']",
help="config['replay_buffer_config']['replay_burn_in']",
)
# Deprecation of old-style replay buffer args
# Warnings before checking of we need local buffer so that algorithms
# Without local buffer also get warned
keys_with_deprecated_positions = [
"prioritized_replay_alpha",
"prioritized_replay_beta",
"prioritized_replay_eps",
"no_local_replay_buffer",
"replay_batch_size",
"replay_zero_init_states",
"learning_starts",
"replay_buffer_shards_colocated_with_driver",
]
for k in keys_with_deprecated_positions:
if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
deprecation_warning(
old="config['{}']".format(k),
help="config['replay_buffer_config']['{}']" "".format(k),
error=False,
)
# Copy values over to new location in config to support new
# and old configuration style.
if config.get("replay_buffer_config") is not None:
config["replay_buffer_config"][k] = config[k]
replay_mode = config["multiagent"].get("replay_mode", DEPRECATED_VALUE)
if replay_mode != DEPRECATED_VALUE:
deprecation_warning(
old="config['multiagent']['replay_mode']",
help="config['replay_buffer_config']['replay_mode']",
error=False,
)
config["replay_buffer_config"]["replay_mode"] = replay_mode
# Can't use DEPRECATED_VALUE here because this is also a deliberate
# value set for some algorithms
# TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
replay_sequence_length = config.get("replay_sequence_length", None)
if replay_sequence_length is not None:
config["replay_buffer_config"][
"replay_sequence_length"
] = replay_sequence_length
deprecation_warning(
old="config['replay_sequence_length']",
help="Replay sequence length specified at new "
"location config['replay_buffer_config']["
"'replay_sequence_length'] will be overwritten.",
error=False,
)
replay_buffer_config = config["replay_buffer_config"]
assert (
"type" in replay_buffer_config
), "Can not instantiate ReplayBuffer from config without 'type' key."
# Check if old replay buffer should be instantiated
buffer_type = config["replay_buffer_config"]["type"]
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
# Create valid full [module].[class] string for from_config
config["replay_buffer_config"]["type"] = (
"ray.rllib.utils.replay_buffers." + buffer_type
)
if config["replay_buffer_config"].get("replay_batch_size", None) is None:
# Fall back to train batch size if no replay batch size was provided
logger.info(
"No value for key `replay_batch_size` in replay_buffer_config. "
"config['replay_buffer_config']['replay_batch_size'] will be "
"automatically set to config['train_batch_size']"
)
config["replay_buffer_config"]["replay_batch_size"] = config["train_batch_size"]
# Instantiate a dummy buffer to fail early on misconfiguration and find out about
# inferred buffer class
dummy_buffer = from_config(buffer_type, config["replay_buffer_config"])
config["replay_buffer_config"]["type"] = type(dummy_buffer)
if hasattr(dummy_buffer, "update_priorities"):
if config["multiagent"]["replay_mode"] == "lockstep":
raise ValueError(
"Prioritized replay is not supported when replay_mode=lockstep."
)
elif config["replay_buffer_config"].get("replay_sequence_length", 0) > 1:
raise ValueError(
"Prioritized replay is not supported when "
"replay_sequence_length > 1."
)
else:
if config["replay_buffer_config"].get("worker_side_prioritization"):
raise ValueError(
"Worker side prioritization is not supported when "
"prioritized_replay=False."
)
def warn_replay_buffer_capacity(*, item: SampleBatchType, capacity: int) -> None:
"""Warn if the configured replay buffer capacity is too large for machine's memory.
Args:
item: A (example) item that's supposed to be added to the buffer.
This is used to compute the overall memory footprint estimate for the
buffer.
capacity: The capacity value of the buffer. This is interpreted as the
number of items (such as given `item`) that will eventually be stored in
the buffer.
Raises:
ValueError: If computed memory footprint for the buffer exceeds the machine's
RAM.
"""
if log_once("warn_replay_buffer_capacity"):
item_size = item.size_bytes()
psutil_mem = psutil.virtual_memory()
total_gb = psutil_mem.total / 1e9
mem_size = capacity * item_size / 1e9
msg = (
"Estimated max memory usage for replay buffer is {} GB "
"({} batches of size {}, {} bytes each), "
"available system memory is {} GB".format(
mem_size, capacity, item.count, item_size, total_gb
)
)
if mem_size > total_gb:
raise ValueError(msg)
elif mem_size > 0.2 * total_gb:
logger.warning(msg)
else:
logger.info(msg)
def patch_buffer_with_fake_sampling_method(
buffer: ReplayBuffer, fake_sample_output: SampleBatchType
) -> None:
"""Patch a ReplayBuffer such that we always sample fake_sample_output.
Transforms fake_sample_output into a MultiAgentBatch if it is not a
MultiAgentBatch and the buffer is a MultiAgentBuffer. This is useful for testing
purposes if we need deterministic sampling.
Args:
buffer: The buffer to be patched
fake_sample_output: The output to be sampled
"""
if isinstance(buffer, MultiAgentReplayBuffer) and not isinstance(
fake_sample_output, MultiAgentBatch
):
fake_sample_output = SampleBatch(fake_sample_output).as_multi_agent()
def fake_sample(_: Any, __: Any = None, **kwargs) -> Optional[SampleBatchType]:
"""Always returns a predefined batch.
Args:
_: dummy arg to match signature of sample() method
__: dummy arg to match signature of sample() method
**kwargs: dummy args to match signature of sample() method
Returns:
Predefined MultiAgentBatch fake_sample_output
"""
return fake_sample_output
buffer.sample = fake_sample