2018-06-09 00:21:35 -07:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import tensorflow as tf
|
|
|
|
|
2018-06-25 22:33:57 -07:00
|
|
|
import ray
|
2018-06-09 00:21:35 -07:00
|
|
|
from ray.rllib.models.catalog import ModelCatalog
|
|
|
|
from ray.rllib.utils.process_rollout import compute_advantages
|
|
|
|
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
|
|
|
|
|
|
|
|
|
|
|
|
class PGPolicyGraph(TFPolicyGraph):
|
|
|
|
|
2018-06-19 22:47:00 -07:00
|
|
|
def __init__(self, obs_space, action_space, config):
|
2018-06-25 22:33:57 -07:00
|
|
|
config = dict(ray.rllib.pg.pg.DEFAULT_CONFIG, **config)
|
2018-06-09 00:21:35 -07:00
|
|
|
self.config = config
|
|
|
|
|
|
|
|
# setup policy
|
|
|
|
self.x = tf.placeholder(tf.float32, shape=[None]+list(obs_space.shape))
|
2018-06-19 19:47:26 -07:00
|
|
|
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
|
|
|
|
action_space, self.config["model"])
|
2018-06-09 00:21:35 -07:00
|
|
|
self.model = ModelCatalog.get_model(
|
2018-06-19 22:47:00 -07:00
|
|
|
self.x, self.logit_dim, options=self.config["model"])
|
2018-06-09 00:21:35 -07:00
|
|
|
self.dist = dist_class(self.model.outputs) # logit for each action
|
|
|
|
|
|
|
|
# setup policy loss
|
|
|
|
self.ac = ModelCatalog.get_action_placeholder(action_space)
|
|
|
|
self.adv = tf.placeholder(tf.float32, [None], name="adv")
|
|
|
|
self.loss = -tf.reduce_mean(self.dist.logp(self.ac) * self.adv)
|
|
|
|
|
|
|
|
# initialize TFPolicyGraph
|
|
|
|
self.sess = tf.get_default_session()
|
|
|
|
self.loss_in = [
|
|
|
|
("obs", self.x),
|
|
|
|
("actions", self.ac),
|
|
|
|
("advantages", self.adv),
|
|
|
|
]
|
|
|
|
self.is_training = tf.placeholder_with_default(True, ())
|
|
|
|
TFPolicyGraph.__init__(
|
2018-06-25 22:33:57 -07:00
|
|
|
self, obs_space, action_space, self.sess, obs_input=self.x,
|
2018-06-09 00:21:35 -07:00
|
|
|
action_sampler=self.dist.sample(), loss=self.loss,
|
|
|
|
loss_inputs=self.loss_in, is_training=self.is_training)
|
|
|
|
self.sess.run(tf.global_variables_initializer())
|
|
|
|
|
|
|
|
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
|
|
|
|
return compute_advantages(
|
|
|
|
sample_batch, 0.0, self.config["gamma"], use_gae=False)
|