ray/rllib/utils/replay_buffers/utils.py

363 lines
15 KiB
Python

import logging
import psutil
from typing import Optional
from ray.rllib.execution import MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer as LegacyMultiAgentReplayBuffer,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import deprecation_warning
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.replay_buffers import (
MultiAgentPrioritizedReplayBuffer,
ReplayBuffer,
)
from ray.rllib.utils.typing import ResultDict, SampleBatchType, TrainerConfigDict
from ray.util import log_once
logger = logging.getLogger(__name__)
def update_priorities_in_replay_buffer(
replay_buffer: ReplayBuffer,
config: TrainerConfigDict,
train_batch: SampleBatchType,
train_results: ResultDict,
) -> None:
"""Updates the priorities in a prioritized replay buffer, given training results.
The `abs(TD-error)` from the loss (inside `train_results`) is used as new
priorities for the row-indices that were sampled for the train batch.
Don't do anything if the given buffer does not support prioritized replay.
Args:
replay_buffer: The replay buffer, whose priority values to update. This may also
be a buffer that does not support priorities.
config: The Trainer's config dict.
train_batch: The batch used for the training update.
train_results: A train results dict, generated by e.g. the `train_one_step()`
utility.
"""
# Only update priorities if buffer supports them.
if (
type(replay_buffer) is LegacyMultiAgentReplayBuffer
and config["replay_buffer_config"].get("prioritized_replay_alpha", 0.0) > 0.0
) or isinstance(replay_buffer, MultiAgentPrioritizedReplayBuffer):
# Go through training results for the different policies (maybe multi-agent).
prio_dict = {}
for policy_id, info in train_results.items():
# TODO(sven): This is currently structured differently for
# torch/tf. Clean up these results/info dicts across
# policies (note: fixing this in torch_policy.py will
# break e.g. DDPPO!).
td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error"))
# Set the get_interceptor to None in order to be able to access the numpy
# arrays directly (instead of e.g. a torch array).
train_batch.policy_batches[policy_id].set_get_interceptor(None)
# Get the replay buffer row indices that make up the `train_batch`.
batch_indices = train_batch.policy_batches[policy_id].get("batch_indexes")
# In case the buffer stores sequences, TD-error could
# already be calculated per sequence chunk.
if len(batch_indices) != len(td_error):
T = replay_buffer.replay_sequence_length
assert (
len(batch_indices) > len(td_error) and len(batch_indices) % T == 0
)
batch_indices = batch_indices.reshape([-1, T])[:, 0]
assert len(batch_indices) == len(td_error)
prio_dict[policy_id] = (batch_indices, td_error)
# Make the actual buffer API call to update the priority weights on all
# policies.
replay_buffer.update_priorities(prio_dict)
def sample_min_n_steps_from_buffer(
replay_buffer: ReplayBuffer, min_steps: int, count_by_agent_steps: bool
) -> Optional[SampleBatchType]:
"""Samples a minimum of n timesteps from a given replay buffer.
This utility method is primarily used by the QMIX algorithm and helps with
sampling a given number of time steps which has stored samples in units
of sequences or complete episodes. Samples n batches from replay buffer
until the total number of timesteps reaches `train_batch_size`.
Args:
replay_buffer: The replay buffer to sample from
num_timesteps: The number of timesteps to sample
count_by_agent_steps: Whether to count agent steps or env steps
Returns:
A concatenated SampleBatch or MultiAgentBatch with samples from the
buffer.
"""
train_batch_size = 0
train_batches = []
while train_batch_size < min_steps:
batch = replay_buffer.sample(num_items=1)
if batch is None:
return None
train_batches.append(batch)
train_batch_size += (
train_batches[-1].agent_steps()
if count_by_agent_steps
else train_batches[-1].env_steps()
)
# All batch types are the same type, hence we can use any concat_samples()
train_batch = SampleBatch.concat_samples(train_batches)
return train_batch
@ExperimentalAPI
def validate_buffer_config(config: dict):
if config.get("replay_buffer_config", None) is None:
config["replay_buffer_config"] = {}
prioritized_replay = config.get("prioritized_replay")
if prioritized_replay != DEPRECATED_VALUE:
deprecation_warning(
old="config['prioritized_replay']",
help="Replay prioritization specified at new location config["
"'replay_buffer_config']["
"'prioritized_replay'] will be overwritten.",
error=False,
)
config["replay_buffer_config"]["prioritized_replay"] = prioritized_replay
capacity = config.get("buffer_size", DEPRECATED_VALUE)
if capacity != DEPRECATED_VALUE:
deprecation_warning(
old="config['buffer_size']",
help="Buffer size specified at new location config["
"'replay_buffer_config']["
"'capacity'] will be overwritten.",
error=False,
)
config["replay_buffer_config"]["capacity"] = capacity
# Deprecation of old-style replay buffer args
# Warnings before checking of we need local buffer so that algorithms
# Without local buffer also get warned
deprecated_replay_buffer_keys = [
"prioritized_replay_alpha",
"prioritized_replay_beta",
"prioritized_replay_eps",
"learning_starts",
]
for k in deprecated_replay_buffer_keys:
if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
deprecation_warning(
old="config[{}]".format(k),
help="config['replay_buffer_config'][{}] should be used "
"for Q-Learning algorithms. Ignore this warning if "
"you are not using a Q-Learning algorithm and still "
"provide {}."
"".format(k, k),
error=False,
)
# Copy values over to new location in config to support new
# and old configuration style
if config.get("replay_buffer_config") is not None:
config["replay_buffer_config"][k] = config[k]
# Old Ape-X configs may contain no_local_replay_buffer
no_local_replay_buffer = config.get("no_local_replay_buffer", False)
if no_local_replay_buffer:
deprecation_warning(
old="config['no_local_replay_buffer']",
help="no_local_replay_buffer specified at new location config["
"'replay_buffer_config']["
"'capacity'] will be overwritten.",
error=False,
)
config["replay_buffer_config"][
"no_local_replay_buffer"
] = no_local_replay_buffer
# TODO (Artur):
if config["replay_buffer_config"].get("no_local_replay_buffer", False):
return
replay_buffer_config = config["replay_buffer_config"]
assert (
"type" in replay_buffer_config
), "Can not instantiate ReplayBuffer from config without 'type' key."
# Check if old replay buffer should be instantiated
buffer_type = config["replay_buffer_config"]["type"]
if not config["replay_buffer_config"].get("_enable_replay_buffer_api", False):
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
# Prepend old-style buffers' path
assert buffer_type == "MultiAgentReplayBuffer", (
"Without "
"ReplayBuffer "
"API, only "
"MultiAgentReplayBuffer "
"is supported!"
)
# Create valid full [module].[class] string for from_config
buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer"
else:
assert buffer_type in [
"ray.rllib.execution.MultiAgentReplayBuffer",
Legacy_MultiAgentReplayBuffer,
], (
"Without ReplayBuffer API, only " "MultiAgentReplayBuffer is supported!"
)
config["replay_buffer_config"]["type"] = buffer_type
# Remove from config, so it's not passed into the buffer c'tor
config["replay_buffer_config"].pop("_enable_replay_buffer_api", None)
# We need to deprecate the old-style location of the following
# buffer arguments and make users put them into the
# "replay_buffer_config" field of their config.
replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE)
if replay_batch_size != DEPRECATED_VALUE:
config["replay_buffer_config"]["replay_batch_size"] = replay_batch_size
deprecation_warning(
old="config['replay_batch_size']",
help="Replay batch size specified at new "
"location config['replay_buffer_config']["
"'replay_batch_size'] will be overwritten.",
error=False,
)
replay_mode = config.get("replay_mode", DEPRECATED_VALUE)
if replay_mode != DEPRECATED_VALUE:
config["replay_buffer_config"]["replay_mode"] = replay_mode
deprecation_warning(
old="config['multiagent']['replay_mode']",
help="Replay sequence length specified at new "
"location config['replay_buffer_config']["
"'replay_mode'] will be overwritten.",
error=False,
)
# Can't use DEPRECATED_VALUE here because this is also a deliberate
# value set for some algorithms
# TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
replay_sequence_length = config.get("replay_sequence_length", None)
if replay_sequence_length is not None:
config["replay_buffer_config"][
"replay_sequence_length"
] = replay_sequence_length
deprecation_warning(
old="config['replay_sequence_length']",
help="Replay sequence length specified at new "
"location config['replay_buffer_config']["
"'replay_sequence_length'] will be overwritten.",
error=False,
)
replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
if replay_burn_in != DEPRECATED_VALUE:
config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
deprecation_warning(
old="config['burn_in']",
help="Burn in specified at new location config["
"'replay_buffer_config']["
"'replay_burn_in'] will be overwritten.",
)
replay_zero_init_states = config.get(
"replay_zero_init_states", DEPRECATED_VALUE
)
if replay_zero_init_states != DEPRECATED_VALUE:
config["replay_buffer_config"][
"replay_zero_init_states"
] = replay_zero_init_states
deprecation_warning(
old="config['replay_zero_init_states']",
help="Replay zero init states specified at new location "
"config["
"'replay_buffer_config']["
"'replay_zero_init_states'] will be overwritten.",
error=False,
)
# TODO (Artur): Move this logic into config objects
if config["replay_buffer_config"].get("prioritized_replay", False):
is_prioritized_buffer = True
else:
is_prioritized_buffer = False
# This triggers non-prioritization in old-style replay buffer
config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0
else:
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
# Create valid full [module].[class] string for from_config
config["replay_buffer_config"]["type"] = (
"ray.rllib.utils.replay_buffers." + buffer_type
)
test_buffer = from_config(buffer_type, config["replay_buffer_config"])
if hasattr(test_buffer, "update_priorities"):
is_prioritized_buffer = True
else:
is_prioritized_buffer = False
if is_prioritized_buffer:
if config["multiagent"]["replay_mode"] == "lockstep":
raise ValueError(
"Prioritized replay is not supported when replay_mode=lockstep."
)
elif config["replay_buffer_config"].get("replay_sequence_length", 0) > 1:
raise ValueError(
"Prioritized replay is not supported when "
"replay_sequence_length > 1."
)
else:
if config.get("worker_side_prioritization"):
raise ValueError(
"Worker side prioritization is not supported when "
"prioritized_replay=False."
)
if config["replay_buffer_config"].get("replay_batch_size", None) is None:
# Fall back to train batch size if no replay batch size was provided
config["replay_buffer_config"]["replay_batch_size"] = config["train_batch_size"]
# Pop prioritized replay because it's not a valid parameter for older
# replay buffers
config["replay_buffer_config"].pop("prioritized_replay", None)
def warn_replay_buffer_capacity(*, item: SampleBatchType, capacity: int) -> None:
"""Warn if the configured replay buffer capacity is too large for machine's memory.
Args:
item: A (example) item that's supposed to be added to the buffer.
This is used to compute the overall memory footprint estimate for the
buffer.
capacity: The capacity value of the buffer. This is interpreted as the
number of items (such as given `item`) that will eventually be stored in
the buffer.
Raises:
ValueError: If computed memory footprint for the buffer exceeds the machine's
RAM.
"""
if log_once("warn_replay_buffer_capacity"):
item_size = item.size_bytes()
psutil_mem = psutil.virtual_memory()
total_gb = psutil_mem.total / 1e9
mem_size = capacity * item_size / 1e9
msg = (
"Estimated max memory usage for replay buffer is {} GB "
"({} batches of size {}, {} bytes each), "
"available system memory is {} GB".format(
mem_size, capacity, item.count, item_size, total_gb
)
)
if mem_size > total_gb:
raise ValueError(msg)
elif mem_size > 0.2 * total_gb:
logger.warning(msg)
else:
logger.info(msg)