mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00

* WIP. * Fixes. * LINT. * WIP. * WIP. * Fixes. * Fixes. * Fixes. * Fixes. * WIP. * Fixes. * Test * Fix. * Fixes and LINT. * Fixes and LINT. * LINT.
37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
import ray
|
|
from ray.rllib.evaluation.postprocessing import Postprocessing, \
|
|
compute_advantages
|
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
|
def post_process_advantages(policy,
|
|
sample_batch,
|
|
other_agent_batches=None,
|
|
episode=None):
|
|
"""This adds the "advantages" column to the sample train_batch."""
|
|
return compute_advantages(
|
|
sample_batch,
|
|
0.0,
|
|
policy.config["gamma"],
|
|
use_gae=False,
|
|
use_critic=False)
|
|
|
|
|
|
def pg_tf_loss(policy, model, dist_class, train_batch):
|
|
"""The basic policy gradients loss."""
|
|
logits, _ = model.from_batch(train_batch)
|
|
action_dist = dist_class(logits, model)
|
|
return -tf.reduce_mean(
|
|
action_dist.logp(train_batch[SampleBatch.ACTIONS]) *
|
|
train_batch[Postprocessing.ADVANTAGES])
|
|
|
|
|
|
PGTFPolicy = build_tf_policy(
|
|
name="PGTFPolicy",
|
|
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
|
|
postprocess_fn=post_process_advantages,
|
|
loss_fn=pg_tf_loss)
|