mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[docs] Refactor (some of) RLlib training API docs using literalinclude (#24436)
Per the [Ray docs contributing guide](https://docs.ray.io/en/master/ray-contribute/docs.html), code chunks should be in `.py` files and pulled in via `literalinclude` rather than placed directly in `.rst` files. This PR takes a small step in doing this for the RLlib docs, specifically for the training API doc page. Note that I had to make some changes to the code itself so that it would run, namely adding missing numpy imports and changing `model.from_batch(...)` to `model(...)` in a couple places. Co-authored-by: Max Pumperla <max.pumperla@googlemail.com>
This commit is contained in:
parent
99d25d4d4e
commit
8d6548a74a
3 changed files with 155 additions and 139 deletions
13
doc/BUILD
13
doc/BUILD
|
@ -154,6 +154,18 @@ py_test_run_all_subdirectory(
|
|||
tags = ["exclusive", "team:ml"],
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Test all doc/source/rllib/doc_code code included in rst/md files.
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test_run_all_subdirectory(
|
||||
size = "medium",
|
||||
include = ["source/rllib/doc_code/*.py"],
|
||||
exclude = [],
|
||||
extra_srcs = [],
|
||||
tags = ["exclusive", "team:ml"],
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Test all doc/source/ray-air/doc_code code included in rst/md files.
|
||||
# --------------------------------------------------------------------
|
||||
|
@ -166,7 +178,6 @@ py_test_run_all_subdirectory(
|
|||
tags = ["exclusive", "team:ml"],
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Test all doc/source/ray-overview/doc_test snippets, used on ray.io
|
||||
# --------------------------------------------------------------------
|
||||
|
|
131
doc/source/rllib/doc_code/training.py
Normal file
131
doc/source/rllib/doc_code/training.py
Normal file
|
@ -0,0 +1,131 @@
|
|||
# flake8: noqa
|
||||
|
||||
# __preprocessing_observations_start__
|
||||
import gym
|
||||
|
||||
env = gym.make("Pong-v0")
|
||||
|
||||
# RLlib uses preprocessors to implement transforms such as one-hot encoding
|
||||
# and flattening of tuple and dict observations.
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
|
||||
prep = get_preprocessor(env.observation_space)(env.observation_space)
|
||||
# <ray.rllib.models.preprocessors.GenericPixelPreprocessor object at 0x7fc4d049de80>
|
||||
|
||||
# Observations should be preprocessed prior to feeding into a model
|
||||
env.reset().shape
|
||||
# (210, 160, 3)
|
||||
prep.transform(env.reset()).shape
|
||||
# (84, 84, 3)
|
||||
# __preprocessing_observations_end__
|
||||
|
||||
# __query_action_dist_start__
|
||||
# Get a reference to the policy
|
||||
import numpy as np
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
|
||||
trainer = PPOTrainer(env="CartPole-v0", config={"framework": "tf2", "num_workers": 0})
|
||||
policy = trainer.get_policy()
|
||||
# <ray.rllib.policy.eager_tf_policy.PPOTFPolicy_eager object at 0x7fd020165470>
|
||||
|
||||
# Run a forward pass to get model output logits. Note that complex observations
|
||||
# must be preprocessed as in the above code block.
|
||||
logits, _ = policy.model({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
|
||||
# (<tf.Tensor: id=1274, shape=(1, 2), dtype=float32, numpy=...>, [])
|
||||
|
||||
# Compute action distribution given logits
|
||||
policy.dist_class
|
||||
# <class_object 'ray.rllib.models.tf.tf_action_dist.Categorical'>
|
||||
dist = policy.dist_class(logits, policy.model)
|
||||
# <ray.rllib.models.tf.tf_action_dist.Categorical object at 0x7fd02301d710>
|
||||
|
||||
# Query the distribution for samples, sample logps
|
||||
dist.sample()
|
||||
# <tf.Tensor: id=661, shape=(1,), dtype=int64, numpy=..>
|
||||
dist.logp([1])
|
||||
# <tf.Tensor: id=1298, shape=(1,), dtype=float32, numpy=...>
|
||||
|
||||
# Get the estimated values for the most recent forward pass
|
||||
policy.model.value_function()
|
||||
# <tf.Tensor: id=670, shape=(1,), dtype=float32, numpy=...>
|
||||
|
||||
policy.model.base_model.summary()
|
||||
"""
|
||||
Model: "model"
|
||||
_____________________________________________________________________
|
||||
Layer (type) Output Shape Param # Connected to
|
||||
=====================================================================
|
||||
observations (InputLayer) [(None, 4)] 0
|
||||
_____________________________________________________________________
|
||||
fc_1 (Dense) (None, 256) 1280 observations[0][0]
|
||||
_____________________________________________________________________
|
||||
fc_value_1 (Dense) (None, 256) 1280 observations[0][0]
|
||||
_____________________________________________________________________
|
||||
fc_2 (Dense) (None, 256) 65792 fc_1[0][0]
|
||||
_____________________________________________________________________
|
||||
fc_value_2 (Dense) (None, 256) 65792 fc_value_1[0][0]
|
||||
_____________________________________________________________________
|
||||
fc_out (Dense) (None, 2) 514 fc_2[0][0]
|
||||
_____________________________________________________________________
|
||||
value_out (Dense) (None, 1) 257 fc_value_2[0][0]
|
||||
=====================================================================
|
||||
Total params: 134,915
|
||||
Trainable params: 134,915
|
||||
Non-trainable params: 0
|
||||
_____________________________________________________________________
|
||||
"""
|
||||
# __query_action_dist_end__
|
||||
|
||||
|
||||
# __get_q_values_dqn_start__
|
||||
# Get a reference to the model through the policy
|
||||
import numpy as np
|
||||
from ray.rllib.agents.dqn import DQNTrainer
|
||||
|
||||
trainer = DQNTrainer(env="CartPole-v0", config={"framework": "tf2"})
|
||||
model = trainer.get_policy().model
|
||||
# <ray.rllib.models.catalog.FullyConnectedNetwork_as_DistributionalQModel ...>
|
||||
|
||||
# List of all model variables
|
||||
model.variables()
|
||||
|
||||
# Run a forward pass to get base model output. Note that complex observations
|
||||
# must be preprocessed. An example of preprocessing is examples/saving_experiences.py
|
||||
model_out = model({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
|
||||
# (<tf.Tensor: id=832, shape=(1, 256), dtype=float32, numpy=...)
|
||||
|
||||
# Access the base Keras models (all default models have a base)
|
||||
model.base_model.summary()
|
||||
"""
|
||||
Model: "model"
|
||||
_______________________________________________________________________
|
||||
Layer (type) Output Shape Param # Connected to
|
||||
=======================================================================
|
||||
observations (InputLayer) [(None, 4)] 0
|
||||
_______________________________________________________________________
|
||||
fc_1 (Dense) (None, 256) 1280 observations[0][0]
|
||||
_______________________________________________________________________
|
||||
fc_out (Dense) (None, 256) 65792 fc_1[0][0]
|
||||
_______________________________________________________________________
|
||||
value_out (Dense) (None, 1) 257 fc_1[0][0]
|
||||
=======================================================================
|
||||
Total params: 67,329
|
||||
Trainable params: 67,329
|
||||
Non-trainable params: 0
|
||||
______________________________________________________________________________
|
||||
"""
|
||||
|
||||
# Access the Q value model (specific to DQN)
|
||||
print(model.get_q_value_distributions(model_out)[0])
|
||||
# tf.Tensor([[ 0.13023682 -0.36805138]], shape=(1, 2), dtype=float32)
|
||||
# ^ exact numbers may differ due to randomness
|
||||
|
||||
model.q_value_head.summary()
|
||||
|
||||
# Access the state value model (specific to DQN)
|
||||
print(model.get_state_value(model_out))
|
||||
# tf.Tensor([[0.09381643]], shape=(1, 1), dtype=float32)
|
||||
# ^ exact number may differ due to randomness
|
||||
|
||||
model.state_value_head.summary()
|
||||
# __get_q_values_dqn_end__
|
|
@ -935,151 +935,25 @@ First, install the dependencies:
|
|||
|
||||
Then for the code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>>> import gym
|
||||
>>> env = gym.make("Pong-v0")
|
||||
|
||||
# RLlib uses preprocessors to implement transforms such as one-hot encoding
|
||||
# and flattening of tuple and dict observations.
|
||||
>>> from ray.rllib.models.preprocessors import get_preprocessor
|
||||
>>> prep = get_preprocessor(env.observation_space)(env.observation_space)
|
||||
<ray.rllib.models.preprocessors.GenericPixelPreprocessor object at 0x7fc4d049de80>
|
||||
|
||||
# Observations should be preprocessed prior to feeding into a model
|
||||
>>> env.reset().shape
|
||||
(210, 160, 3)
|
||||
>>> prep.transform(env.reset()).shape
|
||||
(84, 84, 3)
|
||||
.. literalinclude:: doc_code/training.py
|
||||
:language: python
|
||||
:start-after: __preprocessing_observations_start__
|
||||
:end-before: __preprocessing_observations_end__
|
||||
|
||||
**Example: Querying a policy's action distribution**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Get a reference to the policy
|
||||
>>> from ray.rllib.agents.ppo import PPOTrainer
|
||||
>>> trainer = PPOTrainer(env="CartPole-v0", config={"framework": "tf2", "num_workers": 0})
|
||||
>>> policy = trainer.get_policy()
|
||||
<ray.rllib.policy.eager_tf_policy.PPOTFPolicy_eager object at 0x7fd020165470>
|
||||
|
||||
# Run a forward pass to get model output logits. Note that complex observations
|
||||
# must be preprocessed as in the above code block.
|
||||
>>> logits, _ = policy.model.from_batch({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
|
||||
(<tf.Tensor: id=1274, shape=(1, 2), dtype=float32, numpy=...>, [])
|
||||
|
||||
# Compute action distribution given logits
|
||||
>>> policy.dist_class
|
||||
<class_object 'ray.rllib.models.tf.tf_action_dist.Categorical'>
|
||||
>>> dist = policy.dist_class(logits, policy.model)
|
||||
<ray.rllib.models.tf.tf_action_dist.Categorical object at 0x7fd02301d710>
|
||||
|
||||
# Query the distribution for samples, sample logps
|
||||
>>> dist.sample()
|
||||
<tf.Tensor: id=661, shape=(1,), dtype=int64, numpy=..>
|
||||
>>> dist.logp([1])
|
||||
<tf.Tensor: id=1298, shape=(1,), dtype=float32, numpy=...>
|
||||
|
||||
# Get the estimated values for the most recent forward pass
|
||||
>>> policy.model.value_function()
|
||||
<tf.Tensor: id=670, shape=(1,), dtype=float32, numpy=...>
|
||||
|
||||
>>> policy.model.base_model.summary()
|
||||
Model: "model"
|
||||
_____________________________________________________________________
|
||||
Layer (type) Output Shape Param # Connected to
|
||||
=====================================================================
|
||||
observations (InputLayer) [(None, 4)] 0
|
||||
_____________________________________________________________________
|
||||
fc_1 (Dense) (None, 256) 1280 observations[0][0]
|
||||
_____________________________________________________________________
|
||||
fc_value_1 (Dense) (None, 256) 1280 observations[0][0]
|
||||
_____________________________________________________________________
|
||||
fc_2 (Dense) (None, 256) 65792 fc_1[0][0]
|
||||
_____________________________________________________________________
|
||||
fc_value_2 (Dense) (None, 256) 65792 fc_value_1[0][0]
|
||||
_____________________________________________________________________
|
||||
fc_out (Dense) (None, 2) 514 fc_2[0][0]
|
||||
_____________________________________________________________________
|
||||
value_out (Dense) (None, 1) 257 fc_value_2[0][0]
|
||||
=====================================================================
|
||||
Total params: 134,915
|
||||
Trainable params: 134,915
|
||||
Non-trainable params: 0
|
||||
_____________________________________________________________________
|
||||
.. literalinclude:: doc_code/training.py
|
||||
:language: python
|
||||
:start-after: __query_action_dist_start__
|
||||
:end-before: __query_action_dist_end__
|
||||
|
||||
**Example: Getting Q values from a DQN model**
|
||||
|
||||
.. code-block:: python
|
||||
.. literalinclude:: doc_code/training.py
|
||||
:language: python
|
||||
:start-after: __get_q_values_dqn_start__
|
||||
:end-before: __get_q_values_dqn_end__
|
||||
|
||||
# Get a reference to the model through the policy
|
||||
>>> from ray.rllib.agents.dqn import DQNTrainer
|
||||
>>> trainer = DQNTrainer(env="CartPole-v0", config={"framework": "tf2"})
|
||||
>>> model = trainer.get_policy().model
|
||||
<ray.rllib.models.catalog.FullyConnectedNetwork_as_DistributionalQModel ...>
|
||||
|
||||
# List of all model variables
|
||||
>>> model.variables()
|
||||
[<tf.Variable 'default_policy/fc_1/kernel:0' shape=(4, 256) dtype=float32>, ...]
|
||||
|
||||
# Run a forward pass to get base model output. Note that complex observations
|
||||
# must be preprocessed. An example of preprocessing is examples/saving_experiences.py
|
||||
>>> model_out = model.from_batch({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
|
||||
(<tf.Tensor: id=832, shape=(1, 256), dtype=float32, numpy=...)
|
||||
|
||||
# Access the base Keras models (all default models have a base)
|
||||
>>> model.base_model.summary()
|
||||
Model: "model"
|
||||
_______________________________________________________________________
|
||||
Layer (type) Output Shape Param # Connected to
|
||||
=======================================================================
|
||||
observations (InputLayer) [(None, 4)] 0
|
||||
_______________________________________________________________________
|
||||
fc_1 (Dense) (None, 256) 1280 observations[0][0]
|
||||
_______________________________________________________________________
|
||||
fc_out (Dense) (None, 256) 65792 fc_1[0][0]
|
||||
_______________________________________________________________________
|
||||
value_out (Dense) (None, 1) 257 fc_1[0][0]
|
||||
=======================================================================
|
||||
Total params: 67,329
|
||||
Trainable params: 67,329
|
||||
Non-trainable params: 0
|
||||
______________________________________________________________________________
|
||||
|
||||
# Access the Q value model (specific to DQN)
|
||||
>>> model.get_q_value_distributions(model_out)
|
||||
[<tf.Tensor: id=891, shape=(1, 2)>, <tf.Tensor: id=896, shape=(1, 2, 1)>]
|
||||
|
||||
>>> model.q_value_head.summary()
|
||||
Model: "model_1"
|
||||
_________________________________________________________________
|
||||
Layer (type) Output Shape Param #
|
||||
=================================================================
|
||||
model_out (InputLayer) [(None, 256)] 0
|
||||
_________________________________________________________________
|
||||
lambda (Lambda) [(None, 2), (None, 2, 1), 66306
|
||||
=================================================================
|
||||
Total params: 66,306
|
||||
Trainable params: 66,306
|
||||
Non-trainable params: 0
|
||||
_________________________________________________________________
|
||||
|
||||
# Access the state value model (specific to DQN)
|
||||
>>> model.get_state_value(model_out)
|
||||
<tf.Tensor: id=913, shape=(1, 1), dtype=float32>
|
||||
|
||||
>>> model.state_value_head.summary()
|
||||
Model: "model_2"
|
||||
_________________________________________________________________
|
||||
Layer (type) Output Shape Param #
|
||||
=================================================================
|
||||
model_out (InputLayer) [(None, 256)] 0
|
||||
_________________________________________________________________
|
||||
lambda_1 (Lambda) (None, 1) 66049
|
||||
=================================================================
|
||||
Total params: 66,049
|
||||
Trainable params: 66,049
|
||||
Non-trainable params: 0
|
||||
_________________________________________________________________
|
||||
|
||||
This is especially useful when used with `custom model classes <rllib-models.html>`__.
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue