[RLlib] Curiosity documentation. (#11066)

This commit is contained in:
Sven Mika 2020-09-29 09:39:22 +02:00 committed by GitHub
parent 0d36e4c025
commit f91c455527
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 115 additions and 60 deletions

View file

@ -45,17 +45,24 @@ Algorithm Frameworks Discrete Actions Continuous Acti
`Shared Critic Methods`_ Depends on bootstrapped algorithm `Shared Critic Methods`_ Depends on bootstrapped algorithm
============================= ======================================================================================= ============================= =======================================================================================
Exploration-based plug-ins (can be combined with any algo)
============================= ========== ======================= ================== =========== =====================
Algorithm Frameworks Discrete Actions Continuous Actions Multi-Agent Model Support
============================= ========== ======================= ================== =========== =====================
`Curiosity`_ torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
============================= ========== ======================= ================== =========== =====================
.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces
.. _`+RNN`: rllib-models.html#recurrent-models
.. _`+Transformer`: rllib-models.html#attention-networks
.. _`A2C, A3C`: rllib-algorithms.html#a3c .. _`A2C, A3C`: rllib-algorithms.html#a3c
.. _`APEX-DQN`: rllib-algorithms.html#apex .. _`APEX-DQN`: rllib-algorithms.html#apex
.. _`APEX-DDPG`: rllib-algorithms.html#apex .. _`APEX-DDPG`: rllib-algorithms.html#apex
.. _`Rainbow`: rllib-algorithms.html#dqn
.. _`TD3`: rllib-algorithms.html#ddpg
.. _`+autoreg`: rllib-models.html#autoregressive-action-distributions .. _`+autoreg`: rllib-models.html#autoregressive-action-distributions
.. _`+LSTM auto-wrapping`: rllib-models.html#built-in-models .. _`+LSTM auto-wrapping`: rllib-models.html#built-in-models
.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces
.. _`Rainbow`: rllib-algorithms.html#dqn
.. _`+RNN`: rllib-models.html#recurrent-models
.. _`TD3`: rllib-algorithms.html#ddpg
.. _`+Transformer`: rllib-models.html#attention-networks
High-throughput architectures High-throughput architectures
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -647,8 +654,8 @@ Tuned examples: `SimpleContextualBandit <https://github.com/ray-project/ray/blob
Linear Thompson Sampling (contrib/LinTS) Linear Thompson Sampling (contrib/LinTS)
---------------------------------------- ----------------------------------------
|pytorch| |pytorch|
`[paper] <http://proceedings.mlr.press/v28/agrawal13.pdf>`__ `[implementation] `[paper] <http://proceedings.mlr.press/v28/agrawal13.pdf>`__
<https://github.com/ray-project/ray/blob/master/rllib/contrib/bandits/agents/lin_ts.py>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/contrib/bandits/agents/lin_ts.py>`__
Like LinUCB, LinTS also assumes a linear dependency between the expected Like LinUCB, LinTS also assumes a linear dependency between the expected
reward of an action and its context and uses online ridge regression to reward of an action and its context and uses online ridge regression to
estimate the Q values of actions given the context. It assumes a Gaussian estimate the Q values of actions given the context. It assumes a Gaussian
@ -741,8 +748,78 @@ Fully Independent Learning
Tuned examples: `waterworld <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent_independent_learning.py>`__, `multiagent-cartpole <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent_cartpole.py>`__ Tuned examples: `waterworld <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent_independent_learning.py>`__, `multiagent-cartpole <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent_cartpole.py>`__
Shared Critic Methods Shared Critic Methods
-------------------------- ---------------------
`[instructions] <https://docs.ray.io/en/master/rllib-env.html#implementing-a-centralized-critic>`__ Shared critic methods are when all agents use a single parameter shared critic network (in some cases with access to more of the observation space than agents can see). Note that many specialized multi-agent algorithms such as MADDPG are mostly shared critic forms of their single-agent algorithm (DDPG in the case of MADDPG). `[instructions] <https://docs.ray.io/en/master/rllib-env.html#implementing-a-centralized-critic>`__ Shared critic methods are when all agents use a single parameter shared critic network (in some cases with access to more of the observation space than agents can see). Note that many specialized multi-agent algorithms such as MADDPG are mostly shared critic forms of their single-agent algorithm (DDPG in the case of MADDPG).
Tuned examples: `TwoStepGame <https://github.com/ray-project/ray/blob/master/rllib/examples/centralized_critic_2.py>`__ Tuned examples: `TwoStepGame <https://github.com/ray-project/ray/blob/master/rllib/examples/centralized_critic_2.py>`__
Exploration-based plug-ins (can be combined with any algo)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. _Curiosity:
Curiosity (ICM: Intrinsic Curiosity Module)
-------------------------------------------
|pytorch|
`[paper] <https://arxiv.org/pdf/1705.05363.pdf>`__
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/utils/exploration/curiosity.py>`__
Tuned examples:
`Pyramids (Unity3D) <https://github.com/ray-project/ray/blob/master/rllib/examples/unity3d_env_local.py>`__ (use ``--env Pyramids`` command line option)
`Test case with MiniGrid example <https://github.com/ray-project/ray/blob/master/rllib/utils/exploration/tests/test_curiosity.py#L184>`__ (UnitTest case: ``test_curiosity_on_partially_observable_domain``)
**Activating Curiosity**
The curiosity plugin can be easily activated by specifying it as the Exploration class to-be-used
in the main Trainer config. Most of its parameters usually do not have to be specified
as the module uses the values from the paper by default. For example:
.. code-block:: python
config = ppo.DEFAULT_CONFIG.copy()
config["num_workers"] = 0
config["exploration_config"] = {
"type": "Curiosity", # <- Use the Curiosity module for exploring.
"eta": 1.0, # Weight for intrinsic rewards before being added to extrinsic ones.
"lr": 0.001, # Learning rate of the curiosity (ICM) module.
"feature_dim": 288, # Dimensionality of the generated feature vectors.
# Setup of the feature net (used to encode observations into feature (latent) vectors).
"feature_net_config": {
"fcnet_hiddens": [],
"fcnet_activation": "relu",
},
"inverse_net_hiddens": [256], # Hidden layers of the "inverse" model.
"inverse_net_activation": "relu", # Activation of the "inverse" model.
"forward_net_hiddens": [256], # Hidden layers of the "forward" model.
"forward_net_activation": "relu", # Activation of the "forward" model.
"beta": 0.2, # Weight for the "forward" loss (beta) over the "inverse" loss (1.0 - beta).
# Specify, which exploration sub-type to use (usually, the algo's "default"
# exploration, e.g. EpsilonGreedy for DQN, StochasticSampling for PG/SAC).
"sub_exploration": {
"type": "StochasticSampling",
}
}
**Functionality**
RLlib's Curiosity is based on `"ICM" (intrinsic curiosity module) described in this paper here <https://https://arxiv.org/pdf/1705.05363.pdf>`__.
It allows agents to learn in sparse-reward- or even no-reward environments by
calculating so-called "intrinsic rewards", purely based on the information content that is incoming via the observation channel.
Sparse-reward environments are envs where almost all reward signals are 0.0, such as these `[MiniGrid env examples here] <https://github.com/maximecb/gym-minigrid>`__.
In such environments, agents have to navigate (and change the underlying state of the environment) over long periods of time, without receiving much (or any) feedback.
For example, the task could be to find a key in some room, pick it up, find a matching door (matching the color of the key), and eventually unlock this door with the key to reach a goal state,
all the while not seeing any rewards.
Such problems are impossible to solve with standard RL exploration methods like epsilon-greedy or stochastic sampling.
The Curiosity module - when configured as the Exploration class to use via the Trainer's config (see above on how to do this) - automatically adds three simple models to the Policy's ``self.model``:
a) a latent space learning ("feature") model, taking an environment observation and outputting a latent vector, which represents this observation and
b) a "forward" model, predicting the next latent vector, given the current observation vector and an action to take next.
c) a so-called "inverse" net, only used to train the "feature" net. The inverse net tries to predict the action taken between two latent vectors (obs and next obs).
All the above extra Models are trained inside the ``postprocess_trajectory()`` call.
Using the (ever changing) "forward" model, our Curiosity module calculates an artificial (intrinsic) reward signal, weights it via the ``eta`` parameter, and then adds it to the environment's (extrinsic) reward.
Intrinsic rewards for each env-step are calculated by taking the euclidian distance between the latent-space encoded next observation ("feature" model) and the **predicted** latent-space encoding for the next observation
("forward" model).
This allows the agent to explore areas of the environment, where the "forward" model still performs poorly (are not "understood" yet), whereas exploration to these areas will taper down after the agent has visited them
often: The "forward" model will eventually get better at predicting these next latent vectors, which in turn will diminish the intrinsic rewards (decrease the euclidian distance between predicted and actual vectors).

View file

@ -140,6 +140,11 @@ Algorithms
- |pytorch| :ref:`Linear Upper Confidence Bound (contrib/LinUCB) <linucb>` - |pytorch| :ref:`Linear Upper Confidence Bound (contrib/LinUCB) <linucb>`
- |pytorch| :ref:`Linear Thompson Sampling (contrib/LinTS) <lints>` - |pytorch| :ref:`Linear Thompson Sampling (contrib/LinTS) <lints>`
* Exploration-based plug-ins (can be combined with any algo)
- |pytorch| :ref:`Curiosity (ICM: Intrinsic Curiosity Module) <curiosity>`
Offline Datasets Offline Datasets
---------------- ----------------
* `Working with Offline Datasets <rllib-offline.html>`__ * `Working with Offline Datasets <rllib-offline.html>`__

View file

@ -568,12 +568,11 @@ actions from distributions (stochastically or deterministically).
The setup can be done via using built-in Exploration classes The setup can be done via using built-in Exploration classes
(see `this package <https://github.com/ray-project/ray/blob/master/rllib/utils/exploration/>`__), (see `this package <https://github.com/ray-project/ray/blob/master/rllib/utils/exploration/>`__),
which are specified (and further configured) inside ``Trainer.config["exploration_config"]``. which are specified (and further configured) inside ``Trainer.config["exploration_config"]``.
Besides using built-in classes, one can sub-class any of Besides using one of the available classes, one can sub-class any of
these built-ins, add custom behavior to it, and use that new class in these built-ins, add custom behavior to it, and use that new class in
the config instead. the config instead.
Every policy has-an instantiation of one of the Exploration (sub-)classes. Every policy has-an Exploration object, which is created from the Trainers
This Exploration object is created from the Trainers
``config[“exploration_config”]`` dict, which specifies the class to use via the ``config[“exploration_config”]`` dict, which specifies the class to use via the
special “type” key, as well as constructor arguments via all other keys, special “type” key, as well as constructor arguments via all other keys,
e.g.: e.g.:
@ -589,7 +588,7 @@ e.g.:
# ... # ...
The following table lists all built-in Exploration sub-classes and the agents The following table lists all built-in Exploration sub-classes and the agents
that currently used these by default: that currently use these by default:
.. View table below at: https://docs.google.com/drawings/d/1dEMhosbu7HVgHEwGBuMlEDyPiwjqp_g6bZ0DzCMaoUM/edit?usp=sharing .. View table below at: https://docs.google.com/drawings/d/1dEMhosbu7HVgHEwGBuMlEDyPiwjqp_g6bZ0DzCMaoUM/edit?usp=sharing
.. image:: images/rllib-exploration-api-table.svg .. image:: images/rllib-exploration-api-table.svg
@ -598,53 +597,20 @@ An Exploration class implements the ``get_exploration_action`` method,
in which the exact exploratory behavior is defined. in which the exact exploratory behavior is defined.
It takes the models output, the action distribution class, the model itself, It takes the models output, the action distribution class, the model itself,
a timestep (the global env-sampling steps already taken), a timestep (the global env-sampling steps already taken),
and an ``explore`` switch and outputs a tuple of 1) action and and an ``explore`` switch and outputs a tuple of a) action and
2) log-likelihood: b) log-likelihood:
.. code-block:: python
def get_exploration_action(self,
distribution_inputs,
action_dist_class,
model=None,
explore=True,
timestep=None):
"""Returns a (possibly) exploratory action and its log-likelihood.
Given the Model's logits outputs and action distribution, returns an
exploratory action.
Args:
distribution_inputs (any): The output coming from the model,
ready for parameterizing a distribution
(e.g. q-values or PG-logits).
action_dist_class (class): The action distribution class
to use.
model (ModelV2): The Model object.
explore (bool): True: "Normal" exploration behavior.
False: Suppress all exploratory behavior and return
a deterministic action.
timestep (int): The current sampling time step. If None, the
component should try to use an internal counter, which it
then increments by 1. If provided, will set the internal
counter to the given value.
Returns:
Tuple:
- The chosen exploration action or a tf-op to fetch the exploration
action from the graph.
- The log-likelihood of the exploration action.
"""
pass
.. literalinclude:: ../../rllib/utils/exploration/exploration.py
:language: python
:start-after: __sphinx_doc_begin_get_exploration_action__
:end-before: __sphinx_doc_end_get_exploration_action__
On the highest level, the ``Trainer.compute_action`` and ``Policy.compute_action(s)`` On the highest level, the ``Trainer.compute_action`` and ``Policy.compute_action(s)``
methods have a boolean ``explore`` switch, which is passed into methods have a boolean ``explore`` switch, which is passed into
``Exploration.get_exploration_action``. If ``None``, the value of ``Exploration.get_exploration_action``. If ``explore=None``, the value of
``Trainer.config[“explore”]`` is used. ``Trainer.config[“explore”]`` is used, which thus serves as a main switch for
Hence ``config[“explore”]`` describes the default behavior of the policy and exploratory behavior, allowing e.g. turning off any exploration easily for
e.g. allows switching off any exploration easily for evaluation purposes evaluation purposes (see :ref:`CustomEvaluation`).
(see :ref:`CustomEvaluation`).
The following are example excerpts from different Trainers' configs The following are example excerpts from different Trainers' configs
(see rllib/agents/trainer.py) to setup different exploration behaviors: (see rllib/agents/trainer.py) to setup different exploration behaviors:
@ -688,13 +654,14 @@ The following are example excerpts from different Trainers' configs
"temperature": 1.0, "temperature": 1.0,
}, },
# c) PPO: see rllib/agents/ppo/ppo.py # c) All policy-gradient algos and SAC: see rllib/agents/trainer.py
# Behavior: The algo samples stochastically by default from the # Behavior: The algo samples stochastically from the
# model-parameterized distribution. This is the global Trainer default # model-parameterized distribution. This is the global Trainer default
# setting defined in trainer.py and used by all PG-type algos. # setting defined in trainer.py and used by all PG-type algos (plus SAC).
"explore": True, "explore": True,
"exploration_config": { "exploration_config": {
"type": "StochasticSampling", "type": "StochasticSampling",
"random_timesteps": 0, # timesteps at beginning, over which to act uniformly randomly
}, },

View file

@ -69,6 +69,9 @@ class Exploration:
""" """
pass pass
# yapf: disable
# __sphinx_doc_begin_get_exploration_action__
@DeveloperAPI @DeveloperAPI
def get_exploration_action(self, def get_exploration_action(self,
*, *,
@ -98,6 +101,9 @@ class Exploration:
""" """
pass pass
# __sphinx_doc_end_get_exploration_action__
# yapf: enable
@DeveloperAPI @DeveloperAPI
def on_episode_start(self, def on_episode_start(self,
policy, policy,