ray/rllib/algorithms/pg/utils.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

39 lines
1.3 KiB
Python
Raw Normal View History

from typing import List, Optional
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
def post_process_advantages(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[List[SampleBatch]] = None,
episode: Optional[Episode] = None,
) -> SampleBatch:
"""Adds the "advantages" column to `sample_batch`.
Args:
policy (Policy): The Policy object to do post-processing for.
sample_batch (SampleBatch): The actual sample batch to post-process.
other_agent_batches (Optional[List[SampleBatch]]): Optional list of
other agents' SampleBatch objects.
episode (Episode): The multi-agent episode object, from which
`sample_batch` was generated.
Returns:
SampleBatch: The SampleBatch enhanced by the added ADVANTAGES field.
"""
# Calculates advantage values based on the rewards in the sample batch.
# The value of the last observation is assumed to be 0.0 (no value function
# estimation at the end of the sampled chunk).
return compute_advantages(
rollout=sample_batch,
last_r=0.0,
gamma=policy.config["gamma"],
use_gae=False,
use_critic=False,
)