mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[rllib] Fix documentation on custom policies (#4910)
* wip * add docs * lint * todo sections * fix doc
This commit is contained in:
parent
0066d7cf2a
commit
1c073e92e4
6 changed files with 131 additions and 4 deletions
|
@ -389,6 +389,12 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
|||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_loss.py --iters=2
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_tf_policy.py --iters=2
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_torch_policy.py --iters=2
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/policy_evaluator_custom_workflow.py
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ This page describes the internal concepts used to implement algorithms in RLlib.
|
|||
Policies
|
||||
--------
|
||||
|
||||
Policy classes encapsulate the core numerical components of RL algorithms. This typically includes the policy model that determines actions to take, a trajectory postprocessor for experiences, and a loss function to improve the policy given postprocessed experiences. For a simple example, see the policy gradients `graph definition <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/pg/pg_policy.py>`__.
|
||||
Policy classes encapsulate the core numerical components of RL algorithms. This typically includes the policy model that determines actions to take, a trajectory postprocessor for experiences, and a loss function to improve the policy given postprocessed experiences. For a simple example, see the policy gradients `policy definition <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/pg/pg_policy.py>`__.
|
||||
|
||||
Most interaction with deep learning frameworks is isolated to the `Policy interface <https://github.com/ray-project/ray/blob/master/python/ray/rllib/policy/policy.py>`__, allowing RLlib to support multiple frameworks. To simplify the definition of policies, RLlib includes `Tensorflow <#building-policies-in-tensorflow>`__ and `PyTorch-specific <#building-policies-in-pytorch>`__ templates. You can also write your own from scratch. Here is an example:
|
||||
|
||||
|
@ -148,7 +148,7 @@ We can create a `Trainer <#trainers>`__ and try running this policy on a toy env
|
|||
tune.run(MyTrainer, config={"env": "CartPole-v0", "num_workers": 2})
|
||||
|
||||
|
||||
If you run the above snippet, you'll probably notice that CartPole doesn't learn so well:
|
||||
If you run the above snippet `(runnable file here) <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_tf_policy.py>`__, you'll probably notice that CartPole doesn't learn so well:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
@ -208,7 +208,7 @@ In the above section you saw how to compose a simple policy gradient algorithm w
|
|||
|
||||
Besides some boilerplate for defining the PPO configuration and some warnings, there are two important arguments to take note of here: ``make_policy_optimizer=choose_policy_optimizer``, and ``after_optimizer_step=update_kl``.
|
||||
|
||||
The ``choose_policy_optimizer`` function chooses which `Policy Optimizer <#policy-optimization>`__ to use for distributed training. You can think of these policy optimizers as coordinating the distributed workflow needed to improve the policy. Depending on the trainer config, PPO can switch between a simple synchronous optimizer (the default), or a multi-GPU optimizer that implements minibatch SGD:
|
||||
The ``choose_policy_optimizer`` function chooses which `Policy Optimizer <#policy-optimization>`__ to use for distributed training. You can think of these policy optimizers as coordinating the distributed workflow needed to improve the policy. Depending on the trainer config, PPO can switch between a simple synchronous optimizer, or a multi-GPU optimizer that implements minibatch SGD (the default):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -349,7 +349,27 @@ Finally, note that you do not have to use ``build_tf_policy`` to define a Tensor
|
|||
Building Policies in PyTorch
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a trainer given a Torch policy is exactly the same). Building on the TF examples above, let's look at how the `A3C torch policy <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/a3c/a3c_torch_policy.py>`__ is defined:
|
||||
Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a trainer given a Torch policy is exactly the same). Here's a simple example of a trivial torch policy `(runnable file here) <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_torch_policy.py>`__:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
|
||||
def policy_gradient_loss(policy, batch_tensors):
|
||||
logits, _, values, _ = policy.model({
|
||||
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
|
||||
}, [])
|
||||
action_dist = policy.dist_class(logits)
|
||||
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
|
||||
return -batch_tensors[SampleBatch.REWARDS].dot(log_probs)
|
||||
|
||||
# <class 'ray.rllib.policy.torch_policy_template.MyTorchPolicy'>
|
||||
MyTorchPolicy = build_torch_policy(
|
||||
name="MyTorchPolicy",
|
||||
loss_fn=policy_gradient_loss)
|
||||
|
||||
Now, building on the TF examples above, let's look at how the `A3C torch policy <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/a3c/a3c_torch_policy.py>`__ is defined:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -423,6 +443,11 @@ You can find the full policy definition in `a3c_torch_policy.py <https://github.
|
|||
|
||||
In summary, the main differences between the PyTorch and TensorFlow policy builder functions is that the TF loss and stats functions are built symbolically when the policy is initialized, whereas for PyTorch these functions are called imperatively each time they are used.
|
||||
|
||||
Extending Existing Policies
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
(todo)
|
||||
|
||||
Policy Evaluation
|
||||
-----------------
|
||||
|
||||
|
|
|
@ -103,6 +103,8 @@ Concepts and Building Custom Algorithms
|
|||
|
||||
- `Building Policies in PyTorch <rllib-concepts.html#building-policies-in-pytorch>`__
|
||||
|
||||
- `Extending Existing Policies <rllib-concepts.html#extending-existing-policies>`__
|
||||
|
||||
* `Policy Evaluation <rllib-concepts.html#policy-evaluation>`__
|
||||
* `Policy Optimization <rllib-concepts.html#policy-optimization>`__
|
||||
* `Trainers <rllib-concepts.html#trainers>`__
|
||||
|
|
|
@ -100,6 +100,8 @@ COMMON_CONFIG = {
|
|||
"clip_actions": True,
|
||||
# Whether to use rllib or deepmind preprocessors by default
|
||||
"preprocessor_pref": "deepmind",
|
||||
# The default learning rate
|
||||
"lr": 0.0001,
|
||||
|
||||
# === Evaluation ===
|
||||
# Evaluate with every `evaluation_interval` training iterations.
|
||||
|
|
47
python/ray/rllib/examples/custom_tf_policy.py
Normal file
47
python/ray/rllib/examples/custom_tf_policy.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--iters", type=int, default=200)
|
||||
|
||||
|
||||
def policy_gradient_loss(policy, batch_tensors):
|
||||
actions = batch_tensors[SampleBatch.ACTIONS]
|
||||
rewards = batch_tensors[SampleBatch.REWARDS]
|
||||
return -tf.reduce_mean(policy.action_dist.logp(actions) * rewards)
|
||||
|
||||
|
||||
# <class 'ray.rllib.policy.tf_policy_template.MyTFPolicy'>
|
||||
MyTFPolicy = build_tf_policy(
|
||||
name="MyTFPolicy",
|
||||
loss_fn=policy_gradient_loss,
|
||||
)
|
||||
|
||||
# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
|
||||
MyTrainer = build_trainer(
|
||||
name="MyCustomTrainer",
|
||||
default_policy=MyTFPolicy,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
args = parser.parse_args()
|
||||
tune.run(
|
||||
MyTrainer,
|
||||
stop={"training_iteration": args.iters},
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"num_workers": 2,
|
||||
})
|
45
python/ray/rllib/examples/custom_torch_policy.py
Normal file
45
python/ray/rllib/examples/custom_torch_policy.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--iters", type=int, default=200)
|
||||
|
||||
|
||||
def policy_gradient_loss(policy, batch_tensors):
|
||||
logits, _, values, _ = policy.model({
|
||||
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
|
||||
}, [])
|
||||
action_dist = policy.dist_class(logits)
|
||||
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
|
||||
return -batch_tensors[SampleBatch.REWARDS].dot(log_probs)
|
||||
|
||||
|
||||
# <class 'ray.rllib.policy.torch_policy_template.MyTorchPolicy'>
|
||||
MyTorchPolicy = build_torch_policy(
|
||||
name="MyTorchPolicy", loss_fn=policy_gradient_loss)
|
||||
|
||||
# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
|
||||
MyTrainer = build_trainer(
|
||||
name="MyCustomTrainer",
|
||||
default_policy=MyTorchPolicy,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
args = parser.parse_args()
|
||||
tune.run(
|
||||
MyTrainer,
|
||||
stop={"training_iteration": args.iters},
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"num_workers": 2,
|
||||
})
|
Loading…
Add table
Reference in a new issue