[rllib] Add squash_to_range model option (#2239)

* sigmoid

* squash

* squash true

* git push

* Update catalog.py
This commit is contained in:
Eric Liang 2018-06-19 19:47:26 -07:00 committed by GitHub
parent 51744459f3
commit 46cc51ce0c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 65 additions and 12 deletions

View file

@ -27,7 +27,8 @@ class SharedTorchPolicy(PolicyGraph):
self.lock = Lock()
def setup_graph(self, obs_space, action_space):
_, self.logit_dim = ModelCatalog.get_action_dist(action_space)
_, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self._model = ModelCatalog.get_torch_model(
self.registry, obs_space.shape, self.logit_dim,
self.config["model"])

View file

@ -16,7 +16,8 @@ class SharedModel(A3CTFPolicyGraph):
def _setup_graph(self, ob_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
ac_space, self.config["model"])
self._model = ModelCatalog.get_model(
self.registry, self.x, self.logit_dim, self.config["model"])
self.logits = self._model.outputs

View file

@ -17,7 +17,8 @@ class SharedModelLSTM(A3CTFPolicyGraph):
def _setup_graph(self, ob_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
ac_space, self.config["model"])
self._model = LSTM(self.x, self.logit_dim, {})
self.state_in = self._model.state_in

View file

@ -22,7 +22,8 @@ class BCPolicy(object):
def _setup_graph(self, obs_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(obs_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
ac_space, self.config["model"])
self._model = ModelCatalog.get_model(
self.registry, self.x, self.logit_dim, self.config["model"])
self.logits = self._model.outputs

View file

@ -73,10 +73,19 @@ class DiagGaussian(ActionDistribution):
second half the gaussian standard deviations.
"""
def __init__(self, inputs):
def __init__(self, inputs, low=None, high=None):
ActionDistribution.__init__(self, inputs)
mean, log_std = tf.split(inputs, 2, axis=1)
self.mean = mean
self.low = low
self.high = high
# Squash to range if specified.
# TODO(ekl) might make sense to use a beta distribution instead:
# http://proceedings.mlr.press/v70/chou17a/chou17a.pdf
if low is not None:
self.mean = low + tf.sigmoid(self.mean) * (high - low)
self.log_std = log_std
self.std = tf.exp(log_std)
@ -99,7 +108,10 @@ class DiagGaussian(ActionDistribution):
reduction_indices=[1])
def sample(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
out = self.mean + self.std * tf.random_normal(tf.shape(self.mean))
if self.low is not None:
out = tf.clip_by_value(out, self.low, self.high)
return out
class Deterministic(ActionDistribution):
@ -112,6 +124,34 @@ class Deterministic(ActionDistribution):
return self.inputs
def squash_to_range(dist_cls, low, high):
"""Squashes an action distribution to a range in (low, high).
Arguments:
dist_cls (class): ActionDistribution class to wrap.
low (float|array): Scalar value or array of values.
high (float|array): Scalar value or array of values.
"""
class SquashToRangeWrapper(dist_cls):
def __init__(self, inputs):
dist_cls.__init__(self, inputs, low=low, high=high)
def logp(self, x):
return dist_cls.logp(self, x)
def kl(self, other):
return dist_cls.kl(self, other)
def entropy(self):
return dist_cls.entropy(self)
def sample(self):
return dist_cls.sample(self)
return SquashToRangeWrapper
class MultiActionDistribution(ActionDistribution):
"""Action distribution that operates for list of actions.

View file

@ -11,7 +11,8 @@ from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
_default_registry
from ray.rllib.models.action_dist import (
Categorical, Deterministic, DiagGaussian, MultiActionDistribution)
Categorical, Deterministic, DiagGaussian, MultiActionDistribution,
squash_to_range)
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.visionnet import VisionNetwork
@ -29,6 +30,7 @@ MODEL_CONFIGS = [
"fcnet_hiddens", # Number of hidden layers for fully connected net
"free_log_std", # Documented in ray.rllib.models.Model
"channel_major", # Pytorch conv requires images to be channel-major
"squash_to_range", # Whether to squash the action output to space range
# === Options for custom models ===
"custom_preprocessor", # Name of a custom preprocessor to use
@ -51,11 +53,12 @@ class ModelCatalog(object):
"""
@staticmethod
def get_action_dist(action_space, dist_type=None):
def get_action_dist(action_space, config=None, dist_type=None):
"""Returns action distribution class and size for the given action space.
Args:
action_space (Space): Action space of the target gym env.
config (dict): Optional model config.
dist_type (str): Optional identifier of the action distribution.
Returns:
@ -66,10 +69,14 @@ class ModelCatalog(object):
# TODO(ekl) are list spaces valid?
if isinstance(action_space, list):
action_space = gym.spaces.Tuple(action_space)
config = config or {}
if isinstance(action_space, gym.spaces.Box):
if dist_type is None:
return DiagGaussian, action_space.shape[0] * 2
dist = DiagGaussian
if config.get("squash_to_range"):
dist = squash_to_range(
dist, action_space.low, action_space.high)
return dist, action_space.shape[0] * 2
elif dist_type == 'deterministic':
return Deterministic, action_space.shape[0]
elif isinstance(action_space, gym.spaces.Discrete):

View file

@ -16,7 +16,8 @@ class PGPolicyGraph(TFPolicyGraph):
# setup policy
self.x = tf.placeholder(tf.float32, shape=[None]+list(obs_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(action_space)
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.model = ModelCatalog.get_model(
registry, self.x, self.logit_dim, options=self.config["model"])
self.dist = dist_class(self.model.outputs) # logit for each action

View file

@ -54,7 +54,7 @@ class PPOEvaluator(TFMultiGPUSupport):
action_space = self.env.action_space
self.actions = ModelCatalog.get_action_placeholder(action_space)
self.distribution_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space)
action_space, config["model"])
# Log probabilities from the policy before the policy update.
self.prev_logits = tf.placeholder(
tf.float32, shape=(None, self.logit_dim))

View file

@ -12,3 +12,4 @@ pendulum-ppo:
num_sgd_iter: 10
model:
fcnet_hiddens: [64, 64]
squash_to_range: True