[RLlib] Fix priority update for sequenced batches. (#27544)

This commit is contained in:
Artur Niederfahrenhorst 2022-08-10 12:48:25 +02:00 committed by GitHub
parent a1d80dc195
commit 04bc845360
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,6 +1,7 @@
import logging
import psutil
from typing import Optional, Any
import numpy as np
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import deprecation_warning
@ -52,11 +53,34 @@ def update_priorities_in_replay_buffer(
# policies (note: fixing this in torch_policy.py will
# break e.g. DDPPO!).
td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error"))
policy_batch = train_batch.policy_batches[policy_id]
# Set the get_interceptor to None in order to be able to access the numpy
# arrays directly (instead of e.g. a torch array).
train_batch.policy_batches[policy_id].set_get_interceptor(None)
policy_batch.set_get_interceptor(None)
# Get the replay buffer row indices that make up the `train_batch`.
batch_indices = train_batch.policy_batches[policy_id].get("batch_indexes")
batch_indices = policy_batch.get("batch_indexes")
if SampleBatch.SEQ_LENS in policy_batch:
# Batch_indices are represented per column, in order to update
# priorities, we need one index per td_error
_batch_indices = []
# Sequenced batches have been zero padded to max_seq_len.
# Depending on how batches are split during learning, not all
# sequences have an associated td_error (trailing ones missing).
if policy_batch.zero_padded:
seq_lens = len(td_error) * [policy_batch.max_seq_len]
else:
seq_lens = policy_batch[SampleBatch.SEQ_LENS][: len(td_error)]
# Go through all indices by sequence that they represent and shrink
# them to one index per sequences
sequence_sum = 0
for seq_len in seq_lens:
_batch_indices.append(batch_indices[sequence_sum])
sequence_sum += seq_len
batch_indices = np.array(_batch_indices)
if td_error is None:
if log_once(