mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Fix truncate episodes mode in central critic example (#8073)
This commit is contained in:
parent
3812bfedda
commit
17e3c545d9
1 changed files with 8 additions and 4 deletions
|
@ -96,8 +96,6 @@ def centralized_critic_postprocessing(policy,
|
|||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if policy.loss_initialized():
|
||||
assert sample_batch["dones"][-1], \
|
||||
"Not implemented for train_batch_mode=truncate_episodes"
|
||||
assert other_agent_batches is not None
|
||||
[(_, opponent_batch)] = list(other_agent_batches.values())
|
||||
|
||||
|
@ -116,11 +114,17 @@ def centralized_critic_postprocessing(policy,
|
|||
sample_batch[OPPONENT_ACTION] = np.zeros_like(
|
||||
sample_batch[SampleBatch.ACTIONS])
|
||||
sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(
|
||||
sample_batch[SampleBatch.ACTIONS], dtype=np.float32)
|
||||
sample_batch[SampleBatch.REWARDS], dtype=np.float32)
|
||||
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
last_r = sample_batch[SampleBatch.VF_PREDS][-1]
|
||||
|
||||
train_batch = compute_advantages(
|
||||
sample_batch,
|
||||
0.0,
|
||||
last_r,
|
||||
policy.config["gamma"],
|
||||
policy.config["lambda"],
|
||||
use_gae=policy.config["use_gae"])
|
||||
|
|
Loading…
Add table
Reference in a new issue