mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Add squash_to_range model option (#2239)
* sigmoid * squash * squash true * git push * Update catalog.py
This commit is contained in:
parent
51744459f3
commit
46cc51ce0c
9 changed files with 65 additions and 12 deletions
|
@ -27,7 +27,8 @@ class SharedTorchPolicy(PolicyGraph):
|
|||
self.lock = Lock()
|
||||
|
||||
def setup_graph(self, obs_space, action_space):
|
||||
_, self.logit_dim = ModelCatalog.get_action_dist(action_space)
|
||||
_, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
self._model = ModelCatalog.get_torch_model(
|
||||
self.registry, obs_space.shape, self.logit_dim,
|
||||
self.config["model"])
|
||||
|
|
|
@ -16,7 +16,8 @@ class SharedModel(A3CTFPolicyGraph):
|
|||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
ac_space, self.config["model"])
|
||||
self._model = ModelCatalog.get_model(
|
||||
self.registry, self.x, self.logit_dim, self.config["model"])
|
||||
self.logits = self._model.outputs
|
||||
|
|
|
@ -17,7 +17,8 @@ class SharedModelLSTM(A3CTFPolicyGraph):
|
|||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
ac_space, self.config["model"])
|
||||
self._model = LSTM(self.x, self.logit_dim, {})
|
||||
|
||||
self.state_in = self._model.state_in
|
||||
|
|
|
@ -22,7 +22,8 @@ class BCPolicy(object):
|
|||
|
||||
def _setup_graph(self, obs_space, ac_space):
|
||||
self.x = tf.placeholder(tf.float32, [None] + list(obs_space.shape))
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
ac_space, self.config["model"])
|
||||
self._model = ModelCatalog.get_model(
|
||||
self.registry, self.x, self.logit_dim, self.config["model"])
|
||||
self.logits = self._model.outputs
|
||||
|
|
|
@ -73,10 +73,19 @@ class DiagGaussian(ActionDistribution):
|
|||
second half the gaussian standard deviations.
|
||||
"""
|
||||
|
||||
def __init__(self, inputs):
|
||||
def __init__(self, inputs, low=None, high=None):
|
||||
ActionDistribution.__init__(self, inputs)
|
||||
mean, log_std = tf.split(inputs, 2, axis=1)
|
||||
self.mean = mean
|
||||
self.low = low
|
||||
self.high = high
|
||||
|
||||
# Squash to range if specified.
|
||||
# TODO(ekl) might make sense to use a beta distribution instead:
|
||||
# http://proceedings.mlr.press/v70/chou17a/chou17a.pdf
|
||||
if low is not None:
|
||||
self.mean = low + tf.sigmoid(self.mean) * (high - low)
|
||||
|
||||
self.log_std = log_std
|
||||
self.std = tf.exp(log_std)
|
||||
|
||||
|
@ -99,7 +108,10 @@ class DiagGaussian(ActionDistribution):
|
|||
reduction_indices=[1])
|
||||
|
||||
def sample(self):
|
||||
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
|
||||
out = self.mean + self.std * tf.random_normal(tf.shape(self.mean))
|
||||
if self.low is not None:
|
||||
out = tf.clip_by_value(out, self.low, self.high)
|
||||
return out
|
||||
|
||||
|
||||
class Deterministic(ActionDistribution):
|
||||
|
@ -112,6 +124,34 @@ class Deterministic(ActionDistribution):
|
|||
return self.inputs
|
||||
|
||||
|
||||
def squash_to_range(dist_cls, low, high):
|
||||
"""Squashes an action distribution to a range in (low, high).
|
||||
|
||||
Arguments:
|
||||
dist_cls (class): ActionDistribution class to wrap.
|
||||
low (float|array): Scalar value or array of values.
|
||||
high (float|array): Scalar value or array of values.
|
||||
"""
|
||||
|
||||
class SquashToRangeWrapper(dist_cls):
|
||||
def __init__(self, inputs):
|
||||
dist_cls.__init__(self, inputs, low=low, high=high)
|
||||
|
||||
def logp(self, x):
|
||||
return dist_cls.logp(self, x)
|
||||
|
||||
def kl(self, other):
|
||||
return dist_cls.kl(self, other)
|
||||
|
||||
def entropy(self):
|
||||
return dist_cls.entropy(self)
|
||||
|
||||
def sample(self):
|
||||
return dist_cls.sample(self)
|
||||
|
||||
return SquashToRangeWrapper
|
||||
|
||||
|
||||
class MultiActionDistribution(ActionDistribution):
|
||||
"""Action distribution that operates for list of actions.
|
||||
|
||||
|
|
|
@ -11,7 +11,8 @@ from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
|
|||
_default_registry
|
||||
|
||||
from ray.rllib.models.action_dist import (
|
||||
Categorical, Deterministic, DiagGaussian, MultiActionDistribution)
|
||||
Categorical, Deterministic, DiagGaussian, MultiActionDistribution,
|
||||
squash_to_range)
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.models.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.visionnet import VisionNetwork
|
||||
|
@ -29,6 +30,7 @@ MODEL_CONFIGS = [
|
|||
"fcnet_hiddens", # Number of hidden layers for fully connected net
|
||||
"free_log_std", # Documented in ray.rllib.models.Model
|
||||
"channel_major", # Pytorch conv requires images to be channel-major
|
||||
"squash_to_range", # Whether to squash the action output to space range
|
||||
|
||||
# === Options for custom models ===
|
||||
"custom_preprocessor", # Name of a custom preprocessor to use
|
||||
|
@ -51,11 +53,12 @@ class ModelCatalog(object):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_action_dist(action_space, dist_type=None):
|
||||
def get_action_dist(action_space, config=None, dist_type=None):
|
||||
"""Returns action distribution class and size for the given action space.
|
||||
|
||||
Args:
|
||||
action_space (Space): Action space of the target gym env.
|
||||
config (dict): Optional model config.
|
||||
dist_type (str): Optional identifier of the action distribution.
|
||||
|
||||
Returns:
|
||||
|
@ -66,10 +69,14 @@ class ModelCatalog(object):
|
|||
# TODO(ekl) are list spaces valid?
|
||||
if isinstance(action_space, list):
|
||||
action_space = gym.spaces.Tuple(action_space)
|
||||
|
||||
config = config or {}
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
if dist_type is None:
|
||||
return DiagGaussian, action_space.shape[0] * 2
|
||||
dist = DiagGaussian
|
||||
if config.get("squash_to_range"):
|
||||
dist = squash_to_range(
|
||||
dist, action_space.low, action_space.high)
|
||||
return dist, action_space.shape[0] * 2
|
||||
elif dist_type == 'deterministic':
|
||||
return Deterministic, action_space.shape[0]
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
|
|
|
@ -16,7 +16,8 @@ class PGPolicyGraph(TFPolicyGraph):
|
|||
|
||||
# setup policy
|
||||
self.x = tf.placeholder(tf.float32, shape=[None]+list(obs_space.shape))
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(action_space)
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
self.model = ModelCatalog.get_model(
|
||||
registry, self.x, self.logit_dim, options=self.config["model"])
|
||||
self.dist = dist_class(self.model.outputs) # logit for each action
|
||||
|
|
|
@ -54,7 +54,7 @@ class PPOEvaluator(TFMultiGPUSupport):
|
|||
action_space = self.env.action_space
|
||||
self.actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
self.distribution_class, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space)
|
||||
action_space, config["model"])
|
||||
# Log probabilities from the policy before the policy update.
|
||||
self.prev_logits = tf.placeholder(
|
||||
tf.float32, shape=(None, self.logit_dim))
|
||||
|
|
|
@ -12,3 +12,4 @@ pendulum-ppo:
|
|||
num_sgd_iter: 10
|
||||
model:
|
||||
fcnet_hiddens: [64, 64]
|
||||
squash_to_range: True
|
||||
|
|
Loading…
Add table
Reference in a new issue