[RLlib] Fix broken test_r2d2.py. (#19017)

This commit is contained in:
Sven Mika 2021-09-30 21:19:37 +02:00 committed by GitHub
parent 301312e77f
commit 16ad46a654
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -510,7 +510,11 @@ def check_train_results(train_results):
if "td_error" in policy_stats:
configured_b = train_results["config"]["train_batch_size"]
actual_b = policy_stats["td_error"].shape[0]
assert (configured_b - actual_b) / actual_b <= 0.1
# R2D2 case.
if (configured_b - actual_b) / actual_b > 0.1:
assert configured_b / (
train_results["config"]["model"]["max_seq_len"] +
train_results["config"]["burn_in"]) == actual_b
# Make sure each policy has the LEARNER_STATS_KEY under it.
assert LEARNER_STATS_KEY in policy_stats