2018-07-01 00:05:08 -07:00
RLlib Training APIs
===================
Getting Started
---------------
2019-04-07 00:36:18 -07:00
At a high level, RLlib provides an `` Trainer `` class which
holds a policy for environment interaction. Through the trainer interface, the policy can
be trained, checkpointed, or an action computed. In multi-agent training, the trainer manages the querying and optimization of multiple policies at once.
2018-07-01 00:05:08 -07:00
.. image :: rllib-api.svg
2019-04-07 00:36:18 -07:00
You can train a simple DQN trainer with the following command:
2018-07-01 00:05:08 -07:00
.. code-block :: bash
2019-09-17 04:44:20 -04:00
rllib train --run DQN --env CartPole-v0 # --eager [--trace] for eager execution
2018-07-01 00:05:08 -07:00
By default, the results will be logged to a subdirectory of `` ~/ray_results `` .
This subdirectory will contain a file `` params.json `` which contains the
hyperparameters, a file `` result.json `` which contains a training summary
for each episode and a TensorBoard file that can be used to visualize
training process with TensorBoard by running
.. code-block :: bash
tensorboard --logdir=~/ray_results
2018-12-04 17:36:06 -08:00
The `` rllib train `` command (same as the `` train.py `` script in the repo) has a number of options you can show by running:
2018-07-01 00:05:08 -07:00
.. code-block :: bash
2018-12-04 17:36:06 -08:00
rllib train --help
-or-
2019-08-05 23:25:49 -07:00
python ray/rllib/train.py --help
2018-07-01 00:05:08 -07:00
The most important options are for choosing the environment
with `` --env `` (any OpenAI gym environment including ones registered by the user
can be used) and for choosing the algorithm with `` --run ``
2020-03-14 12:05:04 -07:00
(available options include `` SAC `` , `` PPO `` , `` PG `` , `` A2C `` , `` A3C `` , `` IMPALA `` , `` ES `` , `` DDPG `` , `` DQN `` , `` MARWIL `` , `` APEX `` , and `` APEX_DDPG `` ).
2018-07-01 00:05:08 -07:00
2019-04-07 00:36:18 -07:00
Evaluating Trained Policies
~~~~~~~~~~~~~~~~~~~~~~~~~~~
2018-10-16 15:55:11 -07:00
2019-04-07 00:36:18 -07:00
In order to save checkpoints from which to evaluate policies,
2018-10-16 15:55:11 -07:00
set `` --checkpoint-freq `` (number of training iterations between checkpoints)
2018-12-04 17:36:06 -08:00
when running `` rllib train `` .
2018-10-16 15:55:11 -07:00
2019-04-07 00:36:18 -07:00
An example of evaluating a previously trained DQN policy is as follows:
2018-10-16 15:55:11 -07:00
.. code-block :: bash
2018-12-04 17:36:06 -08:00
rllib rollout \
~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1 \
--run DQN --env CartPole-v0 --steps 10000
2018-10-16 15:55:11 -07:00
2019-04-07 00:36:18 -07:00
The `` rollout.py `` helper script reconstructs a DQN policy from the checkpoint
2018-11-19 20:55:27 -08:00
located at `` ~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1 ``
2018-10-16 15:55:11 -07:00
and renders its behavior in the environment specified by `` --env `` .
2019-11-06 04:34:18 +00:00
(Type `` rllib rollout --help `` to see the available evaluation options.)
2020-02-01 22:12:12 -08:00
For more advanced evaluation functionality, refer to `Customized Evaluation During Training <#customized-evaluation-during-training> `__ .
2018-10-16 15:55:11 -07:00
Configuration
-------------
2018-07-01 00:05:08 -07:00
Specifying Parameters
~~~~~~~~~~~~~~~~~~~~~
2019-08-05 23:25:49 -07:00
Each algorithm has specific hyperparameters that can be set with `` --config `` , in addition to a number of `common hyperparameters <https://github.com/ray-project/ray/blob/master/rllib/agents/trainer.py> `__ . See the
2018-07-01 00:05:08 -07:00
`algorithms documentation <rllib-algorithms.html> `__ for more information.
2018-11-03 18:48:32 -07:00
In an example below, we train A2C by specifying 8 workers through the config flag.
2018-07-01 00:05:08 -07:00
.. code-block :: bash
2018-12-04 17:36:06 -08:00
rllib train --env=PongDeterministic-v4 --run=A2C --config '{"num_workers": 8}'
2018-07-01 00:05:08 -07:00
2018-09-03 11:12:23 -07:00
Specifying Resources
~~~~~~~~~~~~~~~~~~~~
2019-04-07 00:36:18 -07:00
You can control the degree of parallelism used by setting the `` num_workers `` hyperparameter for most algorithms. The number of GPUs the driver should use can be set via the `` num_gpus `` option. Similarly, the resource allocation to workers can be controlled via `` num_cpus_per_worker `` , `` num_gpus_per_worker `` , and `` custom_resources_per_worker `` . The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five trainers onto one GPU by setting `` num_gpus: 0.2 `` .
2018-07-01 00:05:08 -07:00
2020-06-11 10:06:04 +02:00
For synchronous algorithms like PPO and A2C, the driver and workers can make use of the same GPU. To do this for an amount of `` n `` GPUS:
.. code-block :: python
gpu_count = n
num_gpus = 0.0001 # Driver GPU
num_gpus_per_worker = (gpu_count - num_gpus) / num_workers
2020-03-14 12:05:04 -07:00
.. Original image: https://docs.google.com/drawings/d/14QINFvx3grVyJyjAnjggOCEVN-Iq6pYVJ3jA2S6j8z0/edit?usp=sharing
2018-11-27 23:35:19 -08:00
.. image :: rllib-config.svg
2020-03-27 22:05:43 -07:00
Scaling Guide
~~~~~~~~~~~~~
Here are some rules of thumb for scaling training with RLlib.
1. If the environment is slow and cannot be replicated (e.g., since it requires interaction with physical systems), then you should use a sample-efficient off-policy algorithm such as :ref: `DQN <dqn>` or :ref: `SAC <sac>` . These algorithms default to `` num_workers: 0 `` for single-process operation. Consider also batch RL training with the `offline data <rllib-offline.html> `__ API.
2. If the environment is fast and the model is small (most models for RL are), use time-efficient algorithms such as :ref: `PPO <ppo>` , :ref: `IMPALA <impala>` , or :ref: `APEX <apex>` . These can be scaled by increasing `` num_workers `` to add rollout workers. It may also make sense to enable `vectorization <rllib-env.html#vectorized> `__ for inference. If the learner becomes a bottleneck, multiple GPUs can be used for learning by setting `` num_gpus > 1 `` .
3. If the model is compute intensive (e.g., a large deep residual network) and inference is the bottleneck, consider allocating GPUs to workers by setting `` num_gpus_per_worker: 1 `` . If you only have a single GPU, consider `` num_workers: 0 `` to use the learner GPU for inference. For efficient use of GPU time, use a small number of GPU workers and a large number of `envs per worker <rllib-env.html#vectorized> `__ .
2020-04-19 10:20:23 +02:00
2020-03-27 22:05:43 -07:00
4. Finally, if both model and environment are compute intensive, then enable `remote worker envs <rllib-env.html#vectorized> `__ with `async batching <rllib-env.html#vectorized> `__ by setting `` remote_worker_envs: True `` and optionally `` remote_env_batch_wait_ms `` . This batches inference on GPUs in the rollout workers while letting envs run asynchronously in separate actors, similar to the `SEED <https://ai.googleblog.com/2020/03/massively-scaling-reinforcement.html> `__ architecture. The number of workers and number of envs per worker should be tuned to maximize GPU utilization. If your env requires GPUs to function, or if multi-node SGD is needed, then also consider :ref: `DD-PPO <ddppo>` .
2018-10-16 15:55:11 -07:00
Common Parameters
~~~~~~~~~~~~~~~~~
2018-07-01 00:05:08 -07:00
2019-04-07 00:36:18 -07:00
The following is a list of the common algorithm hyperparameters:
2018-07-01 00:05:08 -07:00
2019-08-05 23:25:49 -07:00
.. literalinclude :: ../../rllib/agents/trainer.py
2018-10-16 15:55:11 -07:00
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
2018-07-01 00:05:08 -07:00
Tuned Examples
2018-08-01 20:53:53 -07:00
~~~~~~~~~~~~~~
2018-07-01 00:05:08 -07:00
Some good hyperparameters and settings are available in
2019-08-05 23:25:49 -07:00
`the repository <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples> `__
2018-07-01 00:05:08 -07:00
(some of them are tuned to run on GPUs). If you find better settings or tune
an algorithm on a different domain, consider submitting a Pull Request!
2018-12-04 17:36:06 -08:00
You can run these with the `` rllib train `` command as follows:
2018-08-01 20:53:53 -07:00
.. code-block :: bash
2018-12-04 17:36:06 -08:00
rllib train -f /path/to/tuned/example.yaml
2018-08-01 20:53:53 -07:00
2020-02-01 22:12:12 -08:00
Basic Python API
----------------
2018-07-01 00:05:08 -07:00
2018-09-30 18:36:22 -07:00
The Python API provides the needed flexibility for applying RLlib to new problems. You will need to use this API if you wish to use `custom environments, preprocessors, or models <rllib-models.html> `__ with RLlib.
2018-07-01 00:05:08 -07:00
2019-08-05 23:25:49 -07:00
Here is an example of the basic usage (for a more complete example, see `custom_env.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py> `__ ):
2018-07-01 00:05:08 -07:00
.. code-block :: python
import ray
import ray.rllib.agents.ppo as ppo
2018-08-15 10:19:41 -07:00
from ray.tune.logger import pretty_print
2018-07-01 00:05:08 -07:00
ray.init()
config = ppo.DEFAULT_CONFIG.copy()
2018-08-28 18:13:36 -07:00
config["num_gpus"] = 0
config["num_workers"] = 1
2019-08-23 02:21:11 -04:00
config["eager"] = False
2019-04-07 00:36:18 -07:00
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
2018-07-01 00:05:08 -07:00
2019-04-07 00:36:18 -07:00
# Can optionally call trainer.restore(path) to load a checkpoint.
2018-07-01 00:05:08 -07:00
for i in range(1000):
# Perform one iteration of training the policy with PPO
2019-04-07 00:36:18 -07:00
result = trainer.train()
2018-08-15 10:19:41 -07:00
print(pretty_print(result))
2018-07-01 00:05:08 -07:00
if i % 100 == 0:
2019-04-07 00:36:18 -07:00
checkpoint = trainer.save()
2018-07-01 00:05:08 -07:00
print("checkpoint saved at", checkpoint)
2020-03-23 20:19:30 +01:00
# Also, in case you have trained a model outside of ray/RLlib and have created
# an h5-file with weight values in it, e.g.
# my_keras_model_trained_outside_rllib.save_weights("model.h5")
# (see: https://keras.io/models/about-keras-models/)
# ... you can load the h5-weights into your Trainer's Policy's ModelV2
# (tf or torch) by doing:
trainer.import_model("my_weights.h5")
# NOTE: In order for this to work, your (custom) model needs to implement
# the `import_from_h5` method.
# See https://github.com/ray-project/ray/blob/master/rllib/tests/test_model_imports.py
# for detailed examples for tf- and torch trainers/models.
2018-08-19 11:00:55 -07:00
2018-08-01 16:29:27 -07:00
.. note ::
2020-04-25 18:25:56 -07:00
It's recommended that you run RLlib trainers with :ref: `Tune <tune-index>` , for easy experiment management and visualization of results. Just set `` "run": ALG_NAME, "env": ENV_NAME `` in the experiment config.
2018-08-01 16:29:27 -07:00
2020-04-25 18:25:56 -07:00
All RLlib trainers are compatible with the :ref: `Tune API <tune-60-seconds>` . This enables them to be easily used in experiments with :ref: `Tune <tune-index>` . For example, the following code performs a simple hyperparam sweep of PPO:
2018-08-01 16:29:27 -07:00
.. code-block :: python
import ray
2019-03-30 14:07:50 -07:00
from ray import tune
2018-08-01 16:29:27 -07:00
ray.init()
2019-03-30 14:07:50 -07:00
tune.run(
"PPO",
stop={"episode_reward_mean": 200},
config={
2018-08-01 16:29:27 -07:00
"env": "CartPole-v0",
2019-03-30 14:07:50 -07:00
"num_gpus": 0,
"num_workers": 1,
"lr": tune.grid_search([0.01, 0.001, 0.0001]),
2019-08-23 02:21:11 -04:00
"eager": False,
2018-08-01 16:29:27 -07:00
},
2019-03-30 14:07:50 -07:00
)
2018-08-01 16:29:27 -07:00
Tune will schedule the trials to run in parallel on your Ray cluster:
::
== Status ==
Using FIFO scheduling algorithm.
Resources requested: 4/4 CPUs, 0/0 GPUs
2018-10-16 15:55:11 -07:00
Result logdir: ~/ray_results/my_experiment
2018-08-01 16:29:27 -07:00
PENDING trials:
2019-01-01 12:01:27 +08:00
- PPO_CartPole-v0_2_lr=0.0001: PENDING
2018-08-01 16:29:27 -07:00
RUNNING trials:
2019-01-01 12:01:27 +08:00
- PPO_CartPole-v0_0_lr=0.01: RUNNING [pid=21940], 16 s, 4013 ts, 22 rew
- PPO_CartPole-v0_1_lr=0.001: RUNNING [pid=21942], 27 s, 8111 ts, 54.7 rew
2018-07-01 00:05:08 -07:00
2019-12-03 00:10:50 -08:00
Computing Actions
~~~~~~~~~~~~~~~~~
The simplest way to programmatically compute actions from a trained agent is to use `` trainer.compute_action() `` .
This method preprocesses and filters the observation before passing it to the agent policy.
For more advanced usage, you can access the `` workers `` and policies held by the trainer
directly as `` compute_action() `` does:
.. code-block :: python
class Trainer(Trainable):
@PublicAPI
def compute_action(self,
observation,
state=None,
prev_action=None,
prev_reward=None,
info=None,
policy_id=DEFAULT_POLICY_ID,
full_fetch=False):
"""Computes an action for the specified policy.
Note that you can also access the policy object through
self.get_policy(policy_id) and call compute_actions() on it directly.
Arguments:
observation (obj): observation from the environment.
state (list): RNN hidden state, if any. If state is not None,
then all of compute_single_action(...) is returned
(computed action, rnn state, logits dictionary).
Otherwise compute_single_action(...)[0] is
returned (computed action).
prev_action (obj): previous action value, if any
prev_reward (int): previous reward, if any
info (dict): info object, if any
policy_id (str): policy to query (only applies to multi-agent).
full_fetch (bool): whether to return extra action fetch results.
This is always set to true if RNN state is specified.
Returns:
Just the computed action if full_fetch=False, or the full output
of policy.compute_actions() otherwise.
"""
if state is None:
state = []
preprocessed = self.workers.local_worker().preprocessors[
policy_id].transform(observation)
filtered_obs = self.workers.local_worker().filters[policy_id](
preprocessed, update=False)
if state:
return self.get_policy(policy_id).compute_single_action(
filtered_obs,
state,
prev_action,
prev_reward,
info,
clip_actions=self.config["clip_actions"])
res = self.get_policy(policy_id).compute_single_action(
filtered_obs,
state,
prev_action,
prev_reward,
info,
clip_actions=self.config["clip_actions"])
if full_fetch:
return res
else:
return res[0] # backwards compatibility
2018-09-24 19:08:32 -07:00
Accessing Policy State
2018-07-01 00:05:08 -07:00
~~~~~~~~~~~~~~~~~~~~~~
2019-06-03 06:49:24 +08:00
It is common to need to access a trainer's internal state, e.g., to set or get internal weights. In RLlib trainer state is replicated across multiple *rollout workers* (Ray actors) in the cluster. However, you can easily get and update this state between calls to `` train() `` via `` trainer.workers.foreach_worker() `` or `` trainer.workers.foreach_worker_with_index() `` . These functions take a lambda function that is applied with the worker as an arg. You can also return values from these functions and those will be returned as a list.
2018-07-01 00:05:08 -07:00
2019-06-03 06:49:24 +08:00
You can also access just the "master" copy of the trainer state through `` trainer.get_policy() `` or `` trainer.workers.local_worker() `` , but note that updates here may not be immediately reflected in remote replicas if you have configured `` num_workers > 0 `` . For example, to access the weights of a local TF policy, you can run `` trainer.get_policy().get_weights() `` . This is also equivalent to `` trainer.workers.local_worker().policy_map["default_policy"].get_weights() `` :
2018-09-24 19:08:32 -07:00
.. code-block :: python
2019-01-06 19:37:35 -08:00
# Get weights of the default local policy
2019-04-07 00:36:18 -07:00
trainer.get_policy().get_weights()
2018-09-24 19:08:32 -07:00
# Same as above
2019-06-03 06:49:24 +08:00
trainer.workers.local_worker().policy_map["default_policy"].get_weights()
2018-09-24 19:08:32 -07:00
2019-06-03 06:49:24 +08:00
# Get list of weights of each worker, including remote replicas
trainer.workers.foreach_worker(lambda ev: ev.get_policy().get_weights())
2018-09-24 19:08:32 -07:00
# Same as above
2019-06-03 06:49:24 +08:00
trainer.workers.foreach_worker_with_index(lambda ev, i: ev.get_policy().get_weights())
2018-09-24 19:08:32 -07:00
2019-08-30 21:10:42 -07:00
Accessing Model State
~~~~~~~~~~~~~~~~~~~~~
Similar to accessing policy state, you may want to get a reference to the underlying neural network model being trained. For example, you may want to pre-train it separately, or otherwise update its weights outside of RLlib. This can be done by accessing the `` model `` of the policy:
2019-09-08 23:01:26 -07:00
**Example: Preprocessing observations for feeding into a model**
.. code-block :: python
>>> import gym
>>> env = gym.make("Pong-v0")
# RLlib uses preprocessors to implement transforms such as one-hot encoding
# and flattening of tuple and dict observations.
>>> from ray.rllib.models.preprocessors import get_preprocessor
>>> prep = get_preprocessor(env.observation_space)(env.observation_space)
<ray.rllib.models.preprocessors.GenericPixelPreprocessor object at 0x7fc4d049de80>
# Observations should be preprocessed prior to feeding into a model
>>> env.reset().shape
(210, 160, 3)
>>> prep.transform(env.reset()).shape
(84, 84, 3)
**Example: Querying a policy's action distribution**
.. code-block :: python
# Get a reference to the policy
>>> from ray.rllib.agents.ppo import PPOTrainer
>>> trainer = PPOTrainer(env="CartPole-v0", config={"eager": True, "num_workers": 0})
>>> policy = trainer.get_policy()
<ray.rllib.policy.eager_tf_policy.PPOTFPolicy_eager object at 0x7fd020165470>
# Run a forward pass to get model output logits. Note that complex observations
# must be preprocessed as in the above code block.
>>> logits, _ = policy.model.from_batch({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
(<tf.Tensor: id=1274, shape=(1, 2), dtype=float32, numpy=...>, [])
# Compute action distribution given logits
>>> policy.dist_class
<class_object 'ray.rllib.models.tf.tf_action_dist.Categorical'>
>>> dist = policy.dist_class(logits, policy.model)
<ray.rllib.models.tf.tf_action_dist.Categorical object at 0x7fd02301d710>
# Query the distribution for samples, sample logps
>>> dist.sample()
<tf.Tensor: id=661, shape=(1,), dtype=int64, numpy=..>
>>> dist.logp([1])
<tf.Tensor: id=1298, shape=(1,), dtype=float32, numpy=...>
# Get the estimated values for the most recent forward pass
>>> policy.model.value_function()
<tf.Tensor: id=670, shape=(1,), dtype=float32, numpy=...>
>>> policy.model.base_model.summary()
Model: "model"
_____________________________________________________________________
2020-03-02 01:55:41 +01:00
Layer (type) Output Shape Param # Connected to
2019-09-08 23:01:26 -07:00
=====================================================================
2020-03-02 01:55:41 +01:00
observations (InputLayer) [(None, 4)] 0
2019-09-08 23:01:26 -07:00
_____________________________________________________________________
2020-03-02 01:55:41 +01:00
fc_1 (Dense) (None, 256) 1280 observations[0][0]
2019-09-08 23:01:26 -07:00
_____________________________________________________________________
2020-03-02 01:55:41 +01:00
fc_value_1 (Dense) (None, 256) 1280 observations[0][0]
2019-09-08 23:01:26 -07:00
_____________________________________________________________________
2020-03-02 01:55:41 +01:00
fc_2 (Dense) (None, 256) 65792 fc_1[0][0]
2019-09-08 23:01:26 -07:00
_____________________________________________________________________
2020-03-02 01:55:41 +01:00
fc_value_2 (Dense) (None, 256) 65792 fc_value_1[0][0]
2019-09-08 23:01:26 -07:00
_____________________________________________________________________
2020-03-02 01:55:41 +01:00
fc_out (Dense) (None, 2) 514 fc_2[0][0]
2019-09-08 23:01:26 -07:00
_____________________________________________________________________
2020-03-02 01:55:41 +01:00
value_out (Dense) (None, 1) 257 fc_value_2[0][0]
2019-09-08 23:01:26 -07:00
=====================================================================
Total params: 134,915
Trainable params: 134,915
Non-trainable params: 0
_____________________________________________________________________
**Example: Getting Q values from a DQN model**
2019-08-30 21:10:42 -07:00
.. code-block :: python
2019-09-07 11:50:18 -07:00
# Get a reference to the model through the policy
2019-08-30 21:10:42 -07:00
>>> from ray.rllib.agents.dqn import DQNTrainer
2019-09-08 23:01:26 -07:00
>>> trainer = DQNTrainer(env="CartPole-v0", config={"eager": True})
2019-09-07 11:50:18 -07:00
>>> model = trainer.get_policy().model
2019-08-30 21:10:42 -07:00
<ray.rllib.models.catalog.FullyConnectedNetwork_as_DistributionalQModel ...>
2019-09-07 11:50:18 -07:00
# List of all model variables
>>> model.variables()
2019-08-30 21:10:42 -07:00
[<tf.Variable 'default_policy/fc_1/kernel:0' shape=(4, 256) dtype=float32> , ...]
2019-09-08 23:01:26 -07:00
# Run a forward pass to get base model output. Note that complex observations
# must be preprocessed. An example of preprocessing is examples/saving_experiences.py
>>> model_out = model.from_batch({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
(<tf.Tensor: id=832, shape=(1, 256), dtype=float32, numpy=...)
2019-09-07 11:50:18 -07:00
# Access the base Keras models (all default models have a base)
>>> model.base_model.summary()
Model: "model"
_______________________________________________________________________
2020-03-02 01:55:41 +01:00
Layer (type) Output Shape Param # Connected to
2019-09-07 11:50:18 -07:00
=======================================================================
2020-03-02 01:55:41 +01:00
observations (InputLayer) [(None, 4)] 0
2019-09-07 11:50:18 -07:00
_______________________________________________________________________
fc_1 (Dense) (None, 256) 1280 observations[0][0]
_______________________________________________________________________
2020-03-02 01:55:41 +01:00
fc_out (Dense) (None, 256) 65792 fc_1[0][0]
2019-09-07 11:50:18 -07:00
_______________________________________________________________________
2020-03-02 01:55:41 +01:00
value_out (Dense) (None, 1) 257 fc_1[0][0]
2019-09-07 11:50:18 -07:00
=======================================================================
Total params: 67,329
Trainable params: 67,329
Non-trainable params: 0
______________________________________________________________________________
# Access the Q value model (specific to DQN)
2019-09-08 23:01:26 -07:00
>>> model.get_q_value_distributions(model_out)
[<tf.Tensor: id=891, shape=(1, 2)> , <tf.Tensor: id=896, shape=(1, 2, 1)>]
2019-09-07 11:50:18 -07:00
>>> model.q_value_head.summary()
Model: "model_1"
_________________________________________________________________
2020-03-02 01:55:41 +01:00
Layer (type) Output Shape Param #
2019-09-07 11:50:18 -07:00
=================================================================
2020-03-02 01:55:41 +01:00
model_out (InputLayer) [(None, 256)] 0
2019-09-07 11:50:18 -07:00
_________________________________________________________________
2020-03-02 01:55:41 +01:00
lambda (Lambda) [(None, 2), (None, 2, 1), 66306
2019-09-07 11:50:18 -07:00
=================================================================
Total params: 66,306
Trainable params: 66,306
Non-trainable params: 0
_________________________________________________________________
# Access the state value model (specific to DQN)
2019-09-08 23:01:26 -07:00
>>> model.get_state_value(model_out)
<tf.Tensor: id=913, shape=(1, 1), dtype=float32>
2019-09-07 11:50:18 -07:00
>>> model.state_value_head.summary()
Model: "model_2"
_________________________________________________________________
2020-03-02 01:55:41 +01:00
Layer (type) Output Shape Param #
2019-09-07 11:50:18 -07:00
=================================================================
2020-03-02 01:55:41 +01:00
model_out (InputLayer) [(None, 256)] 0
2019-09-07 11:50:18 -07:00
_________________________________________________________________
2020-03-02 01:55:41 +01:00
lambda_1 (Lambda) (None, 1) 66049
2019-09-07 11:50:18 -07:00
=================================================================
Total params: 66,049
Trainable params: 66,049
Non-trainable params: 0
_________________________________________________________________
2019-08-30 21:10:42 -07:00
This is especially useful when used with `custom model classes <rllib-models.html> `__ .
2020-02-01 22:12:12 -08:00
Advanced Python APIs
--------------------
Custom Training Workflows
~~~~~~~~~~~~~~~~~~~~~~~~~
2020-04-25 18:25:56 -07:00
In the `basic training example <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py> `__ , Tune will call `` train() `` on your trainer once per training iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports :ref: `custom trainable functions <trainable-docs>` that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_train_fn.py> `__ .
2020-02-01 22:12:12 -08:00
For even finer-grained control over training, you can use RLlib's lower-level `building blocks <rllib-concepts.html> `__ directly to implement `fully customized training workflows <https://github.com/ray-project/ray/blob/master/rllib/examples/rollout_worker_custom_workflow.py> `__ .
2018-09-30 18:36:22 -07:00
Global Coordination
~~~~~~~~~~~~~~~~~~~
2020-05-24 20:08:03 -05:00
Sometimes, it is necessary to coordinate between pieces of code that live in different processes managed by RLlib. For example, it can be useful to maintain a global average of a certain variable, or centrally control a hyperparameter used by policies. Ray provides a general way to achieve this through *detached actors* (learn more about Ray actors `here <actors.html> `__ ). These actors are assigned a global name and handles to them can be retrieved using these names. As an example, consider maintaining a shared global counter that is incremented by environments and read periodically from your driver program:
2018-09-30 18:36:22 -07:00
.. code-block :: python
@ray.remote
class Counter:
def __init__(self):
self.count = 0
def inc(self, n):
self.count += n
def get(self):
return self.count
# on the driver
2020-05-24 20:08:03 -05:00
counter = Counter.options(name="global_counter").remote()
2018-09-30 18:36:22 -07:00
print(ray.get(counter.get.remote())) # get the latest count
# in your envs
2020-05-24 20:08:03 -05:00
counter = ray.get_actor("global_counter")
2018-09-30 18:36:22 -07:00
counter.inc.remote(1) # async call to increment the global count
Ray actors provide high levels of performance, so in more complex cases they can be used implement communication patterns such as parameter servers and allreduce.
2018-07-01 00:05:08 -07:00
2018-11-03 18:48:32 -07:00
Callbacks and Custom Metrics
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2020-04-17 02:06:42 +03:00
You can provide callbacks to be called at points during policy evaluation. These callbacks have access to state for the current `episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py> `__ . Certain callbacks such as `` on_postprocess_trajectory `` , `` on_sample_end `` , and `` on_train_result `` are also places where custom postprocessing can be applied to intermediate data or results.
2018-11-03 18:48:32 -07:00
2020-04-17 02:06:42 +03:00
User-defined state can be stored for the `episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py> `__ in the `` episode.user_data `` dict, and custom scalar metrics reported by saving values to the `` episode.custom_metrics `` dict. These custom metrics will be aggregated and reported as part of training results. For a full example, see `custom_metrics_and_callbacks.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_metrics_and_callbacks.py> `__ .
2019-08-24 20:37:45 -07:00
2020-04-17 02:06:42 +03:00
.. autoclass :: ray.rllib.agents.callbacks.DefaultCallbacks
:members:
2018-11-03 18:48:32 -07:00
2019-08-24 20:37:45 -07:00
Visualizing Custom Metrics
~~~~~~~~~~~~~~~~~~~~~~~~~~
2018-11-03 18:48:32 -07:00
Custom metrics can be accessed and visualized like any other training result:
.. image :: custom_metric.png
2020-03-14 11:16:54 -07:00
Customizing Exploration Behavior
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2020-03-02 01:55:41 +01:00
RLlib offers a unified top-level API to configure and customize an agent’ s
exploration behavior, including the decisions (how and whether) to sample
actions from distributions (stochastically or deterministically).
The setup can be done via using built-in Exploration classes
(see `this package <https://github.com/ray-project/ray/blob/master/rllib/utils/exploration/> `__ ),
which are specified (and further configured) inside `` Trainer.config["exploration_config"] `` .
Besides using built-in classes, one can sub-class any of
these built-ins, add custom behavior to it, and use that new class in
the config instead.
Every policy has-an instantiation of one of the Exploration (sub-)classes.
This Exploration object is created from the Trainer’ s
`` config[“exploration_config”] `` dict, which specifies the class to use via the
special “type” key, as well as constructor arguments via all other keys,
e.g.:
.. code-block :: python
# in Trainer.config:
"exploration_config": {
"type": "StochasticSampling", # <- Special `type` key provides class information
"[c'tor arg]" : "[value]", # <- Add any needed constructor args here.
# etc
}
# ...
The following table lists all built-in Exploration sub-classes and the agents
that currently used these by default:
.. View table below at: https://docs.google.com/drawings/d/1dEMhosbu7HVgHEwGBuMlEDyPiwjqp_g6bZ0DzCMaoUM/edit?usp=sharing
.. image :: images/rllib-exploration-api-table.svg
An Exploration class implements the `` get_exploration_action `` method,
in which the exact exploratory behavior is defined.
It takes the model’ s output, the action distribution class, the model itself,
a timestep (the global env-sampling steps already taken),
and an `` explore `` switch and outputs a tuple of 1) action and
2) log-likelihood:
.. code-block :: python
def get_exploration_action(self,
distribution_inputs,
action_dist_class,
model=None,
explore=True,
timestep=None):
"""Returns a (possibly) exploratory action and its log-likelihood.
Given the Model's logits outputs and action distribution, returns an
exploratory action.
Args:
distribution_inputs (any): The output coming from the model,
ready for parameterizing a distribution
(e.g. q-values or PG-logits).
action_dist_class (class): The action distribution class
to use.
model (ModelV2): The Model object.
explore (bool): True: "Normal" exploration behavior.
False: Suppress all exploratory behavior and return
a deterministic action.
timestep (int): The current sampling time step. If None, the
component should try to use an internal counter, which it
then increments by 1. If provided, will set the internal
counter to the given value.
Returns:
Tuple:
- The chosen exploration action or a tf-op to fetch the exploration
action from the graph.
- The log-likelihood of the exploration action.
"""
pass
On the highest level, the `` Trainer.compute_action `` and `` Policy.compute_action(s) ``
methods have a boolean `` explore `` switch, which is passed into
`` Exploration.get_exploration_action `` . If `` None `` , the value of
`` Trainer.config[“explore”] `` is used.
Hence `` config[“explore”] `` describes the default behavior of the policy and
e.g. allows switching off any exploration easily for evaluation purposes
(see :ref: `CustomEvaluation` ).
The following are example excerpts from different Trainers' configs
(see rllib/agents/trainer.py) to setup different exploration behaviors:
.. code-block :: python
# All of the following configs go into Trainer.config.
# 1) Switching *off* exploration by default.
# Behavior: Calling `compute_action(s)` without explicitly setting its `explore`
# param will result in no exploration.
# However, explicitly calling `compute_action(s)` with `explore=True` will
# still(!) result in exploration (per-call overrides default).
"explore": False,
# 2) Switching *on* exploration by default.
# Behavior: Calling `compute_action(s)` without explicitly setting its
# explore param will result in exploration.
# However, explicitly calling `compute_action(s)` with `explore=False`
# will result in no(!) exploration (per-call overrides default).
"explore": True,
# 3) Example exploration_config usages:
# a) DQN: see rllib/agents/dqn/dqn.py
"explore": True,
"exploration_config": {
2020-03-14 12:05:04 -07:00
# Exploration sub-class by name or full path to module+class
# (e.g. “ray.rllib.utils.exploration.epsilon_greedy.EpsilonGreedy”)
"type": "EpsilonGreedy",
2020-03-02 01:55:41 +01:00
# Parameters for the Exploration class' constructor:
"initial_epsilon": 1.0,
"final_epsilon": 0.02,
"epsilon_timesteps": 10000, # Timesteps over which to anneal epsilon.
},
# b) DQN Soft-Q: In order to switch to Soft-Q exploration, do instead:
"explore": True,
"exploration_config": {
"type": "SoftQ",
# Parameters for the Exploration class' constructor:
"temperature": 1.0,
},
# c) PPO: see rllib/agents/ppo/ppo.py
# Behavior: The algo samples stochastically by default from the
# model-parameterized distribution. This is the global Trainer default
# setting defined in trainer.py and used by all PG-type algos.
"explore": True,
"exploration_config": {
"type": "StochasticSampling",
},
.. _CustomEvaluation:
2020-02-01 22:12:12 -08:00
Customized Evaluation During Training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RLlib will report online training rewards, however in some cases you may want to compute
rewards with different settings (e.g., with exploration turned off, or on a specific set
2020-03-14 11:16:54 -07:00
of environment configurations). You can evaluate policies during training by setting
the `` evaluation_interval `` config, and optionally also `` evaluation_num_episodes `` ,
`` evaluation_config `` , `` evaluation_num_workers `` , and `` custom_eval_function ``
2020-02-01 22:12:12 -08:00
(see `trainer.py <https://github.com/ray-project/ray/blob/master/rllib/agents/trainer.py> `__ for further documentation).
2020-03-02 01:55:41 +01:00
By default, exploration is left as-is within `` evaluation_config `` .
However, you can switch off any exploration behavior for the evaluation workers
via:
.. code-block :: python
# Switching off exploration behavior for evaluation workers
# (see rllib/agents/trainer.py)
"evaluation_config": {
"explore": False
}
2020-03-14 11:16:54 -07:00
.. note ::
Policy gradient algorithms are able to find the optimal
policy, even if this is a stochastic one. Setting "explore=False" above
will result in the evaluation workers not using this stochastic policy.
2020-03-02 01:55:41 +01:00
2020-02-01 22:12:12 -08:00
There is an end to end example of how to set up custom online evaluation in `custom_eval.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_eval.py> `__ . Note that if you only want to eval your policy at the end of training, you can set `` evaluation_interval: N `` , where `` N `` is the number of training iterations before stopping.
Below are some examples of how the custom evaluation metrics are reported nested under the `` evaluation `` key of normal training results:
.. code-block :: bash
------------------------------------------------------------------------
Sample output for `python custom_eval.py`
------------------------------------------------------------------------
INFO trainer.py:623 -- Evaluating current policy for 10 episodes.
INFO trainer.py:650 -- Running round 0 of parallel evaluation (2/10 episodes)
INFO trainer.py:650 -- Running round 1 of parallel evaluation (4/10 episodes)
INFO trainer.py:650 -- Running round 2 of parallel evaluation (6/10 episodes)
INFO trainer.py:650 -- Running round 3 of parallel evaluation (8/10 episodes)
INFO trainer.py:650 -- Running round 4 of parallel evaluation (10/10 episodes)
Result for PG_SimpleCorridor_2c6b27dc:
...
evaluation:
custom_metrics: {}
episode_len_mean: 15.864661654135338
episode_reward_max: 1.0
episode_reward_mean: 0.49624060150375937
episode_reward_min: 0.0
episodes_this_iter: 133
.. code-block :: bash
------------------------------------------------------------------------
Sample output for `python custom_eval.py --custom-eval`
------------------------------------------------------------------------
INFO trainer.py:631 -- Running custom eval function <function ...>
Update corridor length to 4
Update corridor length to 7
Custom evaluation round 1
Custom evaluation round 2
Custom evaluation round 3
Custom evaluation round 4
Result for PG_SimpleCorridor_0de4e686:
...
evaluation:
custom_metrics: {}
episode_len_mean: 9.15695067264574
episode_reward_max: 1.0
episode_reward_mean: 0.9596412556053812
episode_reward_min: 0.0
episodes_this_iter: 223
foo: 1
2019-08-24 20:37:45 -07:00
Rewriting Trajectories
~~~~~~~~~~~~~~~~~~~~~~
2019-12-12 00:00:53 +01:00
Note that in the `` on_postprocess_traj `` callback you have full access to the trajectory batch (`` post_batch `` ) and other training state. This can be used to rewrite the trajectory, which has a number of uses including:
2019-09-07 11:50:18 -07:00
2019-08-24 20:37:45 -07:00
* Backdating rewards to previous time steps (e.g., based on values in `` info `` ).
* Adding model-based curiosity bonuses to rewards (you can train the model with a `custom model supervised loss <rllib-models.html#supervised-model-losses> `__ ).
2020-03-14 12:05:04 -07:00
To access the policy / model (`` policy.model `` ) in the callbacks, note that `` info['pre_batch'] `` returns a tuple where the first element is a policy and the second one is the batch itself. You can also access all the rollout worker state using the following call:
.. code-block :: python
from ray.rllib.evaluation.rollout_worker import get_global_worker
# You can use this from any callback to get a reference to the
# RolloutWorker running in the process, which in turn has references to
# all the policies, etc: see rollout_worker.py for more info.
rollout_worker = get_global_worker()
Policy losses are defined over the `` post_batch `` data, so you can mutate that in the callbacks to change what data the policy loss function sees.
2019-09-07 11:50:18 -07:00
Curriculum Learning
~~~~~~~~~~~~~~~~~~~
2018-12-03 23:15:43 -08:00
Let's look at two ways to use the above APIs to implement `curriculum learning <https://bair.berkeley.edu/blog/2017/12/20/reverse-curriculum/> `__ . In curriculum learning, the agent task is adjusted over time to improve the learning process. Suppose that we have an environment class with a `` set_phase() `` method that we can call to adjust the task difficulty over time:
2019-04-07 00:36:18 -07:00
Approach 1: Use the Trainer API and update the environment between calls to `` train() `` . This example shows the trainer being run inside a Tune function:
2018-12-03 23:15:43 -08:00
.. code-block :: python
import ray
from ray import tune
2019-04-07 00:36:18 -07:00
from ray.rllib.agents.ppo import PPOTrainer
2018-12-03 23:15:43 -08:00
def train(config, reporter):
2019-04-07 00:36:18 -07:00
trainer = PPOTrainer(config=config, env=YourEnv)
2018-12-03 23:15:43 -08:00
while True:
2019-04-07 00:36:18 -07:00
result = trainer.train()
2018-12-03 23:15:43 -08:00
reporter(**result)
if result["episode_reward_mean"] > 200:
phase = 2
elif result["episode_reward_mean"] > 100:
phase = 1
else:
phase = 0
2019-06-03 06:49:24 +08:00
trainer.workers.foreach_worker(
2019-02-11 10:40:47 -08:00
lambda ev: ev.foreach_env(
lambda env: env.set_phase(phase)))
2018-12-03 23:15:43 -08:00
ray.init()
2019-03-30 14:07:50 -07:00
tune.run(
train,
config={
"num_gpus": 0,
"num_workers": 2,
},
resources_per_trial={
"cpu": 1,
"gpu": lambda spec: spec.config.num_gpus,
"extra_cpu": lambda spec: spec.config.num_workers,
2018-12-03 23:15:43 -08:00
},
2019-03-30 14:07:50 -07:00
)
2018-12-03 23:15:43 -08:00
Approach 2: Use the callbacks API to update the environment on new training results:
.. code-block :: python
import ray
from ray import tune
def on_train_result(info):
result = info["result"]
if result["episode_reward_mean"] > 200:
phase = 2
elif result["episode_reward_mean"] > 100:
phase = 1
else:
phase = 0
2019-04-07 00:36:18 -07:00
trainer = info["trainer"]
2019-06-03 06:49:24 +08:00
trainer.workers.foreach_worker(
2019-02-11 10:40:47 -08:00
lambda ev: ev.foreach_env(
lambda env: env.set_phase(phase)))
2018-12-03 23:15:43 -08:00
ray.init()
2019-03-30 14:07:50 -07:00
tune.run(
"PPO",
config={
2018-12-03 23:15:43 -08:00
"env": YourEnv,
2019-03-30 14:07:50 -07:00
"callbacks": {
2019-08-31 16:00:10 -07:00
"on_train_result": on_train_result,
2018-12-03 23:15:43 -08:00
},
},
2019-03-30 14:07:50 -07:00
)
2018-12-03 23:15:43 -08:00
Debugging
---------
Gym Monitor
~~~~~~~~~~~
The `` "monitor": true `` config can be used to save Gym episode videos to the result dir. For example:
.. code-block :: bash
2018-12-04 17:36:06 -08:00
rllib train --env=PongDeterministic-v4 \
2018-12-03 23:15:43 -08:00
--run=A2C --config '{"num_workers": 2, "monitor": true}'
# videos will be saved in the ~/ray_results/<experiment> dir, for example
openaigym.video.0.31401.video000000.meta.json
openaigym.video.0.31401.video000000.mp4
openaigym.video.0.31403.video000000.meta.json
openaigym.video.0.31403.video000000.mp4
2019-09-07 11:50:18 -07:00
Eager Mode
~~~~~~~~~~
2019-06-07 16:42:37 -07:00
2020-01-18 03:48:44 +01:00
Policies built with `` build_tf_policy `` (most of the reference algorithms are)
can be run in eager mode by setting the
`` "eager": True `` / `` "eager_tracing": True `` config options or using
`` rllib train --eager [--trace] `` .
This will tell RLlib to execute the model forward pass, action distribution,
loss, and stats functions in eager mode.
Eager mode makes debugging much easier, since you can now use line-by-line
debugging with breakpoints or Python `` print() `` to inspect
intermediate tensor values.
However, eager can be slower than graph mode unless tracing is enabled.
Using PyTorch
~~~~~~~~~~~~~
Trainers that have an implemented TorchPolicy, will allow you to run
`rllib train` using the the command line `` --torch `` flag.
Algorithms that do not have a torch version yet will complain with an error in
this case.
2019-08-23 02:21:11 -04:00
2019-06-07 16:42:37 -07:00
2019-02-23 21:23:40 -08:00
Episode Traces
~~~~~~~~~~~~~~
You can use the `data output API <rllib-offline.html> `__ to save episode traces for debugging. For example, the following command will run PPO while saving episode traces to `` /tmp/debug `` .
.. code-block :: bash
rllib train --run=PPO --env=CartPole-v0 \
--config='{"output": "/tmp/debug", "output_compress_columns": []}'
# episode traces will be saved in /tmp/debug, for example
output-2019-02-23_12-02-03_worker-2_0.json
output-2019-02-23_12-02-04_worker-1_0.json
2018-12-03 23:15:43 -08:00
Log Verbosity
~~~~~~~~~~~~~
2019-11-13 18:50:45 -08:00
You can control the trainer log level via the `` "log_level" `` flag. Valid values are "DEBUG", "INFO", "WARN" (default), and "ERROR". This can be used to increase or decrease the verbosity of internal logging. You can also use the `` -v `` and `` -vv `` flags. For example, the following two commands are about equivalent:
2018-12-03 23:15:43 -08:00
.. code-block :: bash
2018-12-04 17:36:06 -08:00
rllib train --env=PongDeterministic-v4 \
2018-12-03 23:15:43 -08:00
--run=A2C --config '{"num_workers": 2, "log_level": "DEBUG"}'
2019-11-13 18:50:45 -08:00
rllib train --env=PongDeterministic-v4 \
--run=A2C --config '{"num_workers": 2}' -vv
The default log level is `` WARN `` . We strongly recommend using at least `` INFO `` level logging for development.
2018-12-03 23:15:43 -08:00
Stack Traces
~~~~~~~~~~~~
You can use the `` ray stack `` command to dump the stack traces of all the Python workers on a single node. This can be useful for debugging unexpected hangs or performance issues.
2020-03-20 12:43:57 -07:00
External Application API
------------------------
2018-07-01 00:05:08 -07:00
2020-03-20 12:43:57 -07:00
In some cases (i.e., when interacting with an externally hosted simulator or production environment) it makes more sense to interact with RLlib as if it were an independently running service, rather than RLlib hosting the simulations itself. This is possible via RLlib's external applications interface `(full documentation) <rllib-env.html#external-agents-and-applications> `__ .
2018-07-01 00:05:08 -07:00
2020-03-20 12:43:57 -07:00
.. autoclass :: ray.rllib.env.policy_client.PolicyClient
2018-07-01 00:05:08 -07:00
:members:
2020-03-20 12:43:57 -07:00
.. autoclass :: ray.rllib.env.policy_server_input.PolicyServerInput
2018-07-01 00:05:08 -07:00
:members: