mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Curiosity documentation. (#11066)
This commit is contained in:
parent
0d36e4c025
commit
f91c455527
4 changed files with 115 additions and 60 deletions
|
@ -45,17 +45,24 @@ Algorithm Frameworks Discrete Actions Continuous Acti
|
|||
`Shared Critic Methods`_ Depends on bootstrapped algorithm
|
||||
============================= =======================================================================================
|
||||
|
||||
Exploration-based plug-ins (can be combined with any algo)
|
||||
|
||||
============================= ========== ======================= ================== =========== =====================
|
||||
Algorithm Frameworks Discrete Actions Continuous Actions Multi-Agent Model Support
|
||||
============================= ========== ======================= ================== =========== =====================
|
||||
`Curiosity`_ torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
|
||||
============================= ========== ======================= ================== =========== =====================
|
||||
|
||||
.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces
|
||||
.. _`+RNN`: rllib-models.html#recurrent-models
|
||||
.. _`+Transformer`: rllib-models.html#attention-networks
|
||||
.. _`A2C, A3C`: rllib-algorithms.html#a3c
|
||||
.. _`APEX-DQN`: rllib-algorithms.html#apex
|
||||
.. _`APEX-DDPG`: rllib-algorithms.html#apex
|
||||
.. _`Rainbow`: rllib-algorithms.html#dqn
|
||||
.. _`TD3`: rllib-algorithms.html#ddpg
|
||||
.. _`+autoreg`: rllib-models.html#autoregressive-action-distributions
|
||||
.. _`+LSTM auto-wrapping`: rllib-models.html#built-in-models
|
||||
.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces
|
||||
.. _`Rainbow`: rllib-algorithms.html#dqn
|
||||
.. _`+RNN`: rllib-models.html#recurrent-models
|
||||
.. _`TD3`: rllib-algorithms.html#ddpg
|
||||
.. _`+Transformer`: rllib-models.html#attention-networks
|
||||
|
||||
High-throughput architectures
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -647,8 +654,8 @@ Tuned examples: `SimpleContextualBandit <https://github.com/ray-project/ray/blob
|
|||
Linear Thompson Sampling (contrib/LinTS)
|
||||
----------------------------------------
|
||||
|pytorch|
|
||||
`[paper] <http://proceedings.mlr.press/v28/agrawal13.pdf>`__ `[implementation]
|
||||
<https://github.com/ray-project/ray/blob/master/rllib/contrib/bandits/agents/lin_ts.py>`__
|
||||
`[paper] <http://proceedings.mlr.press/v28/agrawal13.pdf>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/contrib/bandits/agents/lin_ts.py>`__
|
||||
Like LinUCB, LinTS also assumes a linear dependency between the expected
|
||||
reward of an action and its context and uses online ridge regression to
|
||||
estimate the Q values of actions given the context. It assumes a Gaussian
|
||||
|
@ -741,8 +748,78 @@ Fully Independent Learning
|
|||
Tuned examples: `waterworld <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent_independent_learning.py>`__, `multiagent-cartpole <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent_cartpole.py>`__
|
||||
|
||||
Shared Critic Methods
|
||||
--------------------------
|
||||
---------------------
|
||||
|
||||
`[instructions] <https://docs.ray.io/en/master/rllib-env.html#implementing-a-centralized-critic>`__ Shared critic methods are when all agents use a single parameter shared critic network (in some cases with access to more of the observation space than agents can see). Note that many specialized multi-agent algorithms such as MADDPG are mostly shared critic forms of their single-agent algorithm (DDPG in the case of MADDPG).
|
||||
|
||||
Tuned examples: `TwoStepGame <https://github.com/ray-project/ray/blob/master/rllib/examples/centralized_critic_2.py>`__
|
||||
|
||||
|
||||
Exploration-based plug-ins (can be combined with any algo)
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. _Curiosity:
|
||||
|
||||
Curiosity (ICM: Intrinsic Curiosity Module)
|
||||
-------------------------------------------
|
||||
|
||||
|pytorch|
|
||||
`[paper] <https://arxiv.org/pdf/1705.05363.pdf>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/utils/exploration/curiosity.py>`__
|
||||
|
||||
Tuned examples:
|
||||
`Pyramids (Unity3D) <https://github.com/ray-project/ray/blob/master/rllib/examples/unity3d_env_local.py>`__ (use ``--env Pyramids`` command line option)
|
||||
`Test case with MiniGrid example <https://github.com/ray-project/ray/blob/master/rllib/utils/exploration/tests/test_curiosity.py#L184>`__ (UnitTest case: ``test_curiosity_on_partially_observable_domain``)
|
||||
|
||||
**Activating Curiosity**
|
||||
The curiosity plugin can be easily activated by specifying it as the Exploration class to-be-used
|
||||
in the main Trainer config. Most of its parameters usually do not have to be specified
|
||||
as the module uses the values from the paper by default. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0
|
||||
config["exploration_config"] = {
|
||||
"type": "Curiosity", # <- Use the Curiosity module for exploring.
|
||||
"eta": 1.0, # Weight for intrinsic rewards before being added to extrinsic ones.
|
||||
"lr": 0.001, # Learning rate of the curiosity (ICM) module.
|
||||
"feature_dim": 288, # Dimensionality of the generated feature vectors.
|
||||
# Setup of the feature net (used to encode observations into feature (latent) vectors).
|
||||
"feature_net_config": {
|
||||
"fcnet_hiddens": [],
|
||||
"fcnet_activation": "relu",
|
||||
},
|
||||
"inverse_net_hiddens": [256], # Hidden layers of the "inverse" model.
|
||||
"inverse_net_activation": "relu", # Activation of the "inverse" model.
|
||||
"forward_net_hiddens": [256], # Hidden layers of the "forward" model.
|
||||
"forward_net_activation": "relu", # Activation of the "forward" model.
|
||||
"beta": 0.2, # Weight for the "forward" loss (beta) over the "inverse" loss (1.0 - beta).
|
||||
# Specify, which exploration sub-type to use (usually, the algo's "default"
|
||||
# exploration, e.g. EpsilonGreedy for DQN, StochasticSampling for PG/SAC).
|
||||
"sub_exploration": {
|
||||
"type": "StochasticSampling",
|
||||
}
|
||||
}
|
||||
|
||||
**Functionality**
|
||||
RLlib's Curiosity is based on `"ICM" (intrinsic curiosity module) described in this paper here <https://https://arxiv.org/pdf/1705.05363.pdf>`__.
|
||||
It allows agents to learn in sparse-reward- or even no-reward environments by
|
||||
calculating so-called "intrinsic rewards", purely based on the information content that is incoming via the observation channel.
|
||||
Sparse-reward environments are envs where almost all reward signals are 0.0, such as these `[MiniGrid env examples here] <https://github.com/maximecb/gym-minigrid>`__.
|
||||
In such environments, agents have to navigate (and change the underlying state of the environment) over long periods of time, without receiving much (or any) feedback.
|
||||
For example, the task could be to find a key in some room, pick it up, find a matching door (matching the color of the key), and eventually unlock this door with the key to reach a goal state,
|
||||
all the while not seeing any rewards.
|
||||
Such problems are impossible to solve with standard RL exploration methods like epsilon-greedy or stochastic sampling.
|
||||
The Curiosity module - when configured as the Exploration class to use via the Trainer's config (see above on how to do this) - automatically adds three simple models to the Policy's ``self.model``:
|
||||
a) a latent space learning ("feature") model, taking an environment observation and outputting a latent vector, which represents this observation and
|
||||
b) a "forward" model, predicting the next latent vector, given the current observation vector and an action to take next.
|
||||
c) a so-called "inverse" net, only used to train the "feature" net. The inverse net tries to predict the action taken between two latent vectors (obs and next obs).
|
||||
|
||||
All the above extra Models are trained inside the ``postprocess_trajectory()`` call.
|
||||
|
||||
Using the (ever changing) "forward" model, our Curiosity module calculates an artificial (intrinsic) reward signal, weights it via the ``eta`` parameter, and then adds it to the environment's (extrinsic) reward.
|
||||
Intrinsic rewards for each env-step are calculated by taking the euclidian distance between the latent-space encoded next observation ("feature" model) and the **predicted** latent-space encoding for the next observation
|
||||
("forward" model).
|
||||
This allows the agent to explore areas of the environment, where the "forward" model still performs poorly (are not "understood" yet), whereas exploration to these areas will taper down after the agent has visited them
|
||||
often: The "forward" model will eventually get better at predicting these next latent vectors, which in turn will diminish the intrinsic rewards (decrease the euclidian distance between predicted and actual vectors).
|
||||
|
|
|
@ -140,6 +140,11 @@ Algorithms
|
|||
- |pytorch| :ref:`Linear Upper Confidence Bound (contrib/LinUCB) <linucb>`
|
||||
- |pytorch| :ref:`Linear Thompson Sampling (contrib/LinTS) <lints>`
|
||||
|
||||
* Exploration-based plug-ins (can be combined with any algo)
|
||||
|
||||
- |pytorch| :ref:`Curiosity (ICM: Intrinsic Curiosity Module) <curiosity>`
|
||||
|
||||
|
||||
Offline Datasets
|
||||
----------------
|
||||
* `Working with Offline Datasets <rllib-offline.html>`__
|
||||
|
|
|
@ -235,11 +235,11 @@ It also simplifies saving the trained agent. For example:
|
|||
checkpoints = analysis.get_trial_checkpoints_paths(
|
||||
trial=analysis.get_best_trial("episode_reward_mean"),
|
||||
metric="episode_reward_mean")
|
||||
|
||||
|
||||
Loading and restoring a trained agent from a checkpoint is simple:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
agent = ppo.PPOTrainer(config=config, env=env_class)
|
||||
agent.restore(checkpoint_path)
|
||||
|
||||
|
@ -568,12 +568,11 @@ actions from distributions (stochastically or deterministically).
|
|||
The setup can be done via using built-in Exploration classes
|
||||
(see `this package <https://github.com/ray-project/ray/blob/master/rllib/utils/exploration/>`__),
|
||||
which are specified (and further configured) inside ``Trainer.config["exploration_config"]``.
|
||||
Besides using built-in classes, one can sub-class any of
|
||||
Besides using one of the available classes, one can sub-class any of
|
||||
these built-ins, add custom behavior to it, and use that new class in
|
||||
the config instead.
|
||||
|
||||
Every policy has-an instantiation of one of the Exploration (sub-)classes.
|
||||
This Exploration object is created from the Trainer’s
|
||||
Every policy has-an Exploration object, which is created from the Trainer’s
|
||||
``config[“exploration_config”]`` dict, which specifies the class to use via the
|
||||
special “type” key, as well as constructor arguments via all other keys,
|
||||
e.g.:
|
||||
|
@ -589,7 +588,7 @@ e.g.:
|
|||
# ...
|
||||
|
||||
The following table lists all built-in Exploration sub-classes and the agents
|
||||
that currently used these by default:
|
||||
that currently use these by default:
|
||||
|
||||
.. View table below at: https://docs.google.com/drawings/d/1dEMhosbu7HVgHEwGBuMlEDyPiwjqp_g6bZ0DzCMaoUM/edit?usp=sharing
|
||||
.. image:: images/rllib-exploration-api-table.svg
|
||||
|
@ -598,53 +597,20 @@ An Exploration class implements the ``get_exploration_action`` method,
|
|||
in which the exact exploratory behavior is defined.
|
||||
It takes the model’s output, the action distribution class, the model itself,
|
||||
a timestep (the global env-sampling steps already taken),
|
||||
and an ``explore`` switch and outputs a tuple of 1) action and
|
||||
2) log-likelihood:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def get_exploration_action(self,
|
||||
distribution_inputs,
|
||||
action_dist_class,
|
||||
model=None,
|
||||
explore=True,
|
||||
timestep=None):
|
||||
"""Returns a (possibly) exploratory action and its log-likelihood.
|
||||
|
||||
Given the Model's logits outputs and action distribution, returns an
|
||||
exploratory action.
|
||||
|
||||
Args:
|
||||
distribution_inputs (any): The output coming from the model,
|
||||
ready for parameterizing a distribution
|
||||
(e.g. q-values or PG-logits).
|
||||
action_dist_class (class): The action distribution class
|
||||
to use.
|
||||
model (ModelV2): The Model object.
|
||||
explore (bool): True: "Normal" exploration behavior.
|
||||
False: Suppress all exploratory behavior and return
|
||||
a deterministic action.
|
||||
timestep (int): The current sampling time step. If None, the
|
||||
component should try to use an internal counter, which it
|
||||
then increments by 1. If provided, will set the internal
|
||||
counter to the given value.
|
||||
|
||||
Returns:
|
||||
Tuple:
|
||||
- The chosen exploration action or a tf-op to fetch the exploration
|
||||
action from the graph.
|
||||
- The log-likelihood of the exploration action.
|
||||
"""
|
||||
pass
|
||||
and an ``explore`` switch and outputs a tuple of a) action and
|
||||
b) log-likelihood:
|
||||
|
||||
.. literalinclude:: ../../rllib/utils/exploration/exploration.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin_get_exploration_action__
|
||||
:end-before: __sphinx_doc_end_get_exploration_action__
|
||||
|
||||
On the highest level, the ``Trainer.compute_action`` and ``Policy.compute_action(s)``
|
||||
methods have a boolean ``explore`` switch, which is passed into
|
||||
``Exploration.get_exploration_action``. If ``None``, the value of
|
||||
``Trainer.config[“explore”]`` is used.
|
||||
Hence ``config[“explore”]`` describes the default behavior of the policy and
|
||||
e.g. allows switching off any exploration easily for evaluation purposes
|
||||
(see :ref:`CustomEvaluation`).
|
||||
``Exploration.get_exploration_action``. If ``explore=None``, the value of
|
||||
``Trainer.config[“explore”]`` is used, which thus serves as a main switch for
|
||||
exploratory behavior, allowing e.g. turning off any exploration easily for
|
||||
evaluation purposes (see :ref:`CustomEvaluation`).
|
||||
|
||||
The following are example excerpts from different Trainers' configs
|
||||
(see rllib/agents/trainer.py) to setup different exploration behaviors:
|
||||
|
@ -688,13 +654,14 @@ The following are example excerpts from different Trainers' configs
|
|||
"temperature": 1.0,
|
||||
},
|
||||
|
||||
# c) PPO: see rllib/agents/ppo/ppo.py
|
||||
# Behavior: The algo samples stochastically by default from the
|
||||
# c) All policy-gradient algos and SAC: see rllib/agents/trainer.py
|
||||
# Behavior: The algo samples stochastically from the
|
||||
# model-parameterized distribution. This is the global Trainer default
|
||||
# setting defined in trainer.py and used by all PG-type algos.
|
||||
# setting defined in trainer.py and used by all PG-type algos (plus SAC).
|
||||
"explore": True,
|
||||
"exploration_config": {
|
||||
"type": "StochasticSampling",
|
||||
"random_timesteps": 0, # timesteps at beginning, over which to act uniformly randomly
|
||||
},
|
||||
|
||||
|
||||
|
|
|
@ -69,6 +69,9 @@ class Exploration:
|
|||
"""
|
||||
pass
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin_get_exploration_action__
|
||||
|
||||
@DeveloperAPI
|
||||
def get_exploration_action(self,
|
||||
*,
|
||||
|
@ -98,6 +101,9 @@ class Exploration:
|
|||
"""
|
||||
pass
|
||||
|
||||
# __sphinx_doc_end_get_exploration_action__
|
||||
# yapf: enable
|
||||
|
||||
@DeveloperAPI
|
||||
def on_episode_start(self,
|
||||
policy,
|
||||
|
|
Loading…
Add table
Reference in a new issue