mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Fix priority update for sequenced batches. (#27544)
This commit is contained in:
parent
a1d80dc195
commit
04bc845360
1 changed files with 26 additions and 2 deletions
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
import psutil
|
||||
from typing import Optional, Any
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils import deprecation_warning
|
||||
|
@ -52,11 +53,34 @@ def update_priorities_in_replay_buffer(
|
|||
# policies (note: fixing this in torch_policy.py will
|
||||
# break e.g. DDPPO!).
|
||||
td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error"))
|
||||
|
||||
policy_batch = train_batch.policy_batches[policy_id]
|
||||
# Set the get_interceptor to None in order to be able to access the numpy
|
||||
# arrays directly (instead of e.g. a torch array).
|
||||
train_batch.policy_batches[policy_id].set_get_interceptor(None)
|
||||
policy_batch.set_get_interceptor(None)
|
||||
# Get the replay buffer row indices that make up the `train_batch`.
|
||||
batch_indices = train_batch.policy_batches[policy_id].get("batch_indexes")
|
||||
batch_indices = policy_batch.get("batch_indexes")
|
||||
|
||||
if SampleBatch.SEQ_LENS in policy_batch:
|
||||
# Batch_indices are represented per column, in order to update
|
||||
# priorities, we need one index per td_error
|
||||
_batch_indices = []
|
||||
|
||||
# Sequenced batches have been zero padded to max_seq_len.
|
||||
# Depending on how batches are split during learning, not all
|
||||
# sequences have an associated td_error (trailing ones missing).
|
||||
if policy_batch.zero_padded:
|
||||
seq_lens = len(td_error) * [policy_batch.max_seq_len]
|
||||
else:
|
||||
seq_lens = policy_batch[SampleBatch.SEQ_LENS][: len(td_error)]
|
||||
|
||||
# Go through all indices by sequence that they represent and shrink
|
||||
# them to one index per sequences
|
||||
sequence_sum = 0
|
||||
for seq_len in seq_lens:
|
||||
_batch_indices.append(batch_indices[sequence_sum])
|
||||
sequence_sum += seq_len
|
||||
batch_indices = np.array(_batch_indices)
|
||||
|
||||
if td_error is None:
|
||||
if log_once(
|
||||
|
|
Loading…
Add table
Reference in a new issue