[rllib] Make the free_logstd param generic (#863)

* make free log std param generic

* fixes

* fixes
This commit is contained in:
Eric Liang 2017-08-24 12:43:51 -07:00 committed by Philipp Moritz
parent 46641a642f
commit 617bc4d239
6 changed files with 30 additions and 32 deletions

View file

@ -70,27 +70,27 @@ class DiagGaussian(ActionDistribution):
def __init__(self, inputs):
ActionDistribution.__init__(self, inputs)
mean, logstd = tf.split(inputs, 2, axis=1)
mean, log_std = tf.split(inputs, 2, axis=1)
self.mean = mean
self.logstd = logstd
self.std = tf.exp(logstd)
self.log_std = log_std
self.std = tf.exp(log_std)
def logp(self, x):
return (-0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std),
reduction_indices=[1]) -
0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) -
tf.reduce_sum(self.logstd, reduction_indices=[1]))
tf.reduce_sum(self.log_std, reduction_indices=[1]))
def kl(self, other):
assert isinstance(other, DiagGaussian)
return tf.reduce_sum(other.logstd - self.logstd +
return tf.reduce_sum(other.log_std - self.log_std +
(tf.square(self.std) +
tf.square(self.mean - other.mean)) /
(2.0 * tf.square(other.std)) - 0.5,
reduction_indices=[1])
def entropy(self):
return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e),
return tf.reduce_sum(self.log_std + .5 * np.log(2.0 * np.pi * np.e),
reduction_indices=[1])
def sample(self):

View file

@ -48,7 +48,7 @@ class ModelCatalog(object):
"Unsupported args: {} {}".format(action_space, dist_type))
@staticmethod
def get_model(inputs, num_outputs, options=None):
def get_model(inputs, num_outputs, options=dict()):
"""Returns a suitable model conforming to given input and output specs.
Args:
@ -60,9 +60,6 @@ class ModelCatalog(object):
model (Model): Neural network model.
"""
if options is None:
options = {}
obs_rank = len(inputs.get_shape()) - 1
if obs_rank > 1:
@ -71,7 +68,7 @@ class ModelCatalog(object):
return FullyConnectedNetwork(inputs, num_outputs, options)
@staticmethod
def ConvolutionalNetwork(inputs, num_outputs, options=None):
def ConvolutionalNetwork(inputs, num_outputs, options=dict()):
return ConvolutionalNetwork(inputs, num_outputs, options)
@staticmethod

View file

@ -19,17 +19,7 @@ def normc_initializer(std=1.0):
class FullyConnectedNetwork(Model):
"""Generic fully connected network.
Options to construct the network are passed to the _init function.
If options["free_logstd"] is True, the last half of the
output layer will be free variables that are not dependent on
inputs. This is often used if the output of the network is used
to parametrize a probability distribution. In this case, the
first half of the parameters can be interpreted as a location
parameter (like a mean) and the second half can be interpreted as
a scale parameter (like a standard deviation).
"""
"""Generic fully connected network."""
def _init(self, inputs, num_outputs, options):
hiddens = options.get("fcnet_hiddens", [256, 256])
@ -40,9 +30,6 @@ class FullyConnectedNetwork(Model):
activation = tf.nn.relu
print("Constructing fcnet {} {}".format(hiddens, activation))
if options.get("free_logstd", False):
num_outputs = num_outputs // 2
with tf.name_scope("fc_net"):
i = 1
last_layer = inputs
@ -57,8 +44,4 @@ class FullyConnectedNetwork(Model):
last_layer, num_outputs,
weights_initializer=normc_initializer(0.01),
activation_fn=None, scope="fc_out")
if options.get("free_logstd", False):
logstd = tf.get_variable(name="logstd", shape=[num_outputs],
initializer=tf.zeros_initializer)
output = tf.concat([output, 0.0 * output + logstd], 1)
return output, last_layer

View file

@ -2,6 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class Model(object):
"""Defines an abstract network model for use with RLlib.
@ -13,6 +15,14 @@ class Model(object):
The last layer of the network can also be retrieved if the algorithm
needs to further post-processing (e.g. Actor and Critic networks in A3C).
If options["free_log_std"] is True, the last half of the
output layer will be free variables that are not dependent on
inputs. This is often used if the output of the network is used
to parametrize a probability distribution. In this case, the
first half of the parameters can be interpreted as a location
parameter (like a mean) and the second half can be interpreted as
a scale parameter (like a standard deviation).
Attributes:
inputs (Tensor): The input placeholder for this model.
outputs (Tensor): The output vector of this model.
@ -21,8 +31,16 @@ class Model(object):
def __init__(self, inputs, num_outputs, options):
self.inputs = inputs
if options.get("free_log_std", False):
assert num_outputs % 2 == 0
num_outputs = num_outputs // 2
self.outputs, self.last_layer = self._init(
inputs, num_outputs, options)
if options.get("free_log_std", False):
log_std = tf.get_variable(name="log_std", shape=[num_outputs],
initializer=tf.zeros_initializer)
self.outputs = tf.concat(
[self.outputs, 0.0 * self.outputs + log_std], 1)
def _init(self):
"""Builds and returns the output and last layer of the network."""

View file

@ -52,7 +52,7 @@ DEFAULT_CONFIG = {
"clip_param": 0.3,
# Target value for KL divergence
"kl_target": 0.01,
"model": {"free_logstd": False},
"model": {"free_log_std": False},
# Number of timesteps collected in each outer loop
"timesteps_per_batch": 40000,
# Each tasks performs rollouts until at least this

View file

@ -6,9 +6,9 @@ python train.py --env CartPole-v1 --config '{"kl_coeff": 1.0, "num_sgd_iter": 20
python train.py --env Walker2d-v1 --config '{"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_agents": 64}' --alg PolicyGradient --upload-dir s3://bucketname/
python train.py --env Humanoid-v1 --config '{"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_agents": 64, "model": {"free_logstd": true}, "use_gae": false}' --alg PolicyGradient --upload-dir s3://bucketname/
python train.py --env Humanoid-v1 --config '{"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_agents": 64, "model": {"free_log_std": true}, "use_gae": false}' --alg PolicyGradient --upload-dir s3://bucketname/
python train.py --env Humanoid-v1 --config '{"lambda": 0.95, "clip_param": 0.2, "kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "horizon": 5000, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_agents": 64, "model": {"free_logstd": true}, "write_logs": false}' --alg PolicyGradient --upload-dir s3://bucketname/
python train.py --env Humanoid-v1 --config '{"lambda": 0.95, "clip_param": 0.2, "kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "horizon": 5000, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_agents": 64, "model": {"free_log_std": true}, "write_logs": false}' --alg PolicyGradient --upload-dir s3://bucketname/
python train.py --env PongNoFrameskip-v0 --alg DQN --upload-dir s3://bucketname/
python train.py --env PongDeterministic-v0 --alg A3C --upload-dir s3://bucketname/