2018-06-28 09:49:08 -07:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
2019-02-22 11:18:51 -08:00
|
|
|
import logging
|
2018-06-28 09:49:08 -07:00
|
|
|
|
2018-07-22 05:09:25 -07:00
|
|
|
import ray
|
2019-03-29 12:44:23 -07:00
|
|
|
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
|
|
|
Postprocessing
|
2019-03-27 15:40:15 -07:00
|
|
|
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
2018-12-08 16:28:58 -08:00
|
|
|
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
2019-03-29 12:44:23 -07:00
|
|
|
from ray.rllib.evaluation.sample_batch import SampleBatch
|
2018-08-23 17:49:10 -07:00
|
|
|
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
|
|
|
LearningRateSchedule
|
2018-06-28 09:49:08 -07:00
|
|
|
from ray.rllib.models.catalog import ModelCatalog
|
2018-12-08 16:28:58 -08:00
|
|
|
from ray.rllib.utils.annotations import override
|
2018-08-23 17:49:10 -07:00
|
|
|
from ray.rllib.utils.explained_variance import explained_variance
|
2019-05-10 20:36:18 -07:00
|
|
|
from ray.rllib.utils import try_import_tf
|
|
|
|
|
|
|
|
tf = try_import_tf()
|
2018-06-28 09:49:08 -07:00
|
|
|
|
2019-02-22 11:18:51 -08:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2019-03-29 12:44:23 -07:00
|
|
|
# Frozen logits of the policy that computed the action
|
|
|
|
BEHAVIOUR_LOGITS = "behaviour_logits"
|
|
|
|
|
2018-06-28 09:49:08 -07:00
|
|
|
|
|
|
|
class PPOLoss(object):
|
2018-07-19 15:30:36 -07:00
|
|
|
def __init__(self,
|
|
|
|
action_space,
|
|
|
|
value_targets,
|
|
|
|
advantages,
|
|
|
|
actions,
|
|
|
|
logits,
|
|
|
|
vf_preds,
|
|
|
|
curr_action_dist,
|
|
|
|
value_fn,
|
|
|
|
cur_kl_coeff,
|
2018-10-15 11:02:50 -07:00
|
|
|
valid_mask,
|
2018-07-19 15:30:36 -07:00
|
|
|
entropy_coeff=0,
|
|
|
|
clip_param=0.1,
|
2018-09-23 13:11:17 -07:00
|
|
|
vf_clip_param=0.1,
|
2018-07-19 15:30:36 -07:00
|
|
|
vf_loss_coeff=1.0,
|
|
|
|
use_gae=True):
|
2018-06-28 09:49:08 -07:00
|
|
|
"""Constructs the loss for Proximal Policy Objective.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
action_space: Environment observation space specification.
|
|
|
|
value_targets (Placeholder): Placeholder for target values; used
|
|
|
|
for GAE.
|
|
|
|
actions (Placeholder): Placeholder for actions taken
|
|
|
|
from previous model evaluation.
|
|
|
|
advantages (Placeholder): Placeholder for calculated advantages
|
|
|
|
from previous model evaluation.
|
2018-07-12 19:22:46 +02:00
|
|
|
logits (Placeholder): Placeholder for logits output from
|
2018-06-28 09:49:08 -07:00
|
|
|
previous model evaluation.
|
|
|
|
vf_preds (Placeholder): Placeholder for value function output
|
|
|
|
from previous model evaluation.
|
|
|
|
curr_action_dist (ActionDistribution): ActionDistribution
|
|
|
|
of the current model.
|
|
|
|
value_fn (Tensor): Current value function output Tensor.
|
|
|
|
cur_kl_coeff (Variable): Variable holding the current PPO KL
|
|
|
|
coefficient.
|
2018-10-15 11:02:50 -07:00
|
|
|
valid_mask (Tensor): A bool mask of valid input elements (#2992).
|
2018-06-28 09:49:08 -07:00
|
|
|
entropy_coeff (float): Coefficient of the entropy regularizer.
|
|
|
|
clip_param (float): Clip parameter
|
2018-09-23 13:11:17 -07:00
|
|
|
vf_clip_param (float): Clip parameter for the value function
|
2018-06-28 09:49:08 -07:00
|
|
|
vf_loss_coeff (float): Coefficient of the value function loss
|
|
|
|
use_gae (bool): If true, use the Generalized Advantage Estimator.
|
|
|
|
"""
|
2018-10-15 11:02:50 -07:00
|
|
|
|
|
|
|
def reduce_mean_valid(t):
|
|
|
|
return tf.reduce_mean(tf.boolean_mask(t, valid_mask))
|
|
|
|
|
2018-10-01 12:49:39 -07:00
|
|
|
dist_cls, _ = ModelCatalog.get_action_dist(action_space, {})
|
2018-07-12 19:22:46 +02:00
|
|
|
prev_dist = dist_cls(logits)
|
2018-06-28 09:49:08 -07:00
|
|
|
# Make loss functions.
|
|
|
|
logp_ratio = tf.exp(
|
|
|
|
curr_action_dist.logp(actions) - prev_dist.logp(actions))
|
|
|
|
action_kl = prev_dist.kl(curr_action_dist)
|
2018-10-15 11:02:50 -07:00
|
|
|
self.mean_kl = reduce_mean_valid(action_kl)
|
2018-06-28 09:49:08 -07:00
|
|
|
|
|
|
|
curr_entropy = curr_action_dist.entropy()
|
2018-10-15 11:02:50 -07:00
|
|
|
self.mean_entropy = reduce_mean_valid(curr_entropy)
|
2018-06-28 09:49:08 -07:00
|
|
|
|
|
|
|
surrogate_loss = tf.minimum(
|
|
|
|
advantages * logp_ratio,
|
2018-07-19 15:30:36 -07:00
|
|
|
advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
|
|
|
|
1 + clip_param))
|
2018-10-15 11:02:50 -07:00
|
|
|
self.mean_policy_loss = reduce_mean_valid(-surrogate_loss)
|
2018-06-28 09:49:08 -07:00
|
|
|
|
|
|
|
if use_gae:
|
|
|
|
vf_loss1 = tf.square(value_fn - value_targets)
|
2018-09-23 13:11:17 -07:00
|
|
|
vf_clipped = vf_preds + tf.clip_by_value(
|
|
|
|
value_fn - vf_preds, -vf_clip_param, vf_clip_param)
|
2018-06-28 09:49:08 -07:00
|
|
|
vf_loss2 = tf.square(vf_clipped - value_targets)
|
2018-07-12 19:22:46 +02:00
|
|
|
vf_loss = tf.maximum(vf_loss1, vf_loss2)
|
2018-10-15 11:02:50 -07:00
|
|
|
self.mean_vf_loss = reduce_mean_valid(vf_loss)
|
|
|
|
loss = reduce_mean_valid(
|
|
|
|
-surrogate_loss + cur_kl_coeff * action_kl +
|
|
|
|
vf_loss_coeff * vf_loss - entropy_coeff * curr_entropy)
|
2018-06-28 09:49:08 -07:00
|
|
|
else:
|
|
|
|
self.mean_vf_loss = tf.constant(0.0)
|
2018-10-15 11:02:50 -07:00
|
|
|
loss = reduce_mean_valid(-surrogate_loss +
|
|
|
|
cur_kl_coeff * action_kl -
|
|
|
|
entropy_coeff * curr_entropy)
|
2018-06-28 09:49:08 -07:00
|
|
|
self.loss = loss
|
|
|
|
|
|
|
|
|
2019-03-29 12:44:23 -07:00
|
|
|
class PPOPostprocessing(object):
|
|
|
|
"""Adds the policy logits, VF preds, and advantages to the trajectory."""
|
|
|
|
|
|
|
|
@override(TFPolicyGraph)
|
|
|
|
def extra_compute_action_fetches(self):
|
|
|
|
return dict(
|
|
|
|
TFPolicyGraph.extra_compute_action_fetches(self), **{
|
|
|
|
SampleBatch.VF_PREDS: self.value_function,
|
|
|
|
BEHAVIOUR_LOGITS: self.logits
|
|
|
|
})
|
|
|
|
|
|
|
|
@override(PolicyGraph)
|
|
|
|
def postprocess_trajectory(self,
|
|
|
|
sample_batch,
|
|
|
|
other_agent_batches=None,
|
|
|
|
episode=None):
|
|
|
|
completed = sample_batch["dones"][-1]
|
|
|
|
if completed:
|
|
|
|
last_r = 0.0
|
|
|
|
else:
|
|
|
|
next_state = []
|
|
|
|
for i in range(len(self.model.state_in)):
|
|
|
|
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
|
|
|
last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1],
|
|
|
|
sample_batch[SampleBatch.ACTIONS][-1],
|
|
|
|
sample_batch[SampleBatch.REWARDS][-1],
|
|
|
|
*next_state)
|
|
|
|
batch = compute_advantages(
|
|
|
|
sample_batch,
|
|
|
|
last_r,
|
|
|
|
self.config["gamma"],
|
|
|
|
self.config["lambda"],
|
|
|
|
use_gae=self.config["use_gae"])
|
|
|
|
return batch
|
|
|
|
|
|
|
|
|
|
|
|
class PPOPolicyGraph(LearningRateSchedule, PPOPostprocessing, TFPolicyGraph):
|
2018-07-19 15:30:36 -07:00
|
|
|
def __init__(self,
|
|
|
|
observation_space,
|
|
|
|
action_space,
|
|
|
|
config,
|
|
|
|
existing_inputs=None):
|
2018-06-28 09:49:08 -07:00
|
|
|
"""
|
|
|
|
Arguments:
|
|
|
|
observation_space: Environment observation space specification.
|
|
|
|
action_space: Environment action space specification.
|
|
|
|
config (dict): Configuration values for PPO graph.
|
|
|
|
existing_inputs (list): Optional list of tuples that specify the
|
|
|
|
placeholders upon which the graph should be built upon.
|
|
|
|
"""
|
2018-07-22 05:09:25 -07:00
|
|
|
config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
|
2018-06-28 09:49:08 -07:00
|
|
|
self.sess = tf.get_default_session()
|
|
|
|
self.action_space = action_space
|
|
|
|
self.config = config
|
|
|
|
self.kl_coeff_val = self.config["kl_coeff"]
|
|
|
|
self.kl_target = self.config["kl_target"]
|
2018-10-01 12:49:39 -07:00
|
|
|
dist_cls, logit_dim = ModelCatalog.get_action_dist(
|
|
|
|
action_space, self.config["model"])
|
2018-06-28 09:49:08 -07:00
|
|
|
|
|
|
|
if existing_inputs:
|
|
|
|
obs_ph, value_targets_ph, adv_ph, act_ph, \
|
2018-10-20 15:21:22 -07:00
|
|
|
logits_ph, vf_preds_ph, prev_actions_ph, prev_rewards_ph = \
|
|
|
|
existing_inputs[:8]
|
|
|
|
existing_state_in = existing_inputs[8:-1]
|
2018-07-17 06:55:46 +02:00
|
|
|
existing_seq_lens = existing_inputs[-1]
|
2018-06-28 09:49:08 -07:00
|
|
|
else:
|
|
|
|
obs_ph = tf.placeholder(
|
2018-07-19 15:30:36 -07:00
|
|
|
tf.float32,
|
|
|
|
name="obs",
|
|
|
|
shape=(None, ) + observation_space.shape)
|
2018-06-28 09:49:08 -07:00
|
|
|
adv_ph = tf.placeholder(
|
2018-07-19 15:30:36 -07:00
|
|
|
tf.float32, name="advantages", shape=(None, ))
|
2018-06-28 09:49:08 -07:00
|
|
|
act_ph = ModelCatalog.get_action_placeholder(action_space)
|
2018-07-12 19:22:46 +02:00
|
|
|
logits_ph = tf.placeholder(
|
|
|
|
tf.float32, name="logits", shape=(None, logit_dim))
|
2018-06-28 09:49:08 -07:00
|
|
|
vf_preds_ph = tf.placeholder(
|
2018-07-19 15:30:36 -07:00
|
|
|
tf.float32, name="vf_preds", shape=(None, ))
|
2018-07-12 19:22:46 +02:00
|
|
|
value_targets_ph = tf.placeholder(
|
2018-07-19 15:30:36 -07:00
|
|
|
tf.float32, name="value_targets", shape=(None, ))
|
2018-10-20 15:21:22 -07:00
|
|
|
prev_actions_ph = ModelCatalog.get_action_placeholder(action_space)
|
|
|
|
prev_rewards_ph = tf.placeholder(
|
|
|
|
tf.float32, [None], name="prev_reward")
|
2018-07-17 06:55:46 +02:00
|
|
|
existing_state_in = None
|
|
|
|
existing_seq_lens = None
|
2018-08-23 17:49:10 -07:00
|
|
|
self.observations = obs_ph
|
2019-02-24 15:36:13 -08:00
|
|
|
self.prev_actions = prev_actions_ph
|
|
|
|
self.prev_rewards = prev_rewards_ph
|
2018-07-17 06:55:46 +02:00
|
|
|
|
|
|
|
self.loss_in = [
|
2019-03-29 12:44:23 -07:00
|
|
|
(SampleBatch.CUR_OBS, obs_ph),
|
|
|
|
(Postprocessing.VALUE_TARGETS, value_targets_ph),
|
|
|
|
(Postprocessing.ADVANTAGES, adv_ph),
|
|
|
|
(SampleBatch.ACTIONS, act_ph),
|
|
|
|
(BEHAVIOUR_LOGITS, logits_ph),
|
|
|
|
(SampleBatch.VF_PREDS, vf_preds_ph),
|
|
|
|
(SampleBatch.PREV_ACTIONS, prev_actions_ph),
|
|
|
|
(SampleBatch.PREV_REWARDS, prev_rewards_ph),
|
2018-07-17 06:55:46 +02:00
|
|
|
]
|
2018-07-12 19:22:46 +02:00
|
|
|
self.model = ModelCatalog.get_model(
|
2018-10-20 15:21:22 -07:00
|
|
|
{
|
|
|
|
"obs": obs_ph,
|
|
|
|
"prev_actions": prev_actions_ph,
|
2018-11-29 13:33:39 -08:00
|
|
|
"prev_rewards": prev_rewards_ph,
|
|
|
|
"is_training": self._get_is_training_placeholder(),
|
2018-10-20 15:21:22 -07:00
|
|
|
},
|
|
|
|
observation_space,
|
2019-03-10 04:23:12 +01:00
|
|
|
action_space,
|
2018-07-19 15:30:36 -07:00
|
|
|
logit_dim,
|
|
|
|
self.config["model"],
|
|
|
|
state_in=existing_state_in,
|
|
|
|
seq_lens=existing_seq_lens)
|
2018-06-28 09:49:08 -07:00
|
|
|
|
|
|
|
# KL Coefficient
|
|
|
|
self.kl_coeff = tf.get_variable(
|
|
|
|
initializer=tf.constant_initializer(self.kl_coeff_val),
|
2018-07-19 15:30:36 -07:00
|
|
|
name="kl_coeff",
|
|
|
|
shape=(),
|
|
|
|
trainable=False,
|
|
|
|
dtype=tf.float32)
|
2018-06-28 09:49:08 -07:00
|
|
|
|
2018-07-12 19:22:46 +02:00
|
|
|
self.logits = self.model.outputs
|
2018-06-28 09:49:08 -07:00
|
|
|
curr_action_dist = dist_cls(self.logits)
|
|
|
|
self.sampler = curr_action_dist.sample()
|
|
|
|
if self.config["use_gae"]:
|
2018-08-23 17:49:10 -07:00
|
|
|
if self.config["vf_share_layers"]:
|
2018-10-29 19:37:27 -07:00
|
|
|
self.value_function = self.model.value_function()
|
2018-08-23 17:49:10 -07:00
|
|
|
else:
|
|
|
|
vf_config = self.config["model"].copy()
|
|
|
|
# Do not split the last layer of the value function into
|
|
|
|
# mean parameters and standard deviation parameters and
|
|
|
|
# do not make the standard deviations free variables.
|
|
|
|
vf_config["free_log_std"] = False
|
2019-02-22 11:18:51 -08:00
|
|
|
if vf_config["use_lstm"]:
|
|
|
|
vf_config["use_lstm"] = False
|
|
|
|
logger.warning(
|
|
|
|
"It is not recommended to use a LSTM model with "
|
|
|
|
"vf_share_layers=False (consider setting it to True). "
|
|
|
|
"If you want to not share layers, you can implement "
|
|
|
|
"a custom LSTM model that overrides the "
|
|
|
|
"value_function() method.")
|
2018-08-23 17:49:10 -07:00
|
|
|
with tf.variable_scope("value_function"):
|
2018-10-20 15:21:22 -07:00
|
|
|
self.value_function = ModelCatalog.get_model({
|
|
|
|
"obs": obs_ph,
|
|
|
|
"prev_actions": prev_actions_ph,
|
2018-11-29 13:33:39 -08:00
|
|
|
"prev_rewards": prev_rewards_ph,
|
|
|
|
"is_training": self._get_is_training_placeholder(),
|
2019-03-10 04:23:12 +01:00
|
|
|
}, observation_space, action_space, 1, vf_config).outputs
|
2018-08-23 17:49:10 -07:00
|
|
|
self.value_function = tf.reshape(self.value_function, [-1])
|
2018-06-28 09:49:08 -07:00
|
|
|
else:
|
2018-07-12 19:22:46 +02:00
|
|
|
self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])
|
2018-06-28 09:49:08 -07:00
|
|
|
|
2018-10-15 11:02:50 -07:00
|
|
|
if self.model.state_in:
|
|
|
|
max_seq_len = tf.reduce_max(self.model.seq_lens)
|
|
|
|
mask = tf.sequence_mask(self.model.seq_lens, max_seq_len)
|
|
|
|
mask = tf.reshape(mask, [-1])
|
|
|
|
else:
|
2018-12-20 12:27:24 +08:00
|
|
|
mask = tf.ones_like(adv_ph, dtype=tf.bool)
|
2018-10-15 11:02:50 -07:00
|
|
|
|
2018-06-28 09:49:08 -07:00
|
|
|
self.loss_obj = PPOLoss(
|
2018-07-19 15:30:36 -07:00
|
|
|
action_space,
|
|
|
|
value_targets_ph,
|
|
|
|
adv_ph,
|
|
|
|
act_ph,
|
|
|
|
logits_ph,
|
|
|
|
vf_preds_ph,
|
|
|
|
curr_action_dist,
|
|
|
|
self.value_function,
|
|
|
|
self.kl_coeff,
|
2018-10-15 11:02:50 -07:00
|
|
|
mask,
|
2018-06-28 09:49:08 -07:00
|
|
|
entropy_coeff=self.config["entropy_coeff"],
|
|
|
|
clip_param=self.config["clip_param"],
|
2018-09-23 13:11:17 -07:00
|
|
|
vf_clip_param=self.config["vf_clip_param"],
|
2018-08-23 17:49:10 -07:00
|
|
|
vf_loss_coeff=self.config["vf_loss_coeff"],
|
2018-06-28 09:49:08 -07:00
|
|
|
use_gae=self.config["use_gae"])
|
|
|
|
|
2018-09-05 12:06:13 -07:00
|
|
|
LearningRateSchedule.__init__(self, self.config["lr"],
|
2018-08-23 17:49:10 -07:00
|
|
|
self.config["lr_schedule"])
|
2018-06-28 09:49:08 -07:00
|
|
|
TFPolicyGraph.__init__(
|
2018-07-19 15:30:36 -07:00
|
|
|
self,
|
|
|
|
observation_space,
|
|
|
|
action_space,
|
|
|
|
self.sess,
|
|
|
|
obs_input=obs_ph,
|
|
|
|
action_sampler=self.sampler,
|
2019-02-13 16:25:05 -08:00
|
|
|
action_prob=curr_action_dist.sampled_action_prob(),
|
2019-02-24 15:36:13 -08:00
|
|
|
loss=self.loss_obj.loss,
|
|
|
|
model=self.model,
|
2018-07-19 15:30:36 -07:00
|
|
|
loss_inputs=self.loss_in,
|
|
|
|
state_inputs=self.model.state_in,
|
|
|
|
state_outputs=self.model.state_out,
|
2018-10-20 15:21:22 -07:00
|
|
|
prev_action_input=prev_actions_ph,
|
|
|
|
prev_reward_input=prev_rewards_ph,
|
2018-07-19 15:30:36 -07:00
|
|
|
seq_lens=self.model.seq_lens,
|
2018-07-17 06:55:46 +02:00
|
|
|
max_seq_len=config["model"]["max_seq_len"])
|
2018-07-12 19:22:46 +02:00
|
|
|
|
|
|
|
self.sess.run(tf.global_variables_initializer())
|
2018-08-23 17:49:10 -07:00
|
|
|
self.explained_variance = explained_variance(value_targets_ph,
|
|
|
|
self.value_function)
|
|
|
|
self.stats_fetches = {
|
2018-12-08 18:02:33 -08:00
|
|
|
"cur_kl_coeff": self.kl_coeff,
|
2018-08-23 17:49:10 -07:00
|
|
|
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
|
|
|
"total_loss": self.loss_obj.loss,
|
|
|
|
"policy_loss": self.loss_obj.mean_policy_loss,
|
|
|
|
"vf_loss": self.loss_obj.mean_vf_loss,
|
|
|
|
"vf_explained_var": self.explained_variance,
|
|
|
|
"kl": self.loss_obj.mean_kl,
|
|
|
|
"entropy": self.loss_obj.mean_entropy
|
|
|
|
}
|
2018-06-28 09:49:08 -07:00
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(TFPolicyGraph)
|
2018-06-28 09:49:08 -07:00
|
|
|
def copy(self, existing_inputs):
|
|
|
|
"""Creates a copy of self using existing input placeholders."""
|
2018-07-12 19:22:46 +02:00
|
|
|
return PPOPolicyGraph(
|
2018-10-15 11:02:50 -07:00
|
|
|
self.observation_space,
|
2018-07-19 15:30:36 -07:00
|
|
|
self.action_space,
|
|
|
|
self.config,
|
2018-06-28 09:49:08 -07:00
|
|
|
existing_inputs=existing_inputs)
|
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(TFPolicyGraph)
|
2019-03-21 21:34:22 -07:00
|
|
|
def gradients(self, optimizer, loss):
|
2019-02-02 22:10:58 -08:00
|
|
|
if self.config["grad_clip"] is not None:
|
|
|
|
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
|
|
|
tf.get_variable_scope().name)
|
2019-03-21 21:34:22 -07:00
|
|
|
grads = tf.gradients(loss, self.var_list)
|
2019-02-02 22:10:58 -08:00
|
|
|
self.grads, _ = tf.clip_by_global_norm(grads,
|
|
|
|
self.config["grad_clip"])
|
|
|
|
clipped_grads = list(zip(self.grads, self.var_list))
|
|
|
|
return clipped_grads
|
|
|
|
else:
|
|
|
|
return optimizer.compute_gradients(
|
2019-03-21 21:34:22 -07:00
|
|
|
loss, colocate_gradients_with_ops=True)
|
2018-07-12 19:22:46 +02:00
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(PolicyGraph)
|
2018-07-12 19:22:46 +02:00
|
|
|
def get_initial_state(self):
|
|
|
|
return self.model.state_init
|
2018-12-08 16:28:58 -08:00
|
|
|
|
|
|
|
@override(TFPolicyGraph)
|
|
|
|
def extra_compute_grad_fetches(self):
|
2019-03-27 15:40:15 -07:00
|
|
|
return {LEARNER_STATS_KEY: self.stats_fetches}
|
2018-12-08 16:28:58 -08:00
|
|
|
|
|
|
|
def update_kl(self, sampled_kl):
|
|
|
|
if sampled_kl > 2.0 * self.kl_target:
|
|
|
|
self.kl_coeff_val *= 1.5
|
|
|
|
elif sampled_kl < 0.5 * self.kl_target:
|
|
|
|
self.kl_coeff_val *= 0.5
|
|
|
|
self.kl_coeff.load(self.kl_coeff_val, session=self.sess)
|
|
|
|
return self.kl_coeff_val
|
|
|
|
|
2019-02-24 15:36:13 -08:00
|
|
|
def _value(self, ob, prev_action, prev_reward, *args):
|
|
|
|
feed_dict = {
|
|
|
|
self.observations: [ob],
|
|
|
|
self.prev_actions: [prev_action],
|
|
|
|
self.prev_rewards: [prev_reward],
|
|
|
|
self.model.seq_lens: [1]
|
|
|
|
}
|
2018-12-08 16:28:58 -08:00
|
|
|
assert len(args) == len(self.model.state_in), \
|
|
|
|
(args, self.model.state_in)
|
|
|
|
for k, v in zip(self.model.state_in, args):
|
|
|
|
feed_dict[k] = v
|
|
|
|
vf = self.sess.run(self.value_function, feed_dict)
|
|
|
|
return vf[0]
|