from typing import List, Optional

from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch


def post_process_advantages(
    policy: Policy,
    sample_batch: SampleBatch,
    other_agent_batches: Optional[List[SampleBatch]] = None,
    episode: Optional[Episode] = None,
) -> SampleBatch:
    """Adds the "advantages" column to `sample_batch`.

    Args:
        policy (Policy): The Policy object to do post-processing for.
        sample_batch (SampleBatch): The actual sample batch to post-process.
        other_agent_batches (Optional[List[SampleBatch]]): Optional list of
            other agents' SampleBatch objects.
        episode (Episode): The multi-agent episode object, from which
            `sample_batch` was generated.

    Returns:
        SampleBatch: The SampleBatch enhanced by the added ADVANTAGES field.
    """

    # Calculates advantage values based on the rewards in the sample batch.
    # The value of the last observation is assumed to be 0.0 (no value function
    # estimation at the end of the sampled chunk).
    return compute_advantages(
        rollout=sample_batch,
        last_r=0.0,
        gamma=policy.config["gamma"],
        use_gae=False,
        use_critic=False,
    )