diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index 56b25f33e..d4b2aa2fd 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -96,8 +96,6 @@ def centralized_critic_postprocessing(policy, other_agent_batches=None, episode=None): if policy.loss_initialized(): - assert sample_batch["dones"][-1], \ - "Not implemented for train_batch_mode=truncate_episodes" assert other_agent_batches is not None [(_, opponent_batch)] = list(other_agent_batches.values()) @@ -116,11 +114,17 @@ def centralized_critic_postprocessing(policy, sample_batch[OPPONENT_ACTION] = np.zeros_like( sample_batch[SampleBatch.ACTIONS]) sample_batch[SampleBatch.VF_PREDS] = np.zeros_like( - sample_batch[SampleBatch.ACTIONS], dtype=np.float32) + sample_batch[SampleBatch.REWARDS], dtype=np.float32) + + completed = sample_batch["dones"][-1] + if completed: + last_r = 0.0 + else: + last_r = sample_batch[SampleBatch.VF_PREDS][-1] train_batch = compute_advantages( sample_batch, - 0.0, + last_r, policy.config["gamma"], policy.config["lambda"], use_gae=policy.config["use_gae"])