mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
363 lines
15 KiB
Python
363 lines
15 KiB
Python
import logging
|
|
import psutil
|
|
from typing import Optional
|
|
|
|
from ray.rllib.execution import MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer
|
|
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
|
|
MultiAgentReplayBuffer as LegacyMultiAgentReplayBuffer,
|
|
)
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils import deprecation_warning
|
|
from ray.rllib.utils.annotations import ExperimentalAPI
|
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
|
from ray.rllib.utils.from_config import from_config
|
|
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
|
from ray.rllib.utils.replay_buffers import (
|
|
MultiAgentPrioritizedReplayBuffer,
|
|
ReplayBuffer,
|
|
)
|
|
from ray.rllib.utils.typing import ResultDict, SampleBatchType, TrainerConfigDict
|
|
from ray.util import log_once
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def update_priorities_in_replay_buffer(
|
|
replay_buffer: ReplayBuffer,
|
|
config: TrainerConfigDict,
|
|
train_batch: SampleBatchType,
|
|
train_results: ResultDict,
|
|
) -> None:
|
|
|
|
"""Updates the priorities in a prioritized replay buffer, given training results.
|
|
|
|
The `abs(TD-error)` from the loss (inside `train_results`) is used as new
|
|
priorities for the row-indices that were sampled for the train batch.
|
|
|
|
Don't do anything if the given buffer does not support prioritized replay.
|
|
|
|
Args:
|
|
replay_buffer: The replay buffer, whose priority values to update. This may also
|
|
be a buffer that does not support priorities.
|
|
config: The Trainer's config dict.
|
|
train_batch: The batch used for the training update.
|
|
train_results: A train results dict, generated by e.g. the `train_one_step()`
|
|
utility.
|
|
"""
|
|
# Only update priorities if buffer supports them.
|
|
if (
|
|
type(replay_buffer) is LegacyMultiAgentReplayBuffer
|
|
and config["replay_buffer_config"].get("prioritized_replay_alpha", 0.0) > 0.0
|
|
) or isinstance(replay_buffer, MultiAgentPrioritizedReplayBuffer):
|
|
# Go through training results for the different policies (maybe multi-agent).
|
|
prio_dict = {}
|
|
for policy_id, info in train_results.items():
|
|
# TODO(sven): This is currently structured differently for
|
|
# torch/tf. Clean up these results/info dicts across
|
|
# policies (note: fixing this in torch_policy.py will
|
|
# break e.g. DDPPO!).
|
|
td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error"))
|
|
# Set the get_interceptor to None in order to be able to access the numpy
|
|
# arrays directly (instead of e.g. a torch array).
|
|
train_batch.policy_batches[policy_id].set_get_interceptor(None)
|
|
# Get the replay buffer row indices that make up the `train_batch`.
|
|
batch_indices = train_batch.policy_batches[policy_id].get("batch_indexes")
|
|
# In case the buffer stores sequences, TD-error could
|
|
# already be calculated per sequence chunk.
|
|
if len(batch_indices) != len(td_error):
|
|
T = replay_buffer.replay_sequence_length
|
|
assert (
|
|
len(batch_indices) > len(td_error) and len(batch_indices) % T == 0
|
|
)
|
|
batch_indices = batch_indices.reshape([-1, T])[:, 0]
|
|
assert len(batch_indices) == len(td_error)
|
|
prio_dict[policy_id] = (batch_indices, td_error)
|
|
|
|
# Make the actual buffer API call to update the priority weights on all
|
|
# policies.
|
|
replay_buffer.update_priorities(prio_dict)
|
|
|
|
|
|
def sample_min_n_steps_from_buffer(
|
|
replay_buffer: ReplayBuffer, min_steps: int, count_by_agent_steps: bool
|
|
) -> Optional[SampleBatchType]:
|
|
"""Samples a minimum of n timesteps from a given replay buffer.
|
|
|
|
This utility method is primarily used by the QMIX algorithm and helps with
|
|
sampling a given number of time steps which has stored samples in units
|
|
of sequences or complete episodes. Samples n batches from replay buffer
|
|
until the total number of timesteps reaches `train_batch_size`.
|
|
|
|
Args:
|
|
replay_buffer: The replay buffer to sample from
|
|
num_timesteps: The number of timesteps to sample
|
|
count_by_agent_steps: Whether to count agent steps or env steps
|
|
|
|
Returns:
|
|
A concatenated SampleBatch or MultiAgentBatch with samples from the
|
|
buffer.
|
|
"""
|
|
train_batch_size = 0
|
|
train_batches = []
|
|
while train_batch_size < min_steps:
|
|
batch = replay_buffer.sample(num_items=1)
|
|
if batch is None:
|
|
return None
|
|
train_batches.append(batch)
|
|
train_batch_size += (
|
|
train_batches[-1].agent_steps()
|
|
if count_by_agent_steps
|
|
else train_batches[-1].env_steps()
|
|
)
|
|
# All batch types are the same type, hence we can use any concat_samples()
|
|
train_batch = SampleBatch.concat_samples(train_batches)
|
|
return train_batch
|
|
|
|
|
|
@ExperimentalAPI
|
|
def validate_buffer_config(config: dict):
|
|
if config.get("replay_buffer_config", None) is None:
|
|
config["replay_buffer_config"] = {}
|
|
|
|
prioritized_replay = config.get("prioritized_replay")
|
|
if prioritized_replay != DEPRECATED_VALUE:
|
|
deprecation_warning(
|
|
old="config['prioritized_replay']",
|
|
help="Replay prioritization specified at new location config["
|
|
"'replay_buffer_config']["
|
|
"'prioritized_replay'] will be overwritten.",
|
|
error=False,
|
|
)
|
|
config["replay_buffer_config"]["prioritized_replay"] = prioritized_replay
|
|
|
|
capacity = config.get("buffer_size", DEPRECATED_VALUE)
|
|
if capacity != DEPRECATED_VALUE:
|
|
deprecation_warning(
|
|
old="config['buffer_size']",
|
|
help="Buffer size specified at new location config["
|
|
"'replay_buffer_config']["
|
|
"'capacity'] will be overwritten.",
|
|
error=False,
|
|
)
|
|
config["replay_buffer_config"]["capacity"] = capacity
|
|
|
|
# Deprecation of old-style replay buffer args
|
|
# Warnings before checking of we need local buffer so that algorithms
|
|
# Without local buffer also get warned
|
|
deprecated_replay_buffer_keys = [
|
|
"prioritized_replay_alpha",
|
|
"prioritized_replay_beta",
|
|
"prioritized_replay_eps",
|
|
"learning_starts",
|
|
]
|
|
for k in deprecated_replay_buffer_keys:
|
|
if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
|
|
deprecation_warning(
|
|
old="config[{}]".format(k),
|
|
help="config['replay_buffer_config'][{}] should be used "
|
|
"for Q-Learning algorithms. Ignore this warning if "
|
|
"you are not using a Q-Learning algorithm and still "
|
|
"provide {}."
|
|
"".format(k, k),
|
|
error=False,
|
|
)
|
|
# Copy values over to new location in config to support new
|
|
# and old configuration style
|
|
if config.get("replay_buffer_config") is not None:
|
|
config["replay_buffer_config"][k] = config[k]
|
|
|
|
# Old Ape-X configs may contain no_local_replay_buffer
|
|
no_local_replay_buffer = config.get("no_local_replay_buffer", False)
|
|
if no_local_replay_buffer:
|
|
deprecation_warning(
|
|
old="config['no_local_replay_buffer']",
|
|
help="no_local_replay_buffer specified at new location config["
|
|
"'replay_buffer_config']["
|
|
"'capacity'] will be overwritten.",
|
|
error=False,
|
|
)
|
|
config["replay_buffer_config"][
|
|
"no_local_replay_buffer"
|
|
] = no_local_replay_buffer
|
|
|
|
# TODO (Artur):
|
|
if config["replay_buffer_config"].get("no_local_replay_buffer", False):
|
|
return
|
|
|
|
replay_buffer_config = config["replay_buffer_config"]
|
|
assert (
|
|
"type" in replay_buffer_config
|
|
), "Can not instantiate ReplayBuffer from config without 'type' key."
|
|
|
|
# Check if old replay buffer should be instantiated
|
|
buffer_type = config["replay_buffer_config"]["type"]
|
|
if not config["replay_buffer_config"].get("_enable_replay_buffer_api", False):
|
|
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
|
|
# Prepend old-style buffers' path
|
|
assert buffer_type == "MultiAgentReplayBuffer", (
|
|
"Without "
|
|
"ReplayBuffer "
|
|
"API, only "
|
|
"MultiAgentReplayBuffer "
|
|
"is supported!"
|
|
)
|
|
# Create valid full [module].[class] string for from_config
|
|
buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer"
|
|
else:
|
|
assert buffer_type in [
|
|
"ray.rllib.execution.MultiAgentReplayBuffer",
|
|
Legacy_MultiAgentReplayBuffer,
|
|
], (
|
|
"Without ReplayBuffer API, only " "MultiAgentReplayBuffer is supported!"
|
|
)
|
|
|
|
config["replay_buffer_config"]["type"] = buffer_type
|
|
|
|
# Remove from config, so it's not passed into the buffer c'tor
|
|
config["replay_buffer_config"].pop("_enable_replay_buffer_api", None)
|
|
|
|
# We need to deprecate the old-style location of the following
|
|
# buffer arguments and make users put them into the
|
|
# "replay_buffer_config" field of their config.
|
|
replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE)
|
|
if replay_batch_size != DEPRECATED_VALUE:
|
|
config["replay_buffer_config"]["replay_batch_size"] = replay_batch_size
|
|
deprecation_warning(
|
|
old="config['replay_batch_size']",
|
|
help="Replay batch size specified at new "
|
|
"location config['replay_buffer_config']["
|
|
"'replay_batch_size'] will be overwritten.",
|
|
error=False,
|
|
)
|
|
|
|
replay_mode = config.get("replay_mode", DEPRECATED_VALUE)
|
|
if replay_mode != DEPRECATED_VALUE:
|
|
config["replay_buffer_config"]["replay_mode"] = replay_mode
|
|
deprecation_warning(
|
|
old="config['multiagent']['replay_mode']",
|
|
help="Replay sequence length specified at new "
|
|
"location config['replay_buffer_config']["
|
|
"'replay_mode'] will be overwritten.",
|
|
error=False,
|
|
)
|
|
|
|
# Can't use DEPRECATED_VALUE here because this is also a deliberate
|
|
# value set for some algorithms
|
|
# TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
|
|
replay_sequence_length = config.get("replay_sequence_length", None)
|
|
if replay_sequence_length is not None:
|
|
config["replay_buffer_config"][
|
|
"replay_sequence_length"
|
|
] = replay_sequence_length
|
|
deprecation_warning(
|
|
old="config['replay_sequence_length']",
|
|
help="Replay sequence length specified at new "
|
|
"location config['replay_buffer_config']["
|
|
"'replay_sequence_length'] will be overwritten.",
|
|
error=False,
|
|
)
|
|
|
|
replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
|
|
if replay_burn_in != DEPRECATED_VALUE:
|
|
config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
|
|
deprecation_warning(
|
|
old="config['burn_in']",
|
|
help="Burn in specified at new location config["
|
|
"'replay_buffer_config']["
|
|
"'replay_burn_in'] will be overwritten.",
|
|
)
|
|
|
|
replay_zero_init_states = config.get(
|
|
"replay_zero_init_states", DEPRECATED_VALUE
|
|
)
|
|
if replay_zero_init_states != DEPRECATED_VALUE:
|
|
config["replay_buffer_config"][
|
|
"replay_zero_init_states"
|
|
] = replay_zero_init_states
|
|
deprecation_warning(
|
|
old="config['replay_zero_init_states']",
|
|
help="Replay zero init states specified at new location "
|
|
"config["
|
|
"'replay_buffer_config']["
|
|
"'replay_zero_init_states'] will be overwritten.",
|
|
error=False,
|
|
)
|
|
|
|
# TODO (Artur): Move this logic into config objects
|
|
if config["replay_buffer_config"].get("prioritized_replay", False):
|
|
is_prioritized_buffer = True
|
|
else:
|
|
is_prioritized_buffer = False
|
|
# This triggers non-prioritization in old-style replay buffer
|
|
config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0
|
|
else:
|
|
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
|
|
# Create valid full [module].[class] string for from_config
|
|
config["replay_buffer_config"]["type"] = (
|
|
"ray.rllib.utils.replay_buffers." + buffer_type
|
|
)
|
|
test_buffer = from_config(buffer_type, config["replay_buffer_config"])
|
|
if hasattr(test_buffer, "update_priorities"):
|
|
is_prioritized_buffer = True
|
|
else:
|
|
is_prioritized_buffer = False
|
|
|
|
if is_prioritized_buffer:
|
|
if config["multiagent"]["replay_mode"] == "lockstep":
|
|
raise ValueError(
|
|
"Prioritized replay is not supported when replay_mode=lockstep."
|
|
)
|
|
elif config["replay_buffer_config"].get("replay_sequence_length", 0) > 1:
|
|
raise ValueError(
|
|
"Prioritized replay is not supported when "
|
|
"replay_sequence_length > 1."
|
|
)
|
|
else:
|
|
if config.get("worker_side_prioritization"):
|
|
raise ValueError(
|
|
"Worker side prioritization is not supported when "
|
|
"prioritized_replay=False."
|
|
)
|
|
|
|
if config["replay_buffer_config"].get("replay_batch_size", None) is None:
|
|
# Fall back to train batch size if no replay batch size was provided
|
|
config["replay_buffer_config"]["replay_batch_size"] = config["train_batch_size"]
|
|
|
|
# Pop prioritized replay because it's not a valid parameter for older
|
|
# replay buffers
|
|
config["replay_buffer_config"].pop("prioritized_replay", None)
|
|
|
|
|
|
def warn_replay_buffer_capacity(*, item: SampleBatchType, capacity: int) -> None:
|
|
"""Warn if the configured replay buffer capacity is too large for machine's memory.
|
|
|
|
Args:
|
|
item: A (example) item that's supposed to be added to the buffer.
|
|
This is used to compute the overall memory footprint estimate for the
|
|
buffer.
|
|
capacity: The capacity value of the buffer. This is interpreted as the
|
|
number of items (such as given `item`) that will eventually be stored in
|
|
the buffer.
|
|
|
|
Raises:
|
|
ValueError: If computed memory footprint for the buffer exceeds the machine's
|
|
RAM.
|
|
"""
|
|
if log_once("warn_replay_buffer_capacity"):
|
|
item_size = item.size_bytes()
|
|
psutil_mem = psutil.virtual_memory()
|
|
total_gb = psutil_mem.total / 1e9
|
|
mem_size = capacity * item_size / 1e9
|
|
msg = (
|
|
"Estimated max memory usage for replay buffer is {} GB "
|
|
"({} batches of size {}, {} bytes each), "
|
|
"available system memory is {} GB".format(
|
|
mem_size, capacity, item.count, item_size, total_gb
|
|
)
|
|
)
|
|
if mem_size > total_gb:
|
|
raise ValueError(msg)
|
|
elif mem_size > 0.2 * total_gb:
|
|
logger.warning(msg)
|
|
else:
|
|
logger.info(msg)
|