mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
[rllib] Make the free_logstd param generic (#863)
* make free log std param generic * fixes * fixes
This commit is contained in:
parent
46641a642f
commit
617bc4d239
6 changed files with 30 additions and 32 deletions
|
@ -70,27 +70,27 @@ class DiagGaussian(ActionDistribution):
|
||||||
|
|
||||||
def __init__(self, inputs):
|
def __init__(self, inputs):
|
||||||
ActionDistribution.__init__(self, inputs)
|
ActionDistribution.__init__(self, inputs)
|
||||||
mean, logstd = tf.split(inputs, 2, axis=1)
|
mean, log_std = tf.split(inputs, 2, axis=1)
|
||||||
self.mean = mean
|
self.mean = mean
|
||||||
self.logstd = logstd
|
self.log_std = log_std
|
||||||
self.std = tf.exp(logstd)
|
self.std = tf.exp(log_std)
|
||||||
|
|
||||||
def logp(self, x):
|
def logp(self, x):
|
||||||
return (-0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std),
|
return (-0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std),
|
||||||
reduction_indices=[1]) -
|
reduction_indices=[1]) -
|
||||||
0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) -
|
0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) -
|
||||||
tf.reduce_sum(self.logstd, reduction_indices=[1]))
|
tf.reduce_sum(self.log_std, reduction_indices=[1]))
|
||||||
|
|
||||||
def kl(self, other):
|
def kl(self, other):
|
||||||
assert isinstance(other, DiagGaussian)
|
assert isinstance(other, DiagGaussian)
|
||||||
return tf.reduce_sum(other.logstd - self.logstd +
|
return tf.reduce_sum(other.log_std - self.log_std +
|
||||||
(tf.square(self.std) +
|
(tf.square(self.std) +
|
||||||
tf.square(self.mean - other.mean)) /
|
tf.square(self.mean - other.mean)) /
|
||||||
(2.0 * tf.square(other.std)) - 0.5,
|
(2.0 * tf.square(other.std)) - 0.5,
|
||||||
reduction_indices=[1])
|
reduction_indices=[1])
|
||||||
|
|
||||||
def entropy(self):
|
def entropy(self):
|
||||||
return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e),
|
return tf.reduce_sum(self.log_std + .5 * np.log(2.0 * np.pi * np.e),
|
||||||
reduction_indices=[1])
|
reduction_indices=[1])
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
|
|
|
@ -48,7 +48,7 @@ class ModelCatalog(object):
|
||||||
"Unsupported args: {} {}".format(action_space, dist_type))
|
"Unsupported args: {} {}".format(action_space, dist_type))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model(inputs, num_outputs, options=None):
|
def get_model(inputs, num_outputs, options=dict()):
|
||||||
"""Returns a suitable model conforming to given input and output specs.
|
"""Returns a suitable model conforming to given input and output specs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -60,9 +60,6 @@ class ModelCatalog(object):
|
||||||
model (Model): Neural network model.
|
model (Model): Neural network model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if options is None:
|
|
||||||
options = {}
|
|
||||||
|
|
||||||
obs_rank = len(inputs.get_shape()) - 1
|
obs_rank = len(inputs.get_shape()) - 1
|
||||||
|
|
||||||
if obs_rank > 1:
|
if obs_rank > 1:
|
||||||
|
@ -71,7 +68,7 @@ class ModelCatalog(object):
|
||||||
return FullyConnectedNetwork(inputs, num_outputs, options)
|
return FullyConnectedNetwork(inputs, num_outputs, options)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def ConvolutionalNetwork(inputs, num_outputs, options=None):
|
def ConvolutionalNetwork(inputs, num_outputs, options=dict()):
|
||||||
return ConvolutionalNetwork(inputs, num_outputs, options)
|
return ConvolutionalNetwork(inputs, num_outputs, options)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -19,17 +19,7 @@ def normc_initializer(std=1.0):
|
||||||
|
|
||||||
|
|
||||||
class FullyConnectedNetwork(Model):
|
class FullyConnectedNetwork(Model):
|
||||||
"""Generic fully connected network.
|
"""Generic fully connected network."""
|
||||||
|
|
||||||
Options to construct the network are passed to the _init function.
|
|
||||||
If options["free_logstd"] is True, the last half of the
|
|
||||||
output layer will be free variables that are not dependent on
|
|
||||||
inputs. This is often used if the output of the network is used
|
|
||||||
to parametrize a probability distribution. In this case, the
|
|
||||||
first half of the parameters can be interpreted as a location
|
|
||||||
parameter (like a mean) and the second half can be interpreted as
|
|
||||||
a scale parameter (like a standard deviation).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init(self, inputs, num_outputs, options):
|
def _init(self, inputs, num_outputs, options):
|
||||||
hiddens = options.get("fcnet_hiddens", [256, 256])
|
hiddens = options.get("fcnet_hiddens", [256, 256])
|
||||||
|
@ -40,9 +30,6 @@ class FullyConnectedNetwork(Model):
|
||||||
activation = tf.nn.relu
|
activation = tf.nn.relu
|
||||||
print("Constructing fcnet {} {}".format(hiddens, activation))
|
print("Constructing fcnet {} {}".format(hiddens, activation))
|
||||||
|
|
||||||
if options.get("free_logstd", False):
|
|
||||||
num_outputs = num_outputs // 2
|
|
||||||
|
|
||||||
with tf.name_scope("fc_net"):
|
with tf.name_scope("fc_net"):
|
||||||
i = 1
|
i = 1
|
||||||
last_layer = inputs
|
last_layer = inputs
|
||||||
|
@ -57,8 +44,4 @@ class FullyConnectedNetwork(Model):
|
||||||
last_layer, num_outputs,
|
last_layer, num_outputs,
|
||||||
weights_initializer=normc_initializer(0.01),
|
weights_initializer=normc_initializer(0.01),
|
||||||
activation_fn=None, scope="fc_out")
|
activation_fn=None, scope="fc_out")
|
||||||
if options.get("free_logstd", False):
|
|
||||||
logstd = tf.get_variable(name="logstd", shape=[num_outputs],
|
|
||||||
initializer=tf.zeros_initializer)
|
|
||||||
output = tf.concat([output, 0.0 * output + logstd], 1)
|
|
||||||
return output, last_layer
|
return output, last_layer
|
||||||
|
|
|
@ -2,6 +2,8 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
class Model(object):
|
class Model(object):
|
||||||
"""Defines an abstract network model for use with RLlib.
|
"""Defines an abstract network model for use with RLlib.
|
||||||
|
@ -13,6 +15,14 @@ class Model(object):
|
||||||
The last layer of the network can also be retrieved if the algorithm
|
The last layer of the network can also be retrieved if the algorithm
|
||||||
needs to further post-processing (e.g. Actor and Critic networks in A3C).
|
needs to further post-processing (e.g. Actor and Critic networks in A3C).
|
||||||
|
|
||||||
|
If options["free_log_std"] is True, the last half of the
|
||||||
|
output layer will be free variables that are not dependent on
|
||||||
|
inputs. This is often used if the output of the network is used
|
||||||
|
to parametrize a probability distribution. In this case, the
|
||||||
|
first half of the parameters can be interpreted as a location
|
||||||
|
parameter (like a mean) and the second half can be interpreted as
|
||||||
|
a scale parameter (like a standard deviation).
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
inputs (Tensor): The input placeholder for this model.
|
inputs (Tensor): The input placeholder for this model.
|
||||||
outputs (Tensor): The output vector of this model.
|
outputs (Tensor): The output vector of this model.
|
||||||
|
@ -21,8 +31,16 @@ class Model(object):
|
||||||
|
|
||||||
def __init__(self, inputs, num_outputs, options):
|
def __init__(self, inputs, num_outputs, options):
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
|
if options.get("free_log_std", False):
|
||||||
|
assert num_outputs % 2 == 0
|
||||||
|
num_outputs = num_outputs // 2
|
||||||
self.outputs, self.last_layer = self._init(
|
self.outputs, self.last_layer = self._init(
|
||||||
inputs, num_outputs, options)
|
inputs, num_outputs, options)
|
||||||
|
if options.get("free_log_std", False):
|
||||||
|
log_std = tf.get_variable(name="log_std", shape=[num_outputs],
|
||||||
|
initializer=tf.zeros_initializer)
|
||||||
|
self.outputs = tf.concat(
|
||||||
|
[self.outputs, 0.0 * self.outputs + log_std], 1)
|
||||||
|
|
||||||
def _init(self):
|
def _init(self):
|
||||||
"""Builds and returns the output and last layer of the network."""
|
"""Builds and returns the output and last layer of the network."""
|
||||||
|
|
|
@ -52,7 +52,7 @@ DEFAULT_CONFIG = {
|
||||||
"clip_param": 0.3,
|
"clip_param": 0.3,
|
||||||
# Target value for KL divergence
|
# Target value for KL divergence
|
||||||
"kl_target": 0.01,
|
"kl_target": 0.01,
|
||||||
"model": {"free_logstd": False},
|
"model": {"free_log_std": False},
|
||||||
# Number of timesteps collected in each outer loop
|
# Number of timesteps collected in each outer loop
|
||||||
"timesteps_per_batch": 40000,
|
"timesteps_per_batch": 40000,
|
||||||
# Each tasks performs rollouts until at least this
|
# Each tasks performs rollouts until at least this
|
||||||
|
|
|
@ -6,9 +6,9 @@ python train.py --env CartPole-v1 --config '{"kl_coeff": 1.0, "num_sgd_iter": 20
|
||||||
|
|
||||||
python train.py --env Walker2d-v1 --config '{"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_agents": 64}' --alg PolicyGradient --upload-dir s3://bucketname/
|
python train.py --env Walker2d-v1 --config '{"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_agents": 64}' --alg PolicyGradient --upload-dir s3://bucketname/
|
||||||
|
|
||||||
python train.py --env Humanoid-v1 --config '{"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_agents": 64, "model": {"free_logstd": true}, "use_gae": false}' --alg PolicyGradient --upload-dir s3://bucketname/
|
python train.py --env Humanoid-v1 --config '{"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_agents": 64, "model": {"free_log_std": true}, "use_gae": false}' --alg PolicyGradient --upload-dir s3://bucketname/
|
||||||
|
|
||||||
python train.py --env Humanoid-v1 --config '{"lambda": 0.95, "clip_param": 0.2, "kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "horizon": 5000, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_agents": 64, "model": {"free_logstd": true}, "write_logs": false}' --alg PolicyGradient --upload-dir s3://bucketname/
|
python train.py --env Humanoid-v1 --config '{"lambda": 0.95, "clip_param": 0.2, "kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "horizon": 5000, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_agents": 64, "model": {"free_log_std": true}, "write_logs": false}' --alg PolicyGradient --upload-dir s3://bucketname/
|
||||||
|
|
||||||
python train.py --env PongNoFrameskip-v0 --alg DQN --upload-dir s3://bucketname/
|
python train.py --env PongNoFrameskip-v0 --alg DQN --upload-dir s3://bucketname/
|
||||||
python train.py --env PongDeterministic-v0 --alg A3C --upload-dir s3://bucketname/
|
python train.py --env PongDeterministic-v0 --alg A3C --upload-dir s3://bucketname/
|
||||||
|
|
Loading…
Add table
Reference in a new issue