ray/rllib/utils/replay_buffers/utils.py

65 lines
3 KiB
Python

from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer as LegacyMultiAgentReplayBuffer,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.replay_buffers import (
MultiAgentPrioritizedReplayBuffer,
ReplayBuffer,
)
from ray.rllib.utils.typing import ResultDict, SampleBatchType, TrainerConfigDict
def update_priorities_in_replay_buffer(
replay_buffer: ReplayBuffer,
config: TrainerConfigDict,
train_batch: SampleBatchType,
train_results: ResultDict,
) -> None:
"""Updates the priorities in a prioritized replay buffer, given training results.
The `abs(TD-error)` from the loss (inside `train_results`) is used as new
priorities for the row-indices that were sampled for the train batch.
Don't do anything if the given buffer does not support prioritized replay.
Args:
replay_buffer: The replay buffer, whose priority values to update. This may also
be a buffer that does not support priorities.
config: The Trainer's config dict.
train_batch: The batch used for the training update.
train_results: A train results dict, generated by e.g. the `train_one_step()`
utility.
"""
# Only update priorities if buffer supports them.
if (
type(replay_buffer) is LegacyMultiAgentReplayBuffer
and config["replay_buffer_config"].get("prioritized_replay_alpha", 0.0) > 0.0
) or isinstance(replay_buffer, MultiAgentPrioritizedReplayBuffer):
# Go through training results for the different policies (maybe multi-agent).
prio_dict = {}
for policy_id, info in train_results.items():
# TODO(sven): This is currently structured differently for
# torch/tf. Clean up these results/info dicts across
# policies (note: fixing this in torch_policy.py will
# break e.g. DDPPO!).
td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error"))
# Set the get_interceptor to None in order to be able to access the numpy
# arrays directly (instead of e.g. a torch array).
train_batch.policy_batches[policy_id].set_get_interceptor(None)
# Get the replay buffer row indices that make up the `train_batch`.
batch_indices = train_batch.policy_batches[policy_id].get("batch_indexes")
# In case the buffer stores sequences, TD-error could
# already be calculated per sequence chunk.
if len(batch_indices) != len(td_error):
T = replay_buffer.replay_sequence_length
assert (
len(batch_indices) > len(td_error) and len(batch_indices) % T == 0
)
batch_indices = batch_indices.reshape([-1, T])[:, 0]
assert len(batch_indices) == len(td_error)
prio_dict[policy_id] = (batch_indices, td_error)
# Make the actual buffer API call to update the priority weights on all
# policies.
replay_buffer.update_priorities(prio_dict)