ray/rllib/agents/pg/pg_tf_policy.py

37 lines
1.2 KiB
Python

import ray
from ray.rllib.evaluation.postprocessing import Postprocessing, \
compute_advantages
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def post_process_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
"""This adds the "advantages" column to the sample train_batch."""
return compute_advantages(
sample_batch,
0.0,
policy.config["gamma"],
use_gae=False,
use_critic=False)
def pg_tf_loss(policy, model, dist_class, train_batch):
"""The basic policy gradients loss."""
logits, _ = model.from_batch(train_batch)
action_dist = dist_class(logits, model)
return -tf.reduce_mean(
action_dist.logp(train_batch[SampleBatch.ACTIONS]) *
train_batch[Postprocessing.ADVANTAGES])
PGTFPolicy = build_tf_policy(
name="PGTFPolicy",
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
postprocess_fn=post_process_advantages,
loss_fn=pg_tf_loss)