mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Fix broken test_r2d2.py. (#19017)
This commit is contained in:
parent
301312e77f
commit
16ad46a654
1 changed files with 5 additions and 1 deletions
|
@ -510,7 +510,11 @@ def check_train_results(train_results):
|
|||
if "td_error" in policy_stats:
|
||||
configured_b = train_results["config"]["train_batch_size"]
|
||||
actual_b = policy_stats["td_error"].shape[0]
|
||||
assert (configured_b - actual_b) / actual_b <= 0.1
|
||||
# R2D2 case.
|
||||
if (configured_b - actual_b) / actual_b > 0.1:
|
||||
assert configured_b / (
|
||||
train_results["config"]["model"]["max_seq_len"] +
|
||||
train_results["config"]["burn_in"]) == actual_b
|
||||
|
||||
# Make sure each policy has the LEARNER_STATS_KEY under it.
|
||||
assert LEARNER_STATS_KEY in policy_stats
|
||||
|
|
Loading…
Add table
Reference in a new issue