From 1c073e92e4f23c7b61e16ad3d3b77c6c69ca35cc Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 1 Jun 2019 16:13:21 +0800 Subject: [PATCH] [rllib] Fix documentation on custom policies (#4910) * wip * add docs * lint * todo sections * fix doc --- ci/jenkins_tests/run_rllib_tests.sh | 6 +++ doc/source/rllib-concepts.rst | 33 +++++++++++-- doc/source/rllib.rst | 2 + python/ray/rllib/agents/trainer.py | 2 + python/ray/rllib/examples/custom_tf_policy.py | 47 +++++++++++++++++++ .../ray/rllib/examples/custom_torch_policy.py | 45 ++++++++++++++++++ 6 files changed, 131 insertions(+), 4 deletions(-) create mode 100644 python/ray/rllib/examples/custom_tf_policy.py create mode 100644 python/ray/rllib/examples/custom_torch_policy.py diff --git a/ci/jenkins_tests/run_rllib_tests.sh b/ci/jenkins_tests/run_rllib_tests.sh index 13acff28d..78fbf6a3a 100644 --- a/ci/jenkins_tests/run_rllib_tests.sh +++ b/ci/jenkins_tests/run_rllib_tests.sh @@ -389,6 +389,12 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_loss.py --iters=2 +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_tf_policy.py --iters=2 + +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_torch_policy.py --iters=2 + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/policy_evaluator_custom_workflow.py diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index 06e890832..8556e419a 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -6,7 +6,7 @@ This page describes the internal concepts used to implement algorithms in RLlib. Policies -------- -Policy classes encapsulate the core numerical components of RL algorithms. This typically includes the policy model that determines actions to take, a trajectory postprocessor for experiences, and a loss function to improve the policy given postprocessed experiences. For a simple example, see the policy gradients `graph definition `__. +Policy classes encapsulate the core numerical components of RL algorithms. This typically includes the policy model that determines actions to take, a trajectory postprocessor for experiences, and a loss function to improve the policy given postprocessed experiences. For a simple example, see the policy gradients `policy definition `__. Most interaction with deep learning frameworks is isolated to the `Policy interface `__, allowing RLlib to support multiple frameworks. To simplify the definition of policies, RLlib includes `Tensorflow <#building-policies-in-tensorflow>`__ and `PyTorch-specific <#building-policies-in-pytorch>`__ templates. You can also write your own from scratch. Here is an example: @@ -148,7 +148,7 @@ We can create a `Trainer <#trainers>`__ and try running this policy on a toy env tune.run(MyTrainer, config={"env": "CartPole-v0", "num_workers": 2}) -If you run the above snippet, you'll probably notice that CartPole doesn't learn so well: +If you run the above snippet `(runnable file here) `__, you'll probably notice that CartPole doesn't learn so well: .. code-block:: bash @@ -208,7 +208,7 @@ In the above section you saw how to compose a simple policy gradient algorithm w Besides some boilerplate for defining the PPO configuration and some warnings, there are two important arguments to take note of here: ``make_policy_optimizer=choose_policy_optimizer``, and ``after_optimizer_step=update_kl``. -The ``choose_policy_optimizer`` function chooses which `Policy Optimizer <#policy-optimization>`__ to use for distributed training. You can think of these policy optimizers as coordinating the distributed workflow needed to improve the policy. Depending on the trainer config, PPO can switch between a simple synchronous optimizer (the default), or a multi-GPU optimizer that implements minibatch SGD: +The ``choose_policy_optimizer`` function chooses which `Policy Optimizer <#policy-optimization>`__ to use for distributed training. You can think of these policy optimizers as coordinating the distributed workflow needed to improve the policy. Depending on the trainer config, PPO can switch between a simple synchronous optimizer, or a multi-GPU optimizer that implements minibatch SGD (the default): .. code-block:: python @@ -349,7 +349,27 @@ Finally, note that you do not have to use ``build_tf_policy`` to define a Tensor Building Policies in PyTorch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a trainer given a Torch policy is exactly the same). Building on the TF examples above, let's look at how the `A3C torch policy `__ is defined: +Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a trainer given a Torch policy is exactly the same). Here's a simple example of a trivial torch policy `(runnable file here) `__: + +.. code-block:: python + + from ray.rllib.policy.sample_batch import SampleBatch + from ray.rllib.policy.torch_policy_template import build_torch_policy + + def policy_gradient_loss(policy, batch_tensors): + logits, _, values, _ = policy.model({ + SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] + }, []) + action_dist = policy.dist_class(logits) + log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) + return -batch_tensors[SampleBatch.REWARDS].dot(log_probs) + + # + MyTorchPolicy = build_torch_policy( + name="MyTorchPolicy", + loss_fn=policy_gradient_loss) + +Now, building on the TF examples above, let's look at how the `A3C torch policy `__ is defined: .. code-block:: python @@ -423,6 +443,11 @@ You can find the full policy definition in `a3c_torch_policy.py `__ + - `Extending Existing Policies `__ + * `Policy Evaluation `__ * `Policy Optimization `__ * `Trainers `__ diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index 4294affb1..fb20f56ba 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -100,6 +100,8 @@ COMMON_CONFIG = { "clip_actions": True, # Whether to use rllib or deepmind preprocessors by default "preprocessor_pref": "deepmind", + # The default learning rate + "lr": 0.0001, # === Evaluation === # Evaluate with every `evaluation_interval` training iterations. diff --git a/python/ray/rllib/examples/custom_tf_policy.py b/python/ray/rllib/examples/custom_tf_policy.py new file mode 100644 index 000000000..0442dff83 --- /dev/null +++ b/python/ray/rllib/examples/custom_tf_policy.py @@ -0,0 +1,47 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse + +import ray +from ray import tune +from ray.rllib.agents.trainer_template import build_trainer +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() + +parser = argparse.ArgumentParser() +parser.add_argument("--iters", type=int, default=200) + + +def policy_gradient_loss(policy, batch_tensors): + actions = batch_tensors[SampleBatch.ACTIONS] + rewards = batch_tensors[SampleBatch.REWARDS] + return -tf.reduce_mean(policy.action_dist.logp(actions) * rewards) + + +# +MyTFPolicy = build_tf_policy( + name="MyTFPolicy", + loss_fn=policy_gradient_loss, +) + +# +MyTrainer = build_trainer( + name="MyCustomTrainer", + default_policy=MyTFPolicy, +) + +if __name__ == "__main__": + ray.init() + args = parser.parse_args() + tune.run( + MyTrainer, + stop={"training_iteration": args.iters}, + config={ + "env": "CartPole-v0", + "num_workers": 2, + }) diff --git a/python/ray/rllib/examples/custom_torch_policy.py b/python/ray/rllib/examples/custom_torch_policy.py new file mode 100644 index 000000000..7ab2786cf --- /dev/null +++ b/python/ray/rllib/examples/custom_torch_policy.py @@ -0,0 +1,45 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse + +import ray +from ray import tune +from ray.rllib.agents.trainer_template import build_trainer +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_policy_template import build_torch_policy + +parser = argparse.ArgumentParser() +parser.add_argument("--iters", type=int, default=200) + + +def policy_gradient_loss(policy, batch_tensors): + logits, _, values, _ = policy.model({ + SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] + }, []) + action_dist = policy.dist_class(logits) + log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) + return -batch_tensors[SampleBatch.REWARDS].dot(log_probs) + + +# +MyTorchPolicy = build_torch_policy( + name="MyTorchPolicy", loss_fn=policy_gradient_loss) + +# +MyTrainer = build_trainer( + name="MyCustomTrainer", + default_policy=MyTorchPolicy, +) + +if __name__ == "__main__": + ray.init() + args = parser.parse_args() + tune.run( + MyTrainer, + stop={"training_iteration": args.iters}, + config={ + "env": "CartPole-v0", + "num_workers": 2, + })