[rllib] Fix truncate episodes mode in central critic example (#8073)

This commit is contained in:
Eric Liang 2020-04-20 12:58:01 -07:00 committed by GitHub
parent 3812bfedda
commit 17e3c545d9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -96,8 +96,6 @@ def centralized_critic_postprocessing(policy,
other_agent_batches=None,
episode=None):
if policy.loss_initialized():
assert sample_batch["dones"][-1], \
"Not implemented for train_batch_mode=truncate_episodes"
assert other_agent_batches is not None
[(_, opponent_batch)] = list(other_agent_batches.values())
@ -116,11 +114,17 @@ def centralized_critic_postprocessing(policy,
sample_batch[OPPONENT_ACTION] = np.zeros_like(
sample_batch[SampleBatch.ACTIONS])
sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(
sample_batch[SampleBatch.ACTIONS], dtype=np.float32)
sample_batch[SampleBatch.REWARDS], dtype=np.float32)
completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
else:
last_r = sample_batch[SampleBatch.VF_PREDS][-1]
train_batch = compute_advantages(
sample_batch,
0.0,
last_r,
policy.config["gamma"],
policy.config["lambda"],
use_gae=policy.config["use_gae"])