[rllib] Document ModelV2 and clean up the models/ directory (#5277)

This commit is contained in:
Eric Liang 2019-07-27 02:08:16 -07:00 committed by GitHub
parent 9c00616cdc
commit a62c5f40f6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
60 changed files with 1107 additions and 950 deletions

View file

@ -325,6 +325,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_keras_model.py --run=DQN --stop=50
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_keras_rnn_model.py --run=PPO --stop=50
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/parametric_action_cartpole.py --run=PG --stop=50

View file

@ -101,10 +101,6 @@ Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/mas
**APPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. warning::
Keras custom models are not compatible with multi-GPU (this includes PPO in single-GPU mode). This is because the multi-GPU implementation in RLlib relies on variable scopes to implement cross-GPU support.
.. literalinclude:: ../../python/ray/rllib/agents/ppo/appo.py
:language: python
:start-after: __sphinx_doc_begin__

View file

@ -1,5 +1,5 @@
RLlib Concepts and Building Custom Algorithms
=============================================
RLlib Concepts and Custom Algorithms
====================================
This page describes the internal concepts used to implement algorithms in RLlib. You might find this useful if modifying or adding new algorithms to RLlib.

View file

@ -30,6 +30,10 @@ Custom Envs and Models
- `Registering a custom env and model <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__:
Example of defining and registering a gym env and model for use with RLlib.
- `Custom Keras model <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_keras_model.py>`__:
Example of using a custom Keras model.
- `Custom Keras RNN model <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_keras_rnn_model.py>`__:
Example of using a custom Keras RNN model.
- `Registering a custom model with supervised loss <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_loss.py>`__:
Example of defining and registering a custom model with a supervised loss.
- `Subprocess environment <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tests/test_env_with_subprocess.py>`__:

View file

@ -14,9 +14,9 @@ Default Behaviours
Built-in Models and Preprocessors
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RLlib picks default models based on a simple heuristic: a `vision network <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/visionnet.py>`__ for image observations, and a `fully connected network <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/fcnet.py>`__ for everything else. These models can be configured via the ``model`` config key, documented in the model `catalog <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/catalog.py>`__. Note that you'll probably have to configure ``conv_filters`` if your environment observations have custom sizes, e.g., ``"model": {"dim": 42, "conv_filters": [[16, [4, 4], 2], [32, [4, 4], 2], [512, [11, 11], 1]]}`` for 42x42 observations.
RLlib picks default models based on a simple heuristic: a `vision network <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/tf/visionnet_v1.py>`__ for image observations, and a `fully connected network <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/tf/fcnet_v1.py>`__ for everything else. These models can be configured via the ``model`` config key, documented in the model `catalog <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/catalog.py>`__. Note that you'll probably have to configure ``conv_filters`` if your environment observations have custom sizes, e.g., ``"model": {"dim": 42, "conv_filters": [[16, [4, 4], 2], [32, [4, 4], 2], [512, [11, 11], 1]]}`` for 42x42 observations.
In addition, if you set ``"model": {"use_lstm": true}``, then the model output will be further processed by a `LSTM cell <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/lstm.py>`__. More generally, RLlib supports the use of recurrent models for its policy gradient algorithms (A3C, PPO, PG, IMPALA), and RNN support is built into its policy evaluation utilities.
In addition, if you set ``"model": {"use_lstm": true}``, then the model output will be further processed by a `LSTM cell <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/tf/lstm.py>`__. More generally, RLlib supports the use of recurrent models for its policy gradient algorithms (A3C, PPO, PG, IMPALA), and RNN support is built into its policy evaluation utilities.
For preprocessors, RLlib tries to pick one of its built-in preprocessor based on the environment's observation space. Discrete observations are one-hot encoded, Atari observations downscaled, and Tuple and Dict observations flattened (these are unflattened and accessible via the ``input_dict`` parameter in custom models). Note that for Atari, RLlib defaults to using the `DeepMind preprocessors <https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/atari_wrappers.py>`__, which are also used by the OpenAI baselines library.
@ -30,103 +30,40 @@ The following is a list of the built-in model hyperparameters:
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
Custom Models (TensorFlow)
--------------------------
TensorFlow Models
-----------------
Custom TF models should subclass the common RLlib `model class <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/model.py>`__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``, ``is_training``), and returns a feature layer and float vector of the specified output size. You can also override the ``value_function`` method to implement a custom value branch. Additional supervised / self-supervised losses can be added via the ``custom_loss`` method. The model can then be registered and used in place of a built-in model:
.. note::
.. warning::
TFModelV2 replaces the previous ``rllib.models.Model`` class, which did not support Keras-style reuse of variables. The ``rllib.models.Model`` class is deprecated and should not be used.
Keras custom models are not compatible with multi-GPU (this includes PPO in single-GPU mode). This is because the multi-GPU implementation in RLlib relies on variable scopes to implement cross-GPU support.
Custom TF models should subclass `TFModelV2 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/tf/tf_modelv2.py>`__ to implement the ``__init__()`` and ``forward()`` methods. Forward takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``, ``is_training``), optional RNN state, and returns the model output of size ``num_outputs`` and the new state. You can also override extra methods of the model such as ``value_function`` to implement a custom value branch. Additional supervised / self-supervised losses can be added via the ``custom_loss`` method:
.. autoclass:: ray.rllib.models.tf.tf_modelv2.TFModelV2
.. automethod:: __init__
.. automethod:: forward
.. automethod:: value_function
.. automethod:: custom_loss
.. automethod:: metrics
.. automethod:: update_ops
.. automethod:: register_variables
.. automethod:: variables
.. automethod:: trainable_variables
Once implemented, the model can then be registered and used in place of a built-in model:
.. code-block:: python
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.models import ModelCatalog, Model
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
class MyModelClass(Model):
def _build_layers_v2(self, input_dict, num_outputs, options):
"""Define the layers of a custom model.
Arguments:
input_dict (dict): Dictionary of input tensors, including "obs",
"prev_action", "prev_reward", "is_training".
num_outputs (int): Output tensor must be of size
[BATCH_SIZE, num_outputs].
options (dict): Model options.
Returns:
(outputs, feature_layer): Tensors of size [BATCH_SIZE, num_outputs]
and [BATCH_SIZE, desired_feature_size].
When using dict or tuple observation spaces, you can access
the nested sub-observation batches here as well:
Examples:
>>> print(input_dict)
{'prev_actions': <tf.Tensor shape=(?,) dtype=int64>,
'prev_rewards': <tf.Tensor shape=(?,) dtype=float32>,
'is_training': <tf.Tensor shape=(), dtype=bool>,
'obs': OrderedDict([
('sensors', OrderedDict([
('front_cam', [
<tf.Tensor shape=(?, 10, 10, 3) dtype=float32>,
<tf.Tensor shape=(?, 10, 10, 3) dtype=float32>]),
('position', <tf.Tensor shape=(?, 3) dtype=float32>),
('velocity', <tf.Tensor shape=(?, 3) dtype=float32>)]))])}
"""
layer1 = slim.fully_connected(input_dict["obs"], 64, ...)
layer2 = slim.fully_connected(layer1, 64, ...)
...
return layerN, layerN_minus_1
def value_function(self):
"""Builds the value function output.
This method can be overridden to customize the implementation of the
value function (e.g., not sharing hidden layers).
Returns:
Tensor of size [BATCH_SIZE] for the value function.
"""
return tf.reshape(
linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
def custom_loss(self, policy_loss, loss_inputs):
"""Override to customize the loss function used to optimize this model.
This can be used to incorporate self-supervised losses (by defining
a loss over existing input and output tensors of this model), and
supervised losses (by defining losses over a variable-sharing copy of
this model's layers).
You can find an runnable example in examples/custom_loss.py.
Arguments:
policy_loss (Tensor): scalar policy loss from the policy.
loss_inputs (dict): map of input placeholders for rollout data.
Returns:
Scalar tensor for the customized loss for this model.
"""
return policy_loss
def custom_stats(self):
"""Override to return custom metrics from your model.
The stats will be reported as part of the learner stats, i.e.,
info:
learner:
model:
key1: metric1
key2: metric2
Returns:
Dict of string keys to scalar tensors.
"""
return {}
class MyModelClass(TFModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config, name): ...
def forward(self, input_dict, state, seq_lens): ...
def value_function(self): ...
ModelCatalog.register_custom_model("my_model", MyModelClass)
@ -138,97 +75,55 @@ Custom TF models should subclass the common RLlib `model class <https://github.c
},
})
For a full example of a custom model in code, see the `custom env example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__. You can also reference the `unit tests <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tests/test_nested_spaces.py>`__ for Tuple and Dict spaces, which show how to access nested observation fields.
For a full example of a custom model in code, see the `keras model example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_keras_model.py>`__. You can also reference the `unit tests <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tests/test_nested_spaces.py>`__ for Tuple and Dict spaces, which show how to access nested observation fields.
Custom Recurrent Models
~~~~~~~~~~~~~~~~~~~~~~~
Recurrent Models
~~~~~~~~~~~~~~~~
Instead of using the ``use_lstm: True`` option, it can be preferable use a custom recurrent model. This provides more control over postprocessing of the LSTM output and can also allow the use of multiple LSTM cells to process different portions of the input. The only difference from a normal custom model is that you have to define ``self.state_init``, ``self.state_in``, and ``self.state_out``. You can refer to the existing `lstm.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/lstm.py>`__ model as an example to implement your own model:
Instead of using the ``use_lstm: True`` option, it can be preferable use a custom recurrent model. This provides more control over postprocessing of the LSTM output and can also allow the use of multiple LSTM cells to process different portions of the input. For a RNN model it is preferred to subclass ``RecurrentTFModelV2`` to implement ``__init__()``, ``get_initial_state()``, and ``forward_rnn()``. You can check out the `custom_keras_rnn_model.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_keras_rnn_model.py>`__ model as an example to implement your own model:
.. code-block:: python
.. autoclass:: ray.rllib.models.tf.recurrent_tf_modelv2.RecurrentTFModelV2
class MyCustomLSTM(Model):
def _build_layers_v2(self, input_dict, num_outputs, options):
# Some initial layers to process inputs, shape [BATCH, OBS...].
features = some_hidden_layers(input_dict["obs"])
# Add back the nested time dimension for tf.dynamic_rnn, new shape
# will be [BATCH, MAX_SEQ_LEN, OBS...].
last_layer = add_time_dimension(features, self.seq_lens)
# Setup the LSTM cell (see lstm.py for an example)
lstm = rnn.BasicLSTMCell(256, state_is_tuple=True)
self.state_init = ...
self.state_in = ...
lstm_out, lstm_state = tf.nn.dynamic_rnn(
lstm,
last_layer,
initial_state=...,
sequence_length=self.seq_lens,
time_major=False,
dtype=tf.float32)
self.state_out = list(lstm_state)
# Drop the time dimension again so back to shape [BATCH, OBS...].
# Note that we retain the zero padding (see issue #2992).
last_layer = tf.reshape(lstm_out, [-1, cell_size])
logits = linear(last_layer, num_outputs, "action",
normc_initializer(0.01))
return logits, last_layer
.. automethod:: __init__
.. automethod:: forward_rnn
.. automethod:: get_initial_state
Batch Normalization
~~~~~~~~~~~~~~~~~~~
You can use ``tf.layers.batch_normalization(x, training=input_dict["is_training"])`` to add batch norm layers to your custom model: `code example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/batch_norm_model.py>`__. RLlib will automatically run the update ops for the batch norm layers during optimization (see `tf_policy.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/policy/tf_policy.py>`__ and `multi_gpu_impl.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/optimizers/multi_gpu_impl.py>`__ for the exact handling of these updates).
Custom Models (PyTorch)
-----------------------
In case RLlib does not properly detect the update ops for your custom model, you can override the ``update_ops()`` method to return the list of ops to run for updates.
PyTorch Models
--------------
Similarly, you can create and register custom PyTorch models for use with PyTorch-based algorithms (e.g., A2C, PG, QMIX). See these examples of `fully connected <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/torch/fcnet.py>`__, `convolutional <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/torch/visionnet.py>`__, and `recurrent <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/qmix/model.py>`__ torch models.
.. autoclass:: ray.rllib.models.torch.torch_modelv2.TorchModelV2
.. automethod:: __init__
.. automethod:: forward
.. automethod:: value_function
.. automethod:: custom_loss
.. automethod:: metrics
.. automethod:: get_initial_state
Once implemented, the model can then be registered and used in place of a built-in model:
.. code-block:: python
import torch.nn as nn
import ray
from ray.rllib.agents import a3c
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
class CustomTorchModel(TorchModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(CustomTorchModel, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
... # setup hidden layers
def forward(self, input_dict, state, seq_lens):
"""Call the model with the given input tensors and state.
Any complex observations (dicts, tuples, etc.) will be unpacked by
__call__ before being passed to forward(). To access the flattened
observation tensor, refer to input_dict["obs_flat"].
This method can be called any number of times. In eager execution,
each call to forward() will eagerly evaluate the model. In symbolic
execution, each call to forward creates a computation graph that
operates over the variables of this model (i.e., shares weights).
Custom models should override this instead of __call__.
Arguments:
input_dict (dict): dictionary of input tensors, including "obs",
"obs_flat", "prev_action", "prev_reward", "is_training"
state (list): list of state tensors with sizes matching those
returned by get_initial_state + the batch dimension
seq_lens (Tensor): 1d tensor holding input sequence lengths
Returns:
(outputs, state): The model output tensor of size
[BATCH, num_outputs]
"""
obs = input_dict["obs"]
...
return logits, state
class CustomTorchModel(nn.Module, TorchModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config, name): ...
def forward(self, input_dict, state, seq_lens): ...
def value_function(self): ...
ModelCatalog.register_custom_model("my_model", CustomTorchModel)
@ -274,7 +169,7 @@ Supervised Model Losses
You can mix supervised losses into any RLlib algorithm through custom models. For example, you can add an imitation learning loss on expert experiences, or a self-supervised autoencoder loss within the model. These losses can be defined over either policy evaluation inputs, or data read from `offline storage <rllib-offline.html#input-pipeline-for-supervised-losses>`__.
**TensorFlow**: To add a supervised loss to a custom TF model, you need to override the ``custom_loss()`` method. This method takes in the existing policy loss for the algorithm, which you can add your own supervised loss to before returning. For debugging, you can also return a dictionary of scalar tensors in the ``custom_metrics()`` method. Here is a `runnable example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_loss.py>`__ of adding an imitation loss to CartPole training that is defined over a `offline dataset <rllib-offline.html#input-pipeline-for-supervised-losses>`__.
**TensorFlow**: To add a supervised loss to a custom TF model, you need to override the ``custom_loss()`` method. This method takes in the existing policy loss for the algorithm, which you can add your own supervised loss to before returning. For debugging, you can also return a dictionary of scalar tensors in the ``metrics()`` method. Here is a `runnable example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_loss.py>`__ of adding an imitation loss to CartPole training that is defined over a `offline dataset <rllib-offline.html#input-pipeline-for-supervised-losses>`__.
**PyTorch**: There is no explicit API for adding losses to custom torch models. However, you can modify the loss in the policy definition directly. Like for TF models, offline datasets can be incorporated by creating an input reader and calling ``reader.next()`` in the loss forward pass.
@ -301,79 +196,43 @@ Custom models can be used to work with environments where (1) the set of valid a
.. code-block:: python
class MyParamActionModel(Model):
def _build_layers_v2(self, input_dict, num_outputs, options):
class ParametricActionsModel(TFModelV2):
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
true_obs_shape=(4,),
action_embed_size=2):
super(ParametricActionsModel, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
self.action_embed_model = FullyConnectedNetwork(...)
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
avail_actions = input_dict["obs"]["avail_actions"]
action_mask = input_dict["obs"]["action_mask"]
output = FullyConnectedNetwork(
input_dict["obs"]["real_obs"], num_outputs=action_embedding_sz)
# Compute the predicted action embedding
action_embed, _ = self.action_embed_model({
"obs": input_dict["obs"]["cart"]
})
# Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
# avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
intent_vector = tf.expand_dims(output, 1)
intent_vector = tf.expand_dims(action_embed, 1)
# Shape of logits is [BATCH, MAX_ACTIONS].
# Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2)
# Mask out invalid actions (use tf.float32.min for stability)
inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
masked_logits = inf_mask + action_logits
return masked_logits, last_layer
return action_logits + inf_mask, state
Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_action_cartpole.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/parametric_action_cartpole.py>`__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms.
Customizing Policies
-------------------------
For deeper customization of algorithms, you can modify the policies of the trainer classes. Here's an example of extending the DDPG policy to specify custom sub-network modules:
.. code-block:: python
from ray.rllib.models import ModelCatalog
from ray.rllib.agents.ddpg.ddpg_policy import DDPGTFPolicy as BaseDDPGTFPolicy
class CustomPNetwork(object):
def __init__(self, dim_actions, hiddens, activation):
action_out = ...
# Use sigmoid layer to bound values within (0, 1)
# shape of action_scores is [batch_size, dim_actions]
self.action_scores = layers.fully_connected(
action_out, num_outputs=dim_actions, activation_fn=tf.nn.sigmoid)
class CustomQNetwork(object):
def __init__(self, action_inputs, hiddens, activation):
q_out = ...
self.value = layers.fully_connected(
q_out, num_outputs=1, activation_fn=None)
class CustomDDPGTFPolicy(BaseDDPGTFPolicy):
def _build_p_network(self, obs):
return CustomPNetwork(
self.dim_actions,
self.config["actor_hiddens"],
self.config["actor_hidden_activation"]).action_scores
def _build_q_network(self, obs, actions):
return CustomQNetwork(
actions,
self.config["critic_hiddens"],
self.config["critic_hidden_activation"]).value
Then, you can create an trainer with your custom policy by:
.. code-block:: python
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer
from custom_policy import CustomDDPGTFPolicy
DDPGTrainer._policy = CustomDDPGTFPolicy
trainer = DDPGTrainer(...)
In this example we overrode existing methods of the existing DDPG policy, i.e., `_build_q_network`, `_build_p_network`, `_build_action_network`, `_build_actor_critic_loss`, but you can also replace the entire graph class entirely.
Model-Based Rollouts
~~~~~~~~~~~~~~~~~~~~

View file

@ -44,12 +44,11 @@ Environments
Models and Preprocessors
------------------------
* `RLlib Models and Preprocessors Overview <rllib-models.html>`__
* `Custom Models (TensorFlow) <rllib-models.html#custom-models-tensorflow>`__
* `Custom Models (PyTorch) <rllib-models.html#custom-models-pytorch>`__
* `TensorFlow Models <rllib-models.html#tensorflow-models>`__
* `PyTorch Models <rllib-models.html#pytorch-models>`__
* `Custom Preprocessors <rllib-models.html#custom-preprocessors>`__
* `Supervised Model Losses <rllib-models.html#supervised-model-losses>`__
* `Variable-length / Parametric Action Spaces <rllib-models.html#variable-length-parametric-action-spaces>`__
* `Customizing Policies <rllib-models.html#customizing-policys>`__
Algorithms
----------
@ -95,8 +94,8 @@ Offline Datasets
* `Input API <rllib-offline.html#input-api>`__
* `Output API <rllib-offline.html#output-api>`__
Concepts and Building Custom Algorithms
---------------------------------------
Concepts and Custom Algorithms
------------------------------
* `Policies <rllib-concepts.html>`__
- `Building Policies in TensorFlow <rllib-concepts.html#building-policies-in-tensorflow>`__

View file

@ -11,7 +11,8 @@ from ray.rllib.agents.dqn.distributional_q_model import DistributionalQModel
from ray.rllib.agents.dqn.simple_q_policy import ExplorationStateMixin, \
TargetNetworkMixin
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models import ModelCatalog, Categorical
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.policy.tf_policy_template import build_tf_policy

View file

@ -117,7 +117,6 @@ def make_aggregators_and_optimizer(workers, config):
optimizer = AsyncSamplesOptimizer(
workers,
lr=config["lr"],
num_envs_per_worker=config["num_envs_per_worker"],
num_gpus=config["num_gpus"],
sample_batch_size=config["sample_batch_size"],
train_batch_size=config["train_batch_size"],

View file

@ -34,7 +34,7 @@ from __future__ import print_function
import collections
from ray.rllib.models.action_dist import Categorical
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.utils import try_import_tf
tf = try_import_tf()

View file

@ -12,7 +12,7 @@ import gym
import ray
from ray.rllib.agents.impala import vtrace
from ray.rllib.models.action_dist import Categorical
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.tf_policy import LearningRateSchedule, \

View file

@ -14,7 +14,7 @@ from ray.rllib.agents.impala import vtrace
from ray.rllib.agents.impala.vtrace_policy import _make_time_major, \
BEHAVIOUR_LOGITS, VTraceTFPolicy
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.action_dist import Categorical
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.utils import try_import_tf

View file

@ -59,9 +59,6 @@ DEFAULT_CONFIG = with_common_config({
# Uses the sync samples optimizer instead of the multi-gpu one. This does
# not support minibatches.
"simple_optimizer": False,
# (Deprecated) Use the sampling behavior as of 0.6, which launches extra
# sampling tasks for performance but can waste a large portion of samples.
"straggler_mitigation": False,
})
# __sphinx_doc_end__
# yapf: enable
@ -83,7 +80,6 @@ def choose_policy_optimizer(workers, config):
num_envs_per_worker=config["num_envs_per_worker"],
train_batch_size=config["train_batch_size"],
standardize_fields=["advantages"],
straggler_mitigation=config["straggler_mitigation"],
shuffle_sequences=config["shuffle_sequences"])

View file

@ -6,7 +6,7 @@ import unittest
import numpy as np
from numpy.testing import assert_allclose
from ray.rllib.models.action_dist import Categorical
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.agents.ppo.utils import flatten, concatenate
from ray.rllib.utils import try_import_tf

View file

@ -10,13 +10,14 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
class RNNModel(TorchModelV2):
class RNNModel(TorchModelV2, nn.Module):
"""The default RNN model for QMIX."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(RNNModel, self).__init__(obs_space, action_space, num_outputs,
model_config, name)
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
self.obs_size = _get_size(obs_space)
self.rnn_hidden_dim = model_config["lstm_cell_size"]
self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)

View file

@ -14,11 +14,10 @@ import ray
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
from ray.rllib.agents.qmix.model import RNNModel, _get_size
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy import Policy, TupleActions
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.action_dist import TupleActions
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.lstm import chop_into_sequences
from ray.rllib.models.model import _unpack_obs
from ray.rllib.env.constants import GROUP_REWARDS
from ray.rllib.utils.annotations import override

View file

@ -143,7 +143,8 @@ COMMON_CONFIG = {
"train_batch_size": 200,
# Whether to rollout "complete_episodes" or "truncate_episodes"
"batch_mode": "truncate_episodes",
# (Deprecated) Use a background thread for sampling (slightly off-policy)
# Use a background thread for sampling (slightly off-policy, usually not
# advisable to turn on unless your env specifically requires it)
"sample_async": False,
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter"
"observation_filter": "NoFilter",

View file

@ -13,10 +13,10 @@ from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
from ray.rllib.evaluation.sample_batch_builder import \
MultiAgentSampleBatchBuilder
from ray.rllib.policy.policy import TupleActions
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
from ray.rllib.models.action_dist import TupleActions
from ray.rllib.offline import InputReader
from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import log_once, summarize

View file

@ -8,7 +8,7 @@ import argparse
import ray
from ray import tune
from ray.rllib.models import Model, ModelCatalog
from ray.rllib.models.misc import normc_initializer
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.utils import try_import_tf
tf = try_import_tf()

View file

@ -14,13 +14,18 @@ from __future__ import print_function
import numpy as np
import gym
from ray.rllib.models import FullyConnectedNetwork, Model, ModelCatalog
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
from gym.spaces import Discrete, Box
import ray
from ray import tune
from ray.rllib.utils import try_import_tf
from ray.tune import grid_search
tf = try_import_tf()
class SimpleCorridor(gym.Env):
"""Example of a custom env in which you have to walk down a corridor.
@ -48,18 +53,22 @@ class SimpleCorridor(gym.Env):
return [self.cur_pos], 1 if done else 0, done, {}
class CustomModel(Model):
"""Example of a custom model.
class CustomModel(TFModelV2):
"""Example of a custom model that just delegates to a fc-net."""
This model just delegates to the built-in fcnet.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(CustomModel, self).__init__(obs_space, action_space, num_outputs,
model_config, name)
self.model = FullyConnectedNetwork(obs_space, action_space,
num_outputs, model_config, name)
self.register_variables(self.model.variables())
def _build_layers_v2(self, input_dict, num_outputs, options):
self.obs_in = input_dict["obs"]
self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space,
self.action_space, num_outputs,
options)
return self.fcnet.outputs, self.fcnet.last_layer
def forward(self, input_dict, state, seq_lens):
return self.model.forward(input_dict, state, seq_lens)
def value_function(self):
return self.model.value_function()
if __name__ == "__main__":
@ -77,6 +86,7 @@ if __name__ == "__main__":
"model": {
"custom_model": "my_model",
},
"vf_share_layers": True,
"lr": grid_search([1e-2, 1e-4, 1e-6]), # try different lrs
"num_workers": 1, # parallelism
"env_config": {

View file

@ -9,7 +9,7 @@ import argparse
import ray
from ray import tune
from ray.rllib.models import ModelCatalog
from ray.rllib.models.misc import normc_initializer
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.agents.dqn.distributional_q_model import DistributionalQModel
from ray.rllib.utils import try_import_tf
@ -17,7 +17,7 @@ from ray.rllib.utils import try_import_tf
tf = try_import_tf()
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="SimpleQ") # Try PG, PPO, DQN
parser.add_argument("--run", type=str, default="DQN") # Try PG, PPO, DQN
parser.add_argument("--stop", type=int, default=200)
@ -49,7 +49,6 @@ class MyKerasModel(TFModelV2):
self.register_variables(self.base_model.variables)
def forward(self, input_dict, state, seq_lens):
self.prev_input = input_dict
model_out, self._value_out = self.base_model(input_dict["obs"])
return model_out, state
@ -84,7 +83,6 @@ class MyKerasQModel(DistributionalQModel):
# Implement the core forward method
def forward(self, input_dict, state, seq_lens):
self.prev_input = input_dict
model_out = self.base_model(input_dict["obs"])
return model_out, state

View file

@ -18,8 +18,9 @@ import os
import ray
from ray import tune
from ray.rllib.models import (Categorical, FullyConnectedNetwork, Model,
ModelCatalog)
from ray.rllib.models import Model, ModelCatalog
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
from ray.rllib.models.model import restore_original_dimensions
from ray.rllib.offline import JsonReader
from ray.rllib.utils import try_import_tf

View file

@ -8,7 +8,8 @@ import random
import ray
from ray import tune
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.models import FullyConnectedNetwork, Model, ModelCatalog
from ray.rllib.models import Model, ModelCatalog
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils import try_import_tf

View file

@ -26,8 +26,10 @@ from gym.spaces import Box, Discrete, Dict
import ray
from ray import tune
from ray.rllib.models import Model, ModelCatalog
from ray.rllib.models.misc import normc_initializer
from ray.rllib.agents.dqn.distributional_q_model import DistributionalQModel
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.tune.registry import register_env
from ray.rllib.utils import try_import_tf
@ -111,7 +113,7 @@ class ParametricActionCartpole(gym.Env):
return obs, rew, done, info
class ParametricActionsModel(Model):
class ParametricActionsModel(DistributionalQModel, TFModelV2):
"""Parametric action model that handles the dot product and masking.
This assumes the outputs are logits for a single Categorical action dist.
@ -120,46 +122,45 @@ class ParametricActionsModel(Model):
exercise to the reader.
"""
def _build_layers_v2(self, input_dict, num_outputs, options):
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
true_obs_shape=(4, ),
action_embed_size=2,
**kw):
super(ParametricActionsModel, self).__init__(
obs_space, action_space, num_outputs, model_config, name, **kw)
self.action_embed_model = FullyConnectedNetwork(
Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size,
model_config, name + "_action_embed")
self.register_variables(self.action_embed_model.variables())
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
avail_actions = input_dict["obs"]["avail_actions"]
action_mask = input_dict["obs"]["action_mask"]
action_embed_size = avail_actions.shape[2].value
if num_outputs != avail_actions.shape[1].value:
raise ValueError(
"This model assumes num outputs is equal to max avail actions",
num_outputs, avail_actions)
# Standard FC net component.
last_layer = input_dict["obs"]["cart"]
hiddens = [256, 256]
for i, size in enumerate(hiddens):
label = "fc{}".format(i)
last_layer = tf.layers.dense(
last_layer,
size,
kernel_initializer=normc_initializer(1.0),
activation=tf.nn.tanh,
name=label)
output = tf.layers.dense(
last_layer,
action_embed_size,
kernel_initializer=normc_initializer(0.01),
activation=None,
name="fc_out")
# Compute the predicted action embedding
action_embed, _ = self.action_embed_model({
"obs": input_dict["obs"]["cart"]
})
# Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
# avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
intent_vector = tf.expand_dims(output, 1)
intent_vector = tf.expand_dims(action_embed, 1)
# Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2)
# Mask out invalid actions (use tf.float32.min for stability)
inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
masked_logits = inf_mask + action_logits
return action_logits + inf_mask, state
return masked_logits, last_layer
def value_function(self):
return self.action_embed_model.value_function()
if __name__ == "__main__":
@ -168,22 +169,17 @@ if __name__ == "__main__":
ModelCatalog.register_custom_model("pa_model", ParametricActionsModel)
register_env("pa_cartpole", lambda _: ParametricActionCartpole(10))
if args.run == "PPO":
if args.run == "DQN":
cfg = {
"observation_filter": "NoFilter", # don't filter the action list
"vf_share_layers": True, # don't create duplicate value model
}
elif args.run in ["SimpleQ", "DQN"]:
cfg = {
"hiddens": [], # important: don't postprocess the action scores
# TODO(ekl) we could support dueling if the model in this example
# was ModelV2 and only emitted -inf values on get_q_values().
# The problem with ModelV1 is that the model outputs
# are used as state scores and hence cause blowup to inf.
# TODO(ekl) we need to set these to prevent the masked values
# from being further processed in DistributionalQModel, which
# would mess up the masking. It is possible to support these if we
# defined a a custom DistributionalQModel that is aware of masking.
"hiddens": [],
"dueling": False,
}
else:
cfg = {} # PG, IMPALA, A2C, etc.
cfg = {}
tune.run(
args.run,
stop={

View file

@ -1,23 +1,12 @@
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
from ray.rllib.models.extra_spaces import Simplex
from ray.rllib.models.action_dist import (
ActionDistribution, Categorical, DiagGaussian, Deterministic, Dirichlet)
from ray.rllib.models.model import Model
from ray.rllib.models.preprocessors import Preprocessor
from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.lstm import LSTM
__all__ = [
"ActionDistribution",
"Categorical",
"DiagGaussian",
"Deterministic",
"Dirichlet",
"ModelCatalog",
"Model",
"Preprocessor",
"FullyConnectedNetwork",
"LSTM",
"MODEL_DEFAULTS",
"Simplex",
]

View file

@ -2,24 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import distutils.version
import numpy as np
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
if tf:
if hasattr(tf, "__version__"):
version = tf.__version__
else:
version = tf.VERSION
use_tf150_api = (distutils.version.LooseVersion(version) >=
distutils.version.LooseVersion("1.5.0"))
else:
use_tf150_api = False
from ray.rllib.utils.annotations import DeveloperAPI
@DeveloperAPI
@ -33,7 +16,11 @@ class ActionDistribution(object):
@DeveloperAPI
def __init__(self, inputs):
self.inputs = inputs
self.sample_op = self._build_sample_op()
@DeveloperAPI
def sample(self):
"""Draw a sample from the action distribution."""
raise NotImplementedError
@DeveloperAPI
def logp(self, x):
@ -50,25 +37,6 @@ class ActionDistribution(object):
"""The entropy of the action distribution."""
raise NotImplementedError
@DeveloperAPI
def _build_sample_op(self):
"""Implement this instead of sample(), to enable op reuse.
This is needed since the sample op is non-deterministic and is shared
between sample() and sampled_action_prob().
"""
raise NotImplementedError
@DeveloperAPI
def sample(self):
"""Draw a sample from the action distribution."""
return self.sample_op
@DeveloperAPI
def sampled_action_prob(self):
"""Returns the log probability of the sampled action."""
return tf.exp(self.logp(self.sample_op))
def multi_kl(self, other):
"""The KL-divergence between two action distributions.
@ -84,262 +52,3 @@ class ActionDistribution(object):
MultiDiscrete. TODO(ekl) consider removing this.
"""
return self.entropy()
class Categorical(ActionDistribution):
"""Categorical distribution for discrete action spaces."""
@override(ActionDistribution)
def logp(self, x):
return -tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=self.inputs, labels=tf.cast(x, tf.int32))
@override(ActionDistribution)
def entropy(self):
if use_tf150_api:
a0 = self.inputs - tf.reduce_max(
self.inputs, reduction_indices=[1], keepdims=True)
else:
a0 = self.inputs - tf.reduce_max(
self.inputs, reduction_indices=[1], keep_dims=True)
ea0 = tf.exp(a0)
if use_tf150_api:
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keepdims=True)
else:
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (tf.log(z0) - a0), reduction_indices=[1])
@override(ActionDistribution)
def kl(self, other):
if use_tf150_api:
a0 = self.inputs - tf.reduce_max(
self.inputs, reduction_indices=[1], keepdims=True)
a1 = other.inputs - tf.reduce_max(
other.inputs, reduction_indices=[1], keepdims=True)
else:
a0 = self.inputs - tf.reduce_max(
self.inputs, reduction_indices=[1], keep_dims=True)
a1 = other.inputs - tf.reduce_max(
other.inputs, reduction_indices=[1], keep_dims=True)
ea0 = tf.exp(a0)
ea1 = tf.exp(a1)
if use_tf150_api:
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keepdims=True)
z1 = tf.reduce_sum(ea1, reduction_indices=[1], keepdims=True)
else:
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True)
z1 = tf.reduce_sum(ea1, reduction_indices=[1], keep_dims=True)
p0 = ea0 / z0
return tf.reduce_sum(
p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), reduction_indices=[1])
@override(ActionDistribution)
def _build_sample_op(self):
return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1)
class MultiCategorical(ActionDistribution):
"""Categorical distribution for discrete action spaces."""
def __init__(self, inputs, input_lens):
self.cats = [
Categorical(input_)
for input_ in tf.split(inputs, input_lens, axis=1)
]
self.sample_op = self._build_sample_op()
@override(ActionDistribution)
def logp(self, actions):
# If tensor is provided, unstack it into list
if isinstance(actions, tf.Tensor):
actions = tf.unstack(tf.cast(actions, tf.int32), axis=1)
logps = tf.stack(
[cat.logp(act) for cat, act in zip(self.cats, actions)])
return tf.reduce_sum(logps, axis=0)
@override(ActionDistribution)
def multi_entropy(self):
return tf.stack([cat.entropy() for cat in self.cats], axis=1)
@override(ActionDistribution)
def entropy(self):
return tf.reduce_sum(self.multi_entropy(), axis=1)
@override(ActionDistribution)
def multi_kl(self, other):
return [cat.kl(oth_cat) for cat, oth_cat in zip(self.cats, other.cats)]
@override(ActionDistribution)
def kl(self, other):
return tf.reduce_sum(self.multi_kl(other), axis=1)
@override(ActionDistribution)
def _build_sample_op(self):
return tf.stack([cat.sample() for cat in self.cats], axis=1)
class DiagGaussian(ActionDistribution):
"""Action distribution where each vector element is a gaussian.
The first half of the input vector defines the gaussian means, and the
second half the gaussian standard deviations.
"""
def __init__(self, inputs):
mean, log_std = tf.split(inputs, 2, axis=1)
self.mean = mean
self.log_std = log_std
self.std = tf.exp(log_std)
ActionDistribution.__init__(self, inputs)
@override(ActionDistribution)
def logp(self, x):
return (-0.5 * tf.reduce_sum(
tf.square((x - self.mean) / self.std), reduction_indices=[1]) -
0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) -
tf.reduce_sum(self.log_std, reduction_indices=[1]))
@override(ActionDistribution)
def kl(self, other):
assert isinstance(other, DiagGaussian)
return tf.reduce_sum(
other.log_std - self.log_std +
(tf.square(self.std) + tf.square(self.mean - other.mean)) /
(2.0 * tf.square(other.std)) - 0.5,
reduction_indices=[1])
@override(ActionDistribution)
def entropy(self):
return tf.reduce_sum(
.5 * self.log_std + .5 * np.log(2.0 * np.pi * np.e),
reduction_indices=[1])
@override(ActionDistribution)
def _build_sample_op(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
class Deterministic(ActionDistribution):
"""Action distribution that returns the input values directly.
This is similar to DiagGaussian with standard deviation zero.
"""
@override(ActionDistribution)
def sampled_action_prob(self):
return 1.0
@override(ActionDistribution)
def _build_sample_op(self):
return self.inputs
class MultiActionDistribution(ActionDistribution):
"""Action distribution that operates for list of actions.
Args:
inputs (Tensor list): A list of tensors from which to compute samples.
"""
def __init__(self, inputs, action_space, child_distributions, input_lens):
self.input_lens = input_lens
split_inputs = tf.split(inputs, self.input_lens, axis=1)
child_list = []
for i, distribution in enumerate(child_distributions):
child_list.append(distribution(split_inputs[i]))
self.child_distributions = child_list
@override(ActionDistribution)
def logp(self, x):
split_indices = []
for dist in self.child_distributions:
if isinstance(dist, Categorical):
split_indices.append(1)
else:
split_indices.append(tf.shape(dist.sample())[1])
split_list = tf.split(x, split_indices, axis=1)
for i, distribution in enumerate(self.child_distributions):
# Remove extra categorical dimension
if isinstance(distribution, Categorical):
split_list[i] = tf.cast(
tf.squeeze(split_list[i], axis=-1), tf.int32)
log_list = np.asarray([
distribution.logp(split_x) for distribution, split_x in zip(
self.child_distributions, split_list)
])
return np.sum(log_list)
@override(ActionDistribution)
def kl(self, other):
kl_list = np.asarray([
distribution.kl(other_distribution)
for distribution, other_distribution in zip(
self.child_distributions, other.child_distributions)
])
return np.sum(kl_list)
@override(ActionDistribution)
def entropy(self):
entropy_list = np.array(
[s.entropy() for s in self.child_distributions])
return np.sum(entropy_list)
@override(ActionDistribution)
def sample(self):
return TupleActions([s.sample() for s in self.child_distributions])
@override(ActionDistribution)
def sampled_action_prob(self):
p = self.child_distributions[0].sampled_action_prob()
for c in self.child_distributions[1:]:
p *= c.sampled_action_prob()
return p
TupleActions = namedtuple("TupleActions", ["batches"])
class Dirichlet(ActionDistribution):
"""Dirichlet distribution for continuous actions that are between
[0,1] and sum to 1.
e.g. actions that represent resource allocation."""
def __init__(self, inputs):
"""Input is a tensor of logits. The exponential of logits is used to
parametrize the Dirichlet distribution as all parameters need to be
positive. An arbitrary small epsilon is added to the concentration
parameters to be zero due to numerical error.
See issue #4440 for more details.
"""
self.epsilon = 1e-7
concentration = tf.exp(inputs) + self.epsilon
self.dist = tf.distributions.Dirichlet(
concentration=concentration,
validate_args=True,
allow_nan_stats=False,
)
ActionDistribution.__init__(self, concentration)
@override(ActionDistribution)
def logp(self, x):
# Support of Dirichlet are positive real numbers. x is already be
# an array of positive number, but we clip to avoid zeros due to
# numerical errors.
x = tf.maximum(x, self.epsilon)
x = x / tf.reduce_sum(x, axis=-1, keepdims=True)
return self.dist.log_prob(x)
@override(ActionDistribution)
def entropy(self):
return self.dist.entropy()
@override(ActionDistribution)
def kl(self, other):
return self.dist.kl_divergence(other.dist)
@override(ActionDistribution)
def _build_sample_op(self):
return self.dist.sample()

View file

@ -11,18 +11,18 @@ from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
_global_registry
from ray.rllib.models.extra_spaces import Simplex
from ray.rllib.models.action_dist import (Categorical, MultiCategorical,
Deterministic, DiagGaussian,
MultiActionDistribution, Dirichlet)
from ray.rllib.models.torch_action_dist import (TorchCategorical,
TorchDiagGaussian)
from ray.rllib.models.tf.modelv1_compat import make_v1_wrapper
from ray.rllib.models.torch.torch_action_dist import (TorchCategorical,
TorchDiagGaussian)
from ray.rllib.models.tf.tf_action_dist import (
Categorical, MultiCategorical, Deterministic, DiagGaussian,
MultiActionDistribution, Dirichlet)
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.visionnet import VisionNetwork
from ray.rllib.models.lstm import LSTM
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
from ray.rllib.models.tf.lstm_v1 import LSTM
from ray.rllib.models.tf.modelv1_compat import make_v1_wrapper
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.visionnet_v1 import VisionNetwork
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils import try_import_tf
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.error import UnsupportedSpaceException
@ -204,12 +204,13 @@ class ModelCatalog(object):
" not supported".format(action_space))
@staticmethod
@DeveloperAPI
def get_model_v2(obs_space,
action_space,
num_outputs,
model_config,
framework,
name=None,
name="default_model",
model_interface=None,
default_model=None,
**model_kwargs):
@ -289,126 +290,6 @@ class ModelCatalog(object):
raise NotImplementedError(
"Framework must be 'tf' or 'torch': {}".format(framework))
@staticmethod
def _wrap_if_needed(model_cls, model_interface):
assert issubclass(model_cls, TFModelV2)
if not model_interface or issubclass(model_cls, model_interface):
return model_cls
class wrapper(model_interface, model_cls):
pass
name = "{}_as_{}".format(model_cls.__name__, model_interface.__name__)
wrapper.__name__ = name
wrapper.__qualname__ = name
return wrapper
@staticmethod
@DeveloperAPI
def get_model(input_dict,
obs_space,
action_space,
num_outputs,
options,
state_in=None,
seq_lens=None):
"""Returns a suitable model conforming to given input and output specs.
Args:
input_dict (dict): Dict of input tensors to the model, including
the observation under the "obs" key.
obs_space (Space): Observation space of the target gym env.
action_space (Space): Action space of the target gym env.
num_outputs (int): The size of the output vector of the model.
options (dict): Optional args to pass to the model constructor.
state_in (list): Optional RNN state in tensors.
seq_lens (Tensor): Optional RNN sequence length tensor.
Returns:
model (models.Model): Neural network model.
"""
assert isinstance(input_dict, dict)
options = options or MODEL_DEFAULTS
model = ModelCatalog._get_model(input_dict, obs_space, action_space,
num_outputs, options, state_in,
seq_lens)
if options.get("use_lstm"):
copy = dict(input_dict)
copy["obs"] = model.last_layer
feature_space = gym.spaces.Box(
-1, 1, shape=(model.last_layer.shape[1], ))
model = LSTM(copy, feature_space, action_space, num_outputs,
options, state_in, seq_lens)
logger.debug(
"Created model {}: ({} of {}, {}, {}, {}) -> {}, {}".format(
model, input_dict, obs_space, action_space, state_in, seq_lens,
model.outputs, model.state_out))
model._validate_output_shape()
return model
@staticmethod
def _get_model(input_dict, obs_space, action_space, num_outputs, options,
state_in, seq_lens):
if options.get("custom_model"):
model = options["custom_model"]
logger.debug("Using custom model {}".format(model))
return _global_registry.get(RLLIB_MODEL, model)(
input_dict,
obs_space,
action_space,
num_outputs,
options,
state_in=state_in,
seq_lens=seq_lens)
obs_rank = len(input_dict["obs"].shape) - 1
if obs_rank > 1:
return VisionNetwork(input_dict, obs_space, action_space,
num_outputs, options)
return FullyConnectedNetwork(input_dict, obs_space, action_space,
num_outputs, options)
@staticmethod
@DeveloperAPI
def get_torch_model(obs_space,
num_outputs,
options=None,
default_model_cls=None):
raise DeprecationWarning("Please use get_model_v2() instead.")
def _get_default_torch_model_v2(obs_space, action_space, num_outputs,
model_config, name):
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
PyTorchFCNet)
from ray.rllib.models.torch.visionnet import (VisionNetwork as
PyTorchVisionNet)
model_config = model_config or MODEL_DEFAULTS
if model_config.get("use_lstm"):
raise NotImplementedError(
"LSTM auto-wrapping not implemented for torch")
if isinstance(obs_space, gym.spaces.Discrete):
obs_rank = 1
else:
obs_rank = len(obs_space.shape)
if obs_rank > 1:
return PyTorchVisionNet(obs_space, action_space, num_outputs,
model_config, name)
return PyTorchFCNet(obs_space, action_space, num_outputs, model_config,
name)
@staticmethod
@DeveloperAPI
def get_preprocessor(env, options=None):
@ -480,3 +361,108 @@ class ModelCatalog(object):
model_class (type): Python class of the model.
"""
_global_registry.register(RLLIB_MODEL, model_name, model_class)
@staticmethod
def _wrap_if_needed(model_cls, model_interface):
assert issubclass(model_cls, TFModelV2)
if not model_interface or issubclass(model_cls, model_interface):
return model_cls
class wrapper(model_interface, model_cls):
pass
name = "{}_as_{}".format(model_cls.__name__, model_interface.__name__)
wrapper.__name__ = name
wrapper.__qualname__ = name
return wrapper
@staticmethod
def _get_default_torch_model_v2(obs_space, action_space, num_outputs,
model_config, name):
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
PyTorchFCNet)
from ray.rllib.models.torch.visionnet import (VisionNetwork as
PyTorchVisionNet)
model_config = model_config or MODEL_DEFAULTS
if model_config.get("use_lstm"):
raise NotImplementedError(
"LSTM auto-wrapping not implemented for torch")
if isinstance(obs_space, gym.spaces.Discrete):
obs_rank = 1
else:
obs_rank = len(obs_space.shape)
if obs_rank > 1:
return PyTorchVisionNet(obs_space, action_space, num_outputs,
model_config, name)
return PyTorchFCNet(obs_space, action_space, num_outputs, model_config,
name)
@staticmethod
def get_model(input_dict,
obs_space,
action_space,
num_outputs,
options,
state_in=None,
seq_lens=None):
"""Deprecated: use get_model_v2() instead."""
assert isinstance(input_dict, dict)
options = options or MODEL_DEFAULTS
model = ModelCatalog._get_model(input_dict, obs_space, action_space,
num_outputs, options, state_in,
seq_lens)
if options.get("use_lstm"):
copy = dict(input_dict)
copy["obs"] = model.last_layer
feature_space = gym.spaces.Box(
-1, 1, shape=(model.last_layer.shape[1], ))
model = LSTM(copy, feature_space, action_space, num_outputs,
options, state_in, seq_lens)
logger.debug(
"Created model {}: ({} of {}, {}, {}, {}) -> {}, {}".format(
model, input_dict, obs_space, action_space, state_in, seq_lens,
model.outputs, model.state_out))
model._validate_output_shape()
return model
@staticmethod
def _get_model(input_dict, obs_space, action_space, num_outputs, options,
state_in, seq_lens):
if options.get("custom_model"):
model = options["custom_model"]
logger.debug("Using custom model {}".format(model))
return _global_registry.get(RLLIB_MODEL, model)(
input_dict,
obs_space,
action_space,
num_outputs,
options,
state_in=state_in,
seq_lens=seq_lens)
obs_rank = len(input_dict["obs"].shape) - 1
if obs_rank > 1:
return VisionNetwork(input_dict, obs_space, action_space,
num_outputs, options)
return FullyConnectedNetwork(input_dict, obs_space, action_space,
num_outputs, options)
@staticmethod
def get_torch_model(obs_space,
num_outputs,
options=None,
default_model_cls=None):
raise DeprecationWarning("Please use get_model_v2() instead.")

View file

@ -3,51 +3,20 @@ from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import logging
import gym
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.models.tf.misc import linear, normc_initializer
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
logger = logging.getLogger(__name__)
# Deprecated: use TFModelV2 instead
class Model(object):
"""Defines an abstract network model for use with RLlib.
This class is deprecated: please use TFModelV2 instead.
Models convert input tensors to a number of output features. These features
can then be interpreted by ActionDistribution classes to determine
e.g. agent action values.
The last layer of the network can also be retrieved if the algorithm
needs to further post-processing (e.g. Actor and Critic networks in A3C).
Attributes:
input_dict (dict): Dictionary of input tensors, including "obs",
"prev_action", "prev_reward", "is_training".
outputs (Tensor): The output vector of this model, of shape
[BATCH_SIZE, num_outputs].
last_layer (Tensor): The feature layer right before the model output,
of shape [BATCH_SIZE, f].
state_init (list): List of initial recurrent state tensors (if any).
state_in (list): List of input recurrent state tensors (if any).
state_out (list): List of output recurrent state tensors (if any).
seq_lens (Tensor): The tensor input for RNN sequence lengths. This
defaults to a Tensor of [1] * len(batch) in the non-RNN case.
If `options["free_log_std"]` is True, the last half of the
output layer will be free variables that are not dependent on
inputs. This is often used if the output of the network is used
to parametrize a probability distribution. In this case, the
first half of the parameters can be interpreted as a location
parameter (like a mean) and the second half can be interpreted as
a scale parameter (like a standard deviation).
"""
"""This class is deprecated, please use TFModelV2 instead."""
def __init__(self,
input_dict,

View file

@ -145,7 +145,13 @@ class ModelV2(object):
restored["obs"] = restore_original_dimensions(
input_dict["obs"], self.obs_space, self.framework)
restored["obs_flat"] = input_dict["obs"]
outputs, state = self.forward(restored, state or [], seq_lens)
res = self.forward(restored, state or [], seq_lens)
if ((not isinstance(res, list) and not isinstance(res, tuple))
or len(res) != 2):
raise ValueError(
"forward() must return a tuple of (output, state) tensors, "
"got {}".format(res))
outputs, state = res
try:
shape = outputs.shape

View file

@ -3,14 +3,14 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.models.model import Model
from ray.rllib.models.misc import normc_initializer, get_activation_fn
from ray.rllib.models.tf.misc import normc_initializer, get_activation_fn
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
# TODO(ekl) rewrite this using ModelV2
# Deprecated: see as an alternative models/tf/fcnet_v2.py
class FullyConnectedNetwork(Model):
"""Generic fully connected network."""

View file

@ -0,0 +1,87 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.misc import normc_initializer, get_activation_fn
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
class FullyConnectedNetwork(TFModelV2):
"""Generic fully connected network implemented in ModelV2 API.
TODO(ekl): should make this the default fcnet in the future."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(FullyConnectedNetwork, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
activation = get_activation_fn(model_config.get("fcnet_activation"))
hiddens = model_config.get("fcnet_hiddens")
no_final_linear = model_config.get("no_final_linear")
vf_share_layers = model_config.get("vf_share_layers")
inputs = tf.keras.layers.Input(
shape=obs_space.shape, name="observations")
last_layer = inputs
i = 1
if no_final_linear:
# the last layer is adjusted to be of size num_outputs
for size in hiddens[:-1]:
last_layer = tf.keras.layers.Dense(
size,
name="fc_{}".format(i),
activation=activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
i += 1
layer_out = tf.keras.layers.Dense(
num_outputs,
name="fc_out",
activation=activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
else:
# the last layer is a linear to size num_outputs
for size in hiddens:
last_layer = tf.keras.layers.Dense(
size,
name="fc_{}".format(i),
activation=activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
i += 1
layer_out = tf.keras.layers.Dense(
num_outputs,
name="fc_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(last_layer)
if not vf_share_layers:
# build a parallel set of hidden layers for the value net
last_layer = inputs
i = 1
for size in hiddens:
last_layer = tf.keras.layers.Dense(
size,
name="value_fc_{}".format(i),
activation=activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
i += 1
value_out = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(last_layer)
self.base_model = tf.keras.Model(inputs, [layer_out, value_out])
self.register_variables(self.base_model.variables)
def forward(self, input_dict, state, seq_lens):
model_out, self._value_out = self.base_model(input_dict["obs"])
return model_out, state
def value_function(self):
return tf.reshape(self._value_out, [-1])

View file

@ -0,0 +1,79 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from ray.rllib.models.model import Model
from ray.rllib.models.tf.misc import linear, normc_initializer
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
# Deprecated: see as an alternative models/tf/recurrent_tf_modelv2.py
class LSTM(Model):
"""Adds a LSTM cell on top of some other model output.
Uses a linear layer at the end for output.
Important: we assume inputs is a padded batch of sequences denoted by
self.seq_lens. See add_time_dimension() for more information.
"""
@override(Model)
def _build_layers_v2(self, input_dict, num_outputs, options):
cell_size = options.get("lstm_cell_size")
if options.get("lstm_use_prev_action_reward"):
action_dim = int(
np.product(
input_dict["prev_actions"].get_shape().as_list()[1:]))
features = tf.concat(
[
input_dict["obs"],
tf.reshape(
tf.cast(input_dict["prev_actions"], tf.float32),
[-1, action_dim]),
tf.reshape(input_dict["prev_rewards"], [-1, 1]),
],
axis=1)
else:
features = input_dict["obs"]
last_layer = add_time_dimension(features, self.seq_lens)
# Setup the LSTM cell
lstm = tf.nn.rnn_cell.LSTMCell(cell_size, state_is_tuple=True)
self.state_init = [
np.zeros(lstm.state_size.c, np.float32),
np.zeros(lstm.state_size.h, np.float32)
]
# Setup LSTM inputs
if self.state_in:
c_in, h_in = self.state_in
else:
c_in = tf.placeholder(
tf.float32, [None, lstm.state_size.c], name="c")
h_in = tf.placeholder(
tf.float32, [None, lstm.state_size.h], name="h")
self.state_in = [c_in, h_in]
# Setup LSTM outputs
state_in = tf.nn.rnn_cell.LSTMStateTuple(c_in, h_in)
lstm_out, lstm_state = tf.nn.dynamic_rnn(
lstm,
last_layer,
initial_state=state_in,
sequence_length=self.seq_lens,
time_major=False,
dtype=tf.float32)
self.state_out = list(lstm_state)
# Compute outputs
last_layer = tf.reshape(lstm_out, [-1, cell_size])
logits = linear(last_layer, num_outputs, "action",
normc_initializer(0.01))
return logits, last_layer

View file

@ -7,7 +7,7 @@ import numpy as np
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.models.tf.misc import linear, normc_initializer
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_tf
from ray.rllib.utils.tf_ops import scope_vars

View file

@ -2,15 +2,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.models.lstm import add_time_dimension
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
@DeveloperAPI
class RecurrentTFModelV2(TFModelV2):
"""Helper class to simplify implementing RNN models with TFModelV2.
@ -19,6 +20,38 @@ class RecurrentTFModelV2(TFModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
"""Initialize a TFModelV2.
Here is an example implementation for a subclass
``MyRNNClass(RecurrentTFModelV2)``::
def __init__(self, *args, **kwargs):
super(MyModelClass, self).__init__(*args, **kwargs)
cell_size = 256
# Define input layers
input_layer = tf.keras.layers.Input(
shape=(None, obs_space.shape[0]))
state_in_h = tf.keras.layers.Input(shape=(256, ))
state_in_c = tf.keras.layers.Input(shape=(256, ))
seq_in = tf.keras.layers.Input(shape=())
# Send to LSTM cell
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
cell_size, return_sequences=True, return_state=True,
name="lstm")(
inputs=input_layer,
mask=tf.sequence_mask(seq_in),
initial_state=[state_in_h, state_in_c])
output_layer = tf.keras.layers.Dense(...)(lstm_out)
# Create the RNN model
self.rnn_model = tf.keras.Model(
inputs=[input_layer, seq_in, state_in_h, state_in_c],
outputs=[output_layer, state_h, state_c])
self.register_variables(self.rnn_model.variables)
self.rnn_model.summary()
"""
TFModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
@ -44,8 +77,27 @@ class RecurrentTFModelV2(TFModelV2):
(outputs, new_state): The model output tensor of shape
[B, T, num_outputs] and the list of new state tensors each with
shape [B, size].
Sample implementation for the ``MyRNNClass`` example::
def forward_rnn(self, inputs, state, seq_lens):
model_out, h, c = self.rnn_model([inputs, seq_lens] + state)
return model_out, [h, c]
"""
raise NotImplementedError("You must implement this for a RNN model")
def get_initial_state(self):
"""Get the initial recurrent state values for the model.
Returns:
list of np.array objects, if any
Sample implementation for the ``MyRNNClass`` example::
def get_initial_state(self):
return [
np.zeros(self.cell_size, np.float32),
np.zeros(self.cell_size, np.float32),
]
"""
raise NotImplementedError("You must implement this for a RNN model")

View file

@ -0,0 +1,280 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.policy import TupleActions
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
@DeveloperAPI
class TFActionDistribution(ActionDistribution):
"""TF-specific extensions for building action distributions."""
@DeveloperAPI
def __init__(self, inputs):
super(TFActionDistribution, self).__init__(inputs)
self.sample_op = self._build_sample_op()
@DeveloperAPI
def _build_sample_op(self):
"""Implement this instead of sample(), to enable op reuse.
This is needed since the sample op is non-deterministic and is shared
between sample() and sampled_action_prob().
"""
raise NotImplementedError
@DeveloperAPI
def sample(self):
"""Draw a sample from the action distribution."""
return self.sample_op
@DeveloperAPI
def sampled_action_prob(self):
"""Returns the log probability of the sampled action."""
return tf.exp(self.logp(self.sample_op))
class Categorical(TFActionDistribution):
"""Categorical distribution for discrete action spaces."""
@override(ActionDistribution)
def logp(self, x):
return -tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=self.inputs, labels=tf.cast(x, tf.int32))
@override(ActionDistribution)
def entropy(self):
a0 = self.inputs - tf.reduce_max(
self.inputs, reduction_indices=[1], keep_dims=True)
ea0 = tf.exp(a0)
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (tf.log(z0) - a0), reduction_indices=[1])
@override(ActionDistribution)
def kl(self, other):
a0 = self.inputs - tf.reduce_max(
self.inputs, reduction_indices=[1], keep_dims=True)
a1 = other.inputs - tf.reduce_max(
other.inputs, reduction_indices=[1], keep_dims=True)
ea0 = tf.exp(a0)
ea1 = tf.exp(a1)
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True)
z1 = tf.reduce_sum(ea1, reduction_indices=[1], keep_dims=True)
p0 = ea0 / z0
return tf.reduce_sum(
p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), reduction_indices=[1])
@override(TFActionDistribution)
def _build_sample_op(self):
return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1)
class MultiCategorical(TFActionDistribution):
"""Categorical distribution for discrete action spaces."""
def __init__(self, inputs, input_lens):
self.cats = [
Categorical(input_)
for input_ in tf.split(inputs, input_lens, axis=1)
]
self.sample_op = self._build_sample_op()
@override(ActionDistribution)
def logp(self, actions):
# If tensor is provided, unstack it into list
if isinstance(actions, tf.Tensor):
actions = tf.unstack(tf.cast(actions, tf.int32), axis=1)
logps = tf.stack(
[cat.logp(act) for cat, act in zip(self.cats, actions)])
return tf.reduce_sum(logps, axis=0)
@override(ActionDistribution)
def multi_entropy(self):
return tf.stack([cat.entropy() for cat in self.cats], axis=1)
@override(ActionDistribution)
def entropy(self):
return tf.reduce_sum(self.multi_entropy(), axis=1)
@override(ActionDistribution)
def multi_kl(self, other):
return [cat.kl(oth_cat) for cat, oth_cat in zip(self.cats, other.cats)]
@override(ActionDistribution)
def kl(self, other):
return tf.reduce_sum(self.multi_kl(other), axis=1)
@override(TFActionDistribution)
def _build_sample_op(self):
return tf.stack([cat.sample() for cat in self.cats], axis=1)
class DiagGaussian(TFActionDistribution):
"""Action distribution where each vector element is a gaussian.
The first half of the input vector defines the gaussian means, and the
second half the gaussian standard deviations.
"""
def __init__(self, inputs):
mean, log_std = tf.split(inputs, 2, axis=1)
self.mean = mean
self.log_std = log_std
self.std = tf.exp(log_std)
TFActionDistribution.__init__(self, inputs)
@override(ActionDistribution)
def logp(self, x):
return (-0.5 * tf.reduce_sum(
tf.square((x - self.mean) / self.std), reduction_indices=[1]) -
0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) -
tf.reduce_sum(self.log_std, reduction_indices=[1]))
@override(ActionDistribution)
def kl(self, other):
assert isinstance(other, DiagGaussian)
return tf.reduce_sum(
other.log_std - self.log_std +
(tf.square(self.std) + tf.square(self.mean - other.mean)) /
(2.0 * tf.square(other.std)) - 0.5,
reduction_indices=[1])
@override(ActionDistribution)
def entropy(self):
return tf.reduce_sum(
.5 * self.log_std + .5 * np.log(2.0 * np.pi * np.e),
reduction_indices=[1])
@override(TFActionDistribution)
def _build_sample_op(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
class Deterministic(TFActionDistribution):
"""Action distribution that returns the input values directly.
This is similar to DiagGaussian with standard deviation zero.
"""
@override(TFActionDistribution)
def sampled_action_prob(self):
return 1.0
@override(TFActionDistribution)
def _build_sample_op(self):
return self.inputs
class MultiActionDistribution(TFActionDistribution):
"""Action distribution that operates for list of actions.
Args:
inputs (Tensor list): A list of tensors from which to compute samples.
"""
def __init__(self, inputs, action_space, child_distributions, input_lens):
self.input_lens = input_lens
split_inputs = tf.split(inputs, self.input_lens, axis=1)
child_list = []
for i, distribution in enumerate(child_distributions):
child_list.append(distribution(split_inputs[i]))
self.child_distributions = child_list
@override(ActionDistribution)
def logp(self, x):
split_indices = []
for dist in self.child_distributions:
if isinstance(dist, Categorical):
split_indices.append(1)
else:
split_indices.append(tf.shape(dist.sample())[1])
split_list = tf.split(x, split_indices, axis=1)
for i, distribution in enumerate(self.child_distributions):
# Remove extra categorical dimension
if isinstance(distribution, Categorical):
split_list[i] = tf.cast(
tf.squeeze(split_list[i], axis=-1), tf.int32)
log_list = np.asarray([
distribution.logp(split_x) for distribution, split_x in zip(
self.child_distributions, split_list)
])
return np.sum(log_list)
@override(ActionDistribution)
def kl(self, other):
kl_list = np.asarray([
distribution.kl(other_distribution)
for distribution, other_distribution in zip(
self.child_distributions, other.child_distributions)
])
return np.sum(kl_list)
@override(ActionDistribution)
def entropy(self):
entropy_list = np.array(
[s.entropy() for s in self.child_distributions])
return np.sum(entropy_list)
@override(ActionDistribution)
def sample(self):
return TupleActions([s.sample() for s in self.child_distributions])
@override(TFActionDistribution)
def sampled_action_prob(self):
p = self.child_distributions[0].sampled_action_prob()
for c in self.child_distributions[1:]:
p *= c.sampled_action_prob()
return p
class Dirichlet(TFActionDistribution):
"""Dirichlet distribution for continuous actions that are between
[0,1] and sum to 1.
e.g. actions that represent resource allocation."""
def __init__(self, inputs):
"""Input is a tensor of logits. The exponential of logits is used to
parametrize the Dirichlet distribution as all parameters need to be
positive. An arbitrary small epsilon is added to the concentration
parameters to be zero due to numerical error.
See issue #4440 for more details.
"""
self.epsilon = 1e-7
concentration = tf.exp(inputs) + self.epsilon
self.dist = tf.distributions.Dirichlet(
concentration=concentration,
validate_args=True,
allow_nan_stats=False,
)
TFActionDistribution.__init__(self, concentration)
@override(ActionDistribution)
def logp(self, x):
# Support of Dirichlet are positive real numbers. x is already be
# an array of positive number, but we clip to avoid zeros due to
# numerical errors.
x = tf.maximum(x, self.epsilon)
x = x / tf.reduce_sum(x, axis=-1, keepdims=True)
return self.dist.log_prob(x)
@override(ActionDistribution)
def entropy(self):
return self.dist.entropy()
@override(ActionDistribution)
def kl(self, other):
return self.dist.kl_divergence(other.dist)
@override(TFActionDistribution)
def _build_sample_op(self):
return self.dist.sample()

View file

@ -18,6 +18,22 @@ class TFModelV2(ModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
"""Initialize a TFModelV2.
Here is an example implementation for a subclass
``MyModelClass(TFModelV2)``::
def __init__(self, *args, **kwargs):
super(MyModelClass, self).__init__(*args, **kwargs)
input_layer = tf.keras.layers.Input(...)
hidden_layer = tf.keras.layers.Dense(...)(input_layer)
output_layer = tf.keras.layers.Dense(...)(hidden_layer)
value_layer = tf.keras.layers.Dense(...)(hidden_layer)
self.base_model = tf.keras.Model(
input_layer, [output_layer, value_layer])
self.register_variables(self.base_model.variables)
"""
ModelV2.__init__(
self,
obs_space,
@ -28,6 +44,52 @@ class TFModelV2(ModelV2):
framework="tf")
self.var_list = []
def forward(self, input_dict, state, seq_lens):
"""Call the model with the given input tensors and state.
Any complex observations (dicts, tuples, etc.) will be unpacked by
__call__ before being passed to forward(). To access the flattened
observation tensor, refer to input_dict["obs_flat"].
This method can be called any number of times. In eager execution,
each call to forward() will eagerly evaluate the model. In symbolic
execution, each call to forward creates a computation graph that
operates over the variables of this model (i.e., shares weights).
Custom models should override this instead of __call__.
Arguments:
input_dict (dict): dictionary of input tensors, including "obs",
"obs_flat", "prev_action", "prev_reward", "is_training"
state (list): list of state tensors with sizes matching those
returned by get_initial_state + the batch dimension
seq_lens (Tensor): 1d tensor holding input sequence lengths
Returns:
(outputs, state): The model output tensor of size
[BATCH, num_outputs]
Sample implementation for the ``MyModelClass`` example::
def forward(self, input_dict, state, seq_lens):
model_out, self._value_out = self.base_model(input_dict["obs"])
return model_out, state
"""
raise NotImplementedError
def value_function(self):
"""Return the value function estimate for the most recent forward pass.
Returns:
value estimate tensor of shape [BATCH].
Sample implementation for the ``MyModelClass`` example::
def value_function(self):
return self._value_out
"""
raise NotImplementedError
def update_ops(self):
"""Return the list of update ops for this model.

View file

@ -3,7 +3,7 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.models.model import Model
from ray.rllib.models.misc import get_activation_fn, flatten
from ray.rllib.models.tf.misc import get_activation_fn, flatten
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_tf

View file

@ -14,13 +14,14 @@ from ray.rllib.utils.annotations import override
logger = logging.getLogger(__name__)
class FullyConnectedNetwork(TorchModelV2):
class FullyConnectedNetwork(TorchModelV2, nn.Module):
"""Generic fully connected network."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(FullyConnectedNetwork, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
hiddens = model_config.get("fcnet_hiddens")
activation = _get_activation_fn(model_config.get("fcnet_activation"))

View file

@ -9,14 +9,32 @@ from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class TorchModelV2(ModelV2, nn.Module):
class TorchModelV2(ModelV2):
"""Torch version of ModelV2.
Note that this class by itself is not a valid model unless you
implement forward() in a subclass."""
inherit from nn.Module and implement forward() in a subclass."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
"""Initialize a TorchModelV2.
Here is an example implementation for a subclass
``MyModelClass(TorchModelV2, nn.Module)``::
def __init__(self, *args, **kwargs):
TorchModelV2.__init__(self, *args, **kwargs)
nn.Module.__init__(self)
self._hidden_layers = nn.Sequential(...)
self._logits = ...
self._value_branch = ...
"""
if not isinstance(self, nn.Module):
raise ValueError(
"Subclasses of TorchModelV2 must also inherit from "
"nn.Module, e.g., MyModel(TorchModelV2, nn.Module)")
ModelV2.__init__(
self,
obs_space,
@ -25,4 +43,50 @@ class TorchModelV2(ModelV2, nn.Module):
model_config,
name,
framework="torch")
nn.Module.__init__(self)
def forward(self, input_dict, state, seq_lens):
"""Call the model with the given input tensors and state.
Any complex observations (dicts, tuples, etc.) will be unpacked by
__call__ before being passed to forward(). To access the flattened
observation tensor, refer to input_dict["obs_flat"].
This method can be called any number of times. In eager execution,
each call to forward() will eagerly evaluate the model. In symbolic
execution, each call to forward creates a computation graph that
operates over the variables of this model (i.e., shares weights).
Custom models should override this instead of __call__.
Arguments:
input_dict (dict): dictionary of input tensors, including "obs",
"obs_flat", "prev_action", "prev_reward", "is_training"
state (list): list of state tensors with sizes matching those
returned by get_initial_state + the batch dimension
seq_lens (Tensor): 1d tensor holding input sequence lengths
Returns:
(outputs, state): The model output tensor of size
[BATCH, num_outputs]
Sample implementation for the ``MyModelClass`` example::
def forward(self, input_dict, state, seq_lens):
features = self._hidden_layers(input_dict["obs"])
self._value_out = self._value_branch(features)
return self._logits(features), state
"""
raise NotImplementedError
def value_function(self):
"""Return the value function estimate for the most recent forward pass.
Returns:
value estimate tensor of shape [BATCH].
Sample implementation for the ``MyModelClass`` example::
def value_function(self):
return self._value_out
"""
raise NotImplementedError

View file

@ -7,17 +7,18 @@ import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import normc_initializer, valid_padding, \
SlimConv2d, SlimFC
from ray.rllib.models.visionnet import _get_filter_config
from ray.rllib.models.tf.visionnet_v1 import _get_filter_config
from ray.rllib.utils.annotations import override
class VisionNetwork(TorchModelV2):
class VisionNetwork(TorchModelV2, nn.Module):
"""Generic vision network."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(VisionNetwork, self).__init__(obs_space, action_space,
num_outputs, model_config, name)
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
filters = model_config.get("conv_filters")
if not filters:

View file

@ -58,6 +58,20 @@ class AggregationWorkerBase(object):
def __init__(self, initial_weights_obj_id, remote_workers,
max_sample_requests_in_flight_per_worker, replay_proportion,
replay_buffer_num_slots, train_batch_size, sample_batch_size):
"""Initialize an aggregator.
Arguments:
initial_weights_obj_id (ObjectID): initial worker weights
remote_workers (list): set of remote workers assigned to this agg
max_sample_request_in_flight_per_worker (int): max queue size per
worker
replay_proportion (float): ratio of replay to sampled outputs
replay_buffer_num_slots (int): max number of sample batches to
store in the replay buffer
train_batch_size (int): size of batches to learn on
sample_batch_size (int): size of batches to sample from workers
"""
self.broadcasted_weights = initial_weights_obj_id
self.remote_workers = remote_workers
self.sample_batch_size = sample_batch_size

View file

@ -27,6 +27,19 @@ class LearnerThread(threading.Thread):
def __init__(self, local_worker, minibatch_buffer_size, num_sgd_iter,
learner_queue_size, learner_queue_timeout):
"""Initialize the learner thread.
Arguments:
local_worker (RolloutWorker): process local rollout worker holding
policies this thread will call learn_on_batch() on
minibatch_buffer_size (int): max number of train batches to store
in the minibatching buffer
num_sgd_iter (int): number of passes to learn on per train batch
learner_queue_size (int): max size of queue of inbound
train batches to this thread
learner_queue_timeout (int): raise an exception if the queue has
been empty for this long in seconds
"""
threading.Thread.__init__(self)
self.learner_queue_size = WindowStat("size", 50)
self.local_worker = local_worker

View file

@ -19,7 +19,7 @@ class MinibatchBuffer(object):
size: Max number of data items to buffer.
timeout: Queue timeout
num_passes: Max num times each data item should be emitted.
"""
"""
self.inqueue = inqueue
self.size = size
self.timeout = timeout

View file

@ -42,6 +42,25 @@ class TFMultiGPULearner(LearnerThread):
learner_queue_timeout=300,
num_data_load_threads=16,
_fake_gpus=False):
"""Initialize a multi-gpu learner thread.
Arguments:
local_worker (RolloutWorker): process local rollout worker holding
policies this thread will call learn_on_batch() on
num_gpus (int): number of GPUs to use for data-parallel SGD
lr (float): learning rate
train_batch_size (int): size of batches to learn on
num_data_loader_buffers (int): number of buffers to load data into
in parallel. Each buffer is of size of train_batch_size and
increases GPU memory usage proportionally.
minibatch_buffer_size (int): max number of train batches to store
in the minibatching buffer
num_sgd_iter (int): number of passes to learn on per train batch
learner_queue_size (int): max size of queue of inbound
train batches to this thread
num_data_loader_threads (int): number of threads to use to load
data into GPU memory in parallel
"""
LearnerThread.__init__(self, local_worker, minibatch_buffer_size,
num_sgd_iter, learner_queue_size,
learner_queue_timeout)

View file

@ -37,6 +37,22 @@ class TreeAggregator(Aggregator):
train_batch_size=500,
sample_batch_size=50,
broadcast_interval=5):
"""Initialize a tree aggregator.
Arguments:
workers (WorkerSet): set of all workers
num_aggregation_workers (int): number of intermediate actors to
use for data aggregation
max_sample_request_in_flight_per_worker (int): max queue size per
worker
replay_proportion (float): ratio of replay to sampled outputs
replay_buffer_num_slots (int): max number of sample batches to
store in the replay buffer
train_batch_size (int): size of batches to learn on
sample_batch_size (int): size of batches to sample from workers
broadcast_interval (int): max number of workers to send the
same set of weights to
"""
self.workers = workers
self.num_aggregation_workers = num_aggregation_workers
self.max_sample_requests_in_flight_per_worker = \

View file

@ -19,6 +19,13 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
"""
def __init__(self, workers, grads_per_step=100):
"""Initialize an async gradients optimizer.
Arguments:
grads_per_step (int): The number of gradients to collect and apply
per each call to step(). This number should be sufficiently
high to amortize the overhead of calling step().
"""
PolicyOptimizer.__init__(self, workers)
self.apply_timer = TimerStat()

View file

@ -61,6 +61,27 @@ class AsyncReplayOptimizer(PolicyOptimizer):
max_weight_sync_delay=400,
debug=False,
batch_replay=False):
"""Initialize an async replay optimizer.
Arguments:
workers (WorkerSet): all workers
learning_starts (int): wait until this many steps have been sampled
before starting optimization.
buffer_size (int): max size of the replay buffer
prioritized_replay (bool): whether to enable prioritized replay
prioritized_replay_alpha (float): replay alpha hyperparameter
prioritized_replay_beta (float): replay beta hyperparameter
prioritized_replay_eps (float): replay eps hyperparameter
train_batch_size (int): size of batches to learn on
sample_batch_size (int): size of batches to sample from workers
num_replay_buffer_shards (int): number of actors to use to store
replay samples
max_weight_sync_delay (int): update the weights of a rollout worker
after collecting this number of timesteps from it
debug (bool): return extra debug stats
batch_replay (bool): replay entire sequential batches of
experiences instead of sampling steps individually
"""
PolicyOptimizer.__init__(self, workers)
self.debug = debug

View file

@ -12,8 +12,7 @@ from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
from ray.rllib.optimizers.rollout import collect_samples, \
collect_samples_straggler_mitigation
from ray.rllib.optimizers.rollout import collect_samples
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
@ -50,8 +49,22 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
train_batch_size=1024,
num_gpus=0,
standardize_fields=[],
straggler_mitigation=False,
shuffle_sequences=True):
"""Initialize a synchronous multi-gpu optimizer.
Arguments:
workers (WorkerSet): all workers
sgd_batch_size (int): SGD minibatch size within train batch size
num_sgd_iter (int): number of passes to learn on per train batch
sample_batch_size (int): size of batches to sample from workers
num_envs_per_worker (int): num envs in each rollout worker
train_batch_size (int): size of batches to learn on
num_gpus (int): number of GPUs to use for data-parallel SGD
standardize_fields (list): list of fields in the training batch
to normalize
shuffle_sequences (bool): whether to shuffle the train batch prior
to SGD to break up correlations
"""
PolicyOptimizer.__init__(self, workers)
self.batch_size = sgd_batch_size
@ -59,7 +72,6 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
self.num_envs_per_worker = num_envs_per_worker
self.sample_batch_size = sample_batch_size
self.train_batch_size = train_batch_size
self.straggler_mitigation = straggler_mitigation
self.shuffle_sequences = shuffle_sequences
if not num_gpus:
self.devices = ["/cpu:0"]
@ -123,13 +135,9 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
with self.sample_timer:
if self.workers.remote_workers():
if self.straggler_mitigation:
samples = collect_samples_straggler_mitigation(
self.workers.remote_workers(), self.train_batch_size)
else:
samples = collect_samples(
self.workers.remote_workers(), self.sample_batch_size,
self.num_envs_per_worker, self.train_batch_size)
samples = collect_samples(
self.workers.remote_workers(), self.sample_batch_size,
self.num_envs_per_worker, self.train_batch_size)
if samples.count > self.train_batch_size * 2:
logger.info(
"Collected more training samples than expected "

View file

@ -134,12 +134,3 @@ class PolicyOptimizer(object):
The index will be passed as the second arg to the given function.
"""
return self.workers.foreach_worker_with_index(func)
def foreach_evaluator(self, func):
raise DeprecationWarning(
"foreach_evaluator has been renamed to foreach_worker")
def foreach_evaluator_with_index(self, func):
raise DeprecationWarning(
"foreach_evaluator_with_index has been renamed to "
"foreach_worker_with_index")

View file

@ -38,35 +38,3 @@ def collect_samples(agents, sample_batch_size, num_envs_per_worker,
agent_dict[fut_sample2] = agent
return SampleBatch.concat_samples(trajectories)
def collect_samples_straggler_mitigation(agents, train_batch_size):
"""Collects at least train_batch_size samples.
This is the legacy behavior as of 0.6, and launches extra sample tasks to
potentially improve performance but can result in many wasted samples.
"""
num_timesteps_so_far = 0
trajectories = []
agent_dict = {}
for agent in agents:
fut_sample = agent.sample.remote()
agent_dict[fut_sample] = agent
while num_timesteps_so_far < train_batch_size:
# TODO(pcm): Make wait support arbitrary iterators and remove the
# conversion to list here.
[fut_sample], _ = ray.wait(list(agent_dict))
agent = agent_dict.pop(fut_sample)
# Start task with next trajectory and record it in the dictionary.
fut_sample2 = agent.sample.remote()
agent_dict[fut_sample2] = agent
next_sample = ray_get_and_free(fut_sample)
num_timesteps_so_far += next_sample.count
trajectories.append(next_sample)
logger.info("Discarding {} sample tasks".format(len(agent_dict)))
return SampleBatch.concat_samples(trajectories)

View file

@ -24,6 +24,15 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
learning_starts=1000,
buffer_size=10000,
train_batch_size=32):
"""Initialize a batch replay optimizer.
Arguments:
workers (WorkerSet): set of all workers
learning_starts (int): start learning after this number of
timesteps have been collected
buffer_size (int): max timesteps to keep in the replay buffer
train_batch_size (int): number of timesteps to train on at once
"""
PolicyOptimizer.__init__(self, workers)
self.replay_starts = learning_starts

View file

@ -36,12 +36,30 @@ class SyncReplayOptimizer(PolicyOptimizer):
prioritized_replay=True,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
prioritized_replay_eps=1e-6,
schedule_max_timesteps=100000,
beta_annealing_fraction=0.2,
final_prioritized_replay_beta=0.4,
prioritized_replay_eps=1e-6,
train_batch_size=32,
sample_batch_size=4):
"""Initialize an sync replay optimizer.
Arguments:
workers (WorkerSet): all workers
learning_starts (int): wait until this many steps have been sampled
before starting optimization.
buffer_size (int): max size of the replay buffer
prioritized_replay (bool): whether to enable prioritized replay
prioritized_replay_alpha (float): replay alpha hyperparameter
prioritized_replay_beta (float): replay beta hyperparameter
prioritized_replay_eps (float): replay eps hyperparameter
schedule_max_timesteps (int): number of timesteps in the schedule
beta_annealing_fraction (float): fraction of schedule to anneal
beta over
final_prioritized_replay_beta (float): final value of beta
train_batch_size (int): size of batches to learn on
sample_batch_size (int): size of batches to sample from workers
"""
PolicyOptimizer.__init__(self, workers)
self.replay_starts = learning_starts

View file

@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import numpy as np
import gym
@ -11,6 +12,9 @@ from ray.rllib.utils.annotations import DeveloperAPI
# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key.
LEARNER_STATS_KEY = "learner_stats"
# Used to return tuple actions as a list of batches per tuple element
TupleActions = namedtuple("TupleActions", ["batches"])
@DeveloperAPI
class Policy(object):

View file

@ -1,7 +1,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""LSTM support for RLlib.
"""RNN utils for RLlib.
The main trick here is that we add the time dimension at the last moment.
The non-LSTM layers of the model see their inputs as one flat batch. Before
@ -12,87 +12,17 @@ reshaping is possible.
Note that this padding strategy only works out if we assume zero inputs don't
meaningfully affect the loss function. This happens to be true for all the
current algorithms: https://github.com/ray-project/ray/issues/2992
See the add_time_dimension() and chop_into_sequences() functions below for
more info.
"""
import numpy as np
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.models.model import Model
from ray.rllib.utils.annotations import override, DeveloperAPI, PublicAPI
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
class LSTM(Model):
"""Adds a LSTM cell on top of some other model output.
Uses a linear layer at the end for output.
Important: we assume inputs is a padded batch of sequences denoted by
self.seq_lens. See add_time_dimension() for more information.
"""
@override(Model)
def _build_layers_v2(self, input_dict, num_outputs, options):
cell_size = options.get("lstm_cell_size")
if options.get("lstm_use_prev_action_reward"):
action_dim = int(
np.product(
input_dict["prev_actions"].get_shape().as_list()[1:]))
features = tf.concat(
[
input_dict["obs"],
tf.reshape(
tf.cast(input_dict["prev_actions"], tf.float32),
[-1, action_dim]),
tf.reshape(input_dict["prev_rewards"], [-1, 1]),
],
axis=1)
else:
features = input_dict["obs"]
last_layer = add_time_dimension(features, self.seq_lens)
# Setup the LSTM cell
lstm = tf.nn.rnn_cell.LSTMCell(cell_size, state_is_tuple=True)
self.state_init = [
np.zeros(lstm.state_size.c, np.float32),
np.zeros(lstm.state_size.h, np.float32)
]
# Setup LSTM inputs
if self.state_in:
c_in, h_in = self.state_in
else:
c_in = tf.placeholder(
tf.float32, [None, lstm.state_size.c], name="c")
h_in = tf.placeholder(
tf.float32, [None, lstm.state_size.h], name="h")
self.state_in = [c_in, h_in]
# Setup LSTM outputs
state_in = tf.nn.rnn_cell.LSTMStateTuple(c_in, h_in)
lstm_out, lstm_state = tf.nn.dynamic_rnn(
lstm,
last_layer,
initial_state=state_in,
sequence_length=self.seq_lens,
time_major=False,
dtype=tf.float32)
self.state_out = list(lstm_state)
# Compute outputs
last_layer = tf.reshape(lstm_out, [-1, cell_size])
logits = linear(last_layer, num_outputs, "action",
normc_initializer(0.01))
return logits, last_layer
@PublicAPI
@DeveloperAPI
def add_time_dimension(padded_inputs, seq_lens):
"""Adds a time dimension to padded inputs.

View file

@ -9,8 +9,8 @@ import os
import numpy as np
import ray
import ray.experimental.tf_utils
from ray.rllib.models.lstm import chop_into_sequences
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override, DeveloperAPI
@ -94,7 +94,7 @@ class TFPolicy(Policy):
prev_reward_input (Tensor): placeholder for previous rewards
seq_lens (Tensor): placeholder for RNN sequence lengths, of shape
[NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See
models/lstm.py for more information.
policy/rnn_sequencing.py for more information.
max_seq_len (int): max sequence length for LSTM training.
batch_divisibility_req (int): pad all agent experiences batches to
multiples of this value. This only has an effect if not using

View file

@ -9,8 +9,8 @@ from ray.rllib.models import ModelCatalog
from ray.rllib.models.model import Model
from ray.rllib.models.preprocessors import (NoPreprocessor, OneHotPreprocessor,
Preprocessor)
from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.visionnet import VisionNetwork
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
from ray.rllib.models.tf.visionnet_v1 import VisionNetwork
from ray.rllib.utils import try_import_tf
tf = try_import_tf()

View file

@ -9,9 +9,10 @@ import unittest
import ray
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.policy.rnn_sequencing import chop_into_sequences, \
add_time_dimension
from ray.rllib.models import ModelCatalog
from ray.rllib.models.lstm import add_time_dimension, chop_into_sequences
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.models.tf.misc import linear, normc_initializer
from ray.rllib.models.model import Model
from ray.tune.registry import register_env
from ray.rllib.utils import try_import_tf

View file

@ -7,6 +7,7 @@ import pickle
from gym import spaces
from gym.envs.registration import EnvSpec
import gym
import torch.nn as nn
import unittest
import ray
@ -133,13 +134,14 @@ class InvalidModel2(Model):
return tf.constant(0), tf.constant(0)
class TorchSpyModel(TorchModelV2):
class TorchSpyModel(TorchModelV2, nn.Module):
capture_index = 0
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(TorchSpyModel, self).__init__(obs_space, action_space,
num_outputs, model_config, name)
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
self.fc = FullyConnectedNetwork(
obs_space.original_space.spaces["sensors"].spaces["position"],
action_space, num_outputs, model_config, name)

View file

@ -80,19 +80,6 @@ class PPOCollectTest(unittest.TestCase):
self.assertEqual(ppo.optimizer.num_steps_sampled, 1200)
ppo.stop()
# Check legacy mode
ppo = PPOTrainer(
env="CartPole-v0",
config={
"sample_batch_size": 200,
"train_batch_size": 128,
"num_workers": 3,
"straggler_mitigation": True,
})
ppo.train()
self.assertEqual(ppo.optimizer.num_steps_sampled, 200)
ppo.stop()
class SampleBatchTest(unittest.TestCase):
def testConcat(self):