2018-06-09 00:21:35 -07:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
2018-06-25 22:33:57 -07:00
|
|
|
import ray
|
2019-03-29 12:44:23 -07:00
|
|
|
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
|
|
|
Postprocessing
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2019-05-10 20:36:18 -07:00
|
|
|
from ray.rllib.utils import try_import_tf
|
|
|
|
|
|
|
|
tf = try_import_tf()
|
2018-06-09 00:21:35 -07:00
|
|
|
|
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
# The basic policy gradients loss
|
2019-08-23 02:21:11 -04:00
|
|
|
def policy_gradient_loss(policy, model, dist_class, train_batch):
|
|
|
|
logits, _ = model.from_batch(train_batch)
|
|
|
|
action_dist = dist_class(logits, model)
|
|
|
|
return -tf.reduce_mean(
|
|
|
|
action_dist.logp(train_batch[SampleBatch.ACTIONS]) *
|
|
|
|
train_batch[Postprocessing.ADVANTAGES])
|
2018-10-29 19:37:27 -07:00
|
|
|
|
2018-06-26 13:17:15 -07:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
# This adds the "advantages" column to the sampletrain_batch.
|
2019-05-18 00:23:11 -07:00
|
|
|
def postprocess_advantages(policy,
|
|
|
|
sample_batch,
|
|
|
|
other_agent_batches=None,
|
|
|
|
episode=None):
|
|
|
|
return compute_advantages(
|
|
|
|
sample_batch, 0.0, policy.config["gamma"], use_gae=False)
|
2018-06-09 00:21:35 -07:00
|
|
|
|
2019-03-29 12:44:23 -07:00
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
PGTFPolicy = build_tf_policy(
|
|
|
|
name="PGTFPolicy",
|
|
|
|
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
|
|
|
|
postprocess_fn=postprocess_advantages,
|
|
|
|
loss_fn=policy_gradient_loss)
|