diff --git a/rllib/utils/replay_buffers/utils.py b/rllib/utils/replay_buffers/utils.py index b5d15f98e..2f92031df 100644 --- a/rllib/utils/replay_buffers/utils.py +++ b/rllib/utils/replay_buffers/utils.py @@ -1,6 +1,7 @@ import logging import psutil from typing import Optional, Any +import numpy as np from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import deprecation_warning @@ -52,11 +53,34 @@ def update_priorities_in_replay_buffer( # policies (note: fixing this in torch_policy.py will # break e.g. DDPPO!). td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error")) + + policy_batch = train_batch.policy_batches[policy_id] # Set the get_interceptor to None in order to be able to access the numpy # arrays directly (instead of e.g. a torch array). - train_batch.policy_batches[policy_id].set_get_interceptor(None) + policy_batch.set_get_interceptor(None) # Get the replay buffer row indices that make up the `train_batch`. - batch_indices = train_batch.policy_batches[policy_id].get("batch_indexes") + batch_indices = policy_batch.get("batch_indexes") + + if SampleBatch.SEQ_LENS in policy_batch: + # Batch_indices are represented per column, in order to update + # priorities, we need one index per td_error + _batch_indices = [] + + # Sequenced batches have been zero padded to max_seq_len. + # Depending on how batches are split during learning, not all + # sequences have an associated td_error (trailing ones missing). + if policy_batch.zero_padded: + seq_lens = len(td_error) * [policy_batch.max_seq_len] + else: + seq_lens = policy_batch[SampleBatch.SEQ_LENS][: len(td_error)] + + # Go through all indices by sequence that they represent and shrink + # them to one index per sequences + sequence_sum = 0 + for seq_len in seq_lens: + _batch_indices.append(batch_indices[sequence_sum]) + sequence_sum += seq_len + batch_indices = np.array(_batch_indices) if td_error is None: if log_once(