mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Move (A/DD)?PPO and IMPALA algos to algorithms
dir and rename policy and trainer classes. (#25346)
This commit is contained in:
parent
f781622f86
commit
e4ceae19ef
110 changed files with 649 additions and 508 deletions
|
@ -129,7 +129,7 @@
|
|||
# Test all tests in the `agents` (soon to be "trainers") dir:
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options)
|
||||
--build_tests_only
|
||||
--test_tag_filters=trainers_dir_generic,-multi_gpu
|
||||
--test_tag_filters=algorithms_dir_generic,-multi_gpu
|
||||
--test_env=RAY_USE_MULTIPROCESSING_CPU_COUNT=1
|
||||
rllib/...
|
||||
|
||||
|
@ -141,7 +141,7 @@
|
|||
# Test all tests in the `agents` (soon to be "trainers") dir:
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options)
|
||||
--build_tests_only
|
||||
--test_tag_filters=trainers_dir,-trainers_dir_generic,-multi_gpu
|
||||
--test_tag_filters=algorithms_dir,-algorithms_dir_generic,-multi_gpu
|
||||
--test_env=RAY_USE_MULTIPROCESSING_CPU_COUNT=1
|
||||
rllib/...
|
||||
|
||||
|
@ -154,7 +154,7 @@
|
|||
# "learning_tests|quick_train|examples|tests_dir".
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options)
|
||||
--build_tests_only
|
||||
--test_tag_filters=-learning_tests,-quick_train,-memory_leak_tests,-examples,-tests_dir,-trainers_dir,-documentation,-multi_gpu
|
||||
--test_tag_filters=-learning_tests,-quick_train,-memory_leak_tests,-examples,-tests_dir,-algorithms_dir,-documentation,-multi_gpu
|
||||
--test_env=RAY_USE_MULTIPROCESSING_CPU_COUNT=1
|
||||
rllib/...
|
||||
|
||||
|
|
|
@ -179,7 +179,7 @@ It offers high scalability and unified APIs for a
|
|||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
|
||||
|
||||
# Define your problem using python and openAI's gym API:
|
||||
|
@ -229,7 +229,7 @@ It offers high scalability and unified APIs for a
|
|||
|
||||
|
||||
# Create an RLlib Trainer instance.
|
||||
trainer = PPOTrainer(
|
||||
trainer = PPO(
|
||||
config={
|
||||
# Env class to use (here: our gym.Env sub-class from above).
|
||||
"env": SimpleCorridor,
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from ray import tune
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
|
||||
tune.run(
|
||||
PPOTrainer,
|
||||
PPO,
|
||||
stop={"episode_len_mean": 20},
|
||||
config={"env": "CartPole-v0", "framework": "torch", "log_level": "INFO"},
|
||||
)
|
||||
|
|
|
@ -25,14 +25,14 @@ Trainers also implement the :ref:`Tune Trainable API <tune-60-seconds>` for easy
|
|||
|
||||
You have three ways to interact with a trainer. You can use the basic Python API or the command line to train it, or you
|
||||
can use Ray Tune to tune hyperparameters of your reinforcement learning algorithm.
|
||||
The following example shows three equivalent ways of interacting with the ``PPOTrainer``,
|
||||
The following example shows three equivalent ways of interacting with the ``PPO`` Trainer,
|
||||
which implements the proximal policy optimization algorithm in RLlib.
|
||||
|
||||
.. tabbed:: Basic RLlib Trainer
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = PPOTrainer(env="CartPole-v0", config={"train_batch_size": 4000})
|
||||
trainer = PPO(env="CartPole-v0", config={"train_batch_size": 4000})
|
||||
while True:
|
||||
print(trainer.train())
|
||||
|
||||
|
@ -47,7 +47,7 @@ which implements the proximal policy optimization algorithm in RLlib.
|
|||
.. code-block:: python
|
||||
|
||||
from ray import tune
|
||||
tune.run(PPOTrainer, config={"env": "CartPole-v0", "train_batch_size": 4000})
|
||||
tune.run(PPO, config={"env": "CartPole-v0", "train_batch_size": 4000})
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -22,9 +22,9 @@ prep.transform(env.reset()).shape
|
|||
# __query_action_dist_start__
|
||||
# Get a reference to the policy
|
||||
import numpy as np
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
|
||||
trainer = PPOTrainer(env="CartPole-v0", config={"framework": "tf2", "num_workers": 0})
|
||||
trainer = PPO(env="CartPole-v0", config={"framework": "tf2", "num_workers": 0})
|
||||
policy = trainer.get_policy()
|
||||
# <ray.rllib.policy.eager_tf_policy.PPOTFPolicy_eager object at 0x7fd020165470>
|
||||
|
||||
|
|
|
@ -23,4 +23,4 @@ framework-agnostic policy),
|
|||
* :py:meth:`~ray.rllib.policy.policy.Policy.postprocess_trajectory`
|
||||
* :py:meth:`~ray.rllib.policy.policy.Policy.loss`
|
||||
|
||||
`See here for an example on how to override TorchPolicy <https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ppo_torch_policy.py>`_.
|
||||
`See here for an example on how to override TorchPolicy <https://github.com/ray-project/ray/blob/master/rllib/algorithms/ppo/ppo_torch_policy.py>`_.
|
||||
|
|
|
@ -130,7 +130,7 @@ Importance Weighted Actor-Learner Architecture (IMPALA)
|
|||
-------------------------------------------------------
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://arxiv.org/abs/1802.01561>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/impala/impala.py>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/impala/impala.py>`__
|
||||
In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code <https://github.com/deepmind/scalable_agent/blob/master/vtrace.py>`__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model <rllib-models.html#custom-models-tensorflow>`__. Multiple learner GPUs and experience replay are also supported.
|
||||
|
||||
.. figure:: images/impala-arch.svg
|
||||
|
@ -168,7 +168,7 @@ SpaceInvaders 843 ~300
|
|||
|
||||
**IMPALA-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
.. literalinclude:: ../../../rllib/agents/impala/impala.py
|
||||
.. literalinclude:: ../../../rllib/algorithms/impala/impala.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
@ -179,7 +179,7 @@ Asynchronous Proximal Policy Optimization (APPO)
|
|||
------------------------------------------------
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://arxiv.org/abs/1707.06347>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/appo.py>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/appo/appo.py>`__
|
||||
We include an asynchronous variant of Proximal Policy Optimization (PPO) based on the IMPALA architecture. This is similar to IMPALA but using a surrogate policy loss with clipping. Compared to synchronous PPO, APPO is more efficient in wall-clock time due to its use of asynchronous sampling. Using a clipped loss also allows for multiple SGD passes, and therefore the potential for better sample efficiency compared to IMPALA. V-trace can also be enabled to correct for off-policy samples.
|
||||
|
||||
.. tip::
|
||||
|
@ -190,11 +190,11 @@ We include an asynchronous variant of Proximal Policy Optimization (PPO) based o
|
|||
|
||||
APPO architecture (same as IMPALA)
|
||||
|
||||
Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ppo/pong-appo.yaml>`__
|
||||
Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/appo/pong-appo.yaml>`__
|
||||
|
||||
**APPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
.. literalinclude:: ../../../rllib/agents/ppo/appo.py
|
||||
.. literalinclude:: ../../../rllib/algorithms/appo/appo.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
@ -205,7 +205,7 @@ Decentralized Distributed Proximal Policy Optimization (DD-PPO)
|
|||
---------------------------------------------------------------
|
||||
|pytorch|
|
||||
`[paper] <https://arxiv.org/abs/1911.00357>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ddppo.py>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/ddppo/ddppo.py>`__
|
||||
Unlike APPO or PPO, with DD-PPO policy improvement is no longer done centralized in the trainer process. Instead, gradients are computed remotely on each rollout worker and all-reduced at each mini-batch using `torch distributed <https://pytorch.org/docs/stable/distributed.html>`__. This allows each worker's GPU to be used both for sampling and for training.
|
||||
|
||||
.. tip::
|
||||
|
@ -216,11 +216,11 @@ Unlike APPO or PPO, with DD-PPO policy improvement is no longer done centralized
|
|||
|
||||
DD-PPO architecture (both sampling and learning are done on worker GPUs)
|
||||
|
||||
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ppo/cartpole-ddppo.yaml>`__, `BreakoutNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ppo/atari-ddppo.yaml>`__
|
||||
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ddppo/cartpole-ddppo.yaml>`__, `BreakoutNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ddppo/atari-ddppo.yaml>`__
|
||||
|
||||
**DDPPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
.. literalinclude:: ../../../rllib/agents/ppo/ddppo.py
|
||||
.. literalinclude:: ../../../rllib/algorithms/ddppo/ddppo.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
@ -396,7 +396,7 @@ Proximal Policy Optimization (PPO)
|
|||
----------------------------------
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://arxiv.org/abs/1707.06347>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ppo.py>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/ppo/ppo.py>`__
|
||||
PPO's clipped objective supports multiple SGD passes over the same batch of experiences. RLlib's multi-GPU optimizer pins that data in GPU memory to avoid unnecessary transfers from host memory, substantially improving performance over a naive implementation. PPO scales out using multiple workers for experience collection, and also to multiple GPUs for SGD.
|
||||
|
||||
.. tip::
|
||||
|
@ -445,7 +445,7 @@ HalfCheetah 9664 ~7700
|
|||
|
||||
**PPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
.. literalinclude:: ../../../rllib/agents/ppo/ppo.py
|
||||
.. literalinclude:: ../../../rllib/algorithms/ppo/ppo.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
|
|
@ -210,11 +210,11 @@ You might be wondering how RLlib makes the advantages placeholder automatically
|
|||
|
||||
In the above section you saw how to compose a simple policy gradient algorithm with RLlib.
|
||||
In this example, we'll dive into how PPO is defined within RLlib and how you can modify it.
|
||||
First, check out the `PPO trainer definition <https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ppo.py>`__:
|
||||
First, check out the `PPO trainer definition <https://github.com/ray-project/ray/blob/master/rllib/algorithms/ppo/ppo.py>`__:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class PPOTrainer(Trainer):
|
||||
class PPO(Trainer):
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
|
@ -280,7 +280,7 @@ Suppose we want to customize PPO to use an asynchronous-gradient optimization st
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
from ray.rllib.execution.rollout_ops import AsyncGradients
|
||||
from ray.rllib.execution.train_ops import ApplyGradients
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
|
@ -307,7 +307,7 @@ Now let's look at each PPO policy definition:
|
|||
|
||||
PPOTFPolicy = build_tf_policy(
|
||||
name="PPOTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
|
||||
get_default_config=lambda: ray.rllib.algorithms.ppo.ppo.PPOConfig().to_dict(),
|
||||
loss_fn=ppo_surrogate_loss,
|
||||
stats_fn=kl_and_loss_stats,
|
||||
extra_action_out_fn=vf_preds_and_logits_fetches,
|
||||
|
@ -562,8 +562,8 @@ You can use the ``with_updates`` method on Trainers and Policy objects built wit
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTFPolicy
|
||||
|
||||
CustomPolicy = PPOTFPolicy.with_updates(
|
||||
name="MyCustomPPOTFPolicy",
|
||||
|
|
|
@ -33,7 +33,7 @@ You can pass either a string name or a Python class to specify an environment. B
|
|||
return <obs>, <reward: float>, <done: bool>, <info: dict>
|
||||
|
||||
ray.init()
|
||||
trainer = ppo.PPOTrainer(env=MyEnv, config={
|
||||
trainer = ppo.PPO(env=MyEnv, config={
|
||||
"env_config": {}, # config to pass to env class
|
||||
})
|
||||
|
||||
|
@ -50,7 +50,7 @@ You can also register a custom env creator function with a string name. This fun
|
|||
return MyEnv(...) # return an env instance
|
||||
|
||||
register_env("my_env", env_creator)
|
||||
trainer = ppo.PPOTrainer(env="my_env")
|
||||
trainer = ppo.PPO(env="my_env")
|
||||
|
||||
For a full runnable code example using the custom environment API, see `custom_env.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__.
|
||||
|
||||
|
|
|
@ -215,7 +215,7 @@ Once implemented, your TF model can then be registered and used in place of a bu
|
|||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
|
||||
|
@ -227,7 +227,7 @@ Once implemented, your TF model can then be registered and used in place of a bu
|
|||
ModelCatalog.register_custom_model("my_tf_model", MyModelClass)
|
||||
|
||||
ray.init()
|
||||
trainer = ppo.PPOTrainer(env="CartPole-v0", config={
|
||||
trainer = ppo.PPO(env="CartPole-v0", config={
|
||||
"model": {
|
||||
"custom_model": "my_tf_model",
|
||||
# Extra kwargs to be passed to your model's c'tor.
|
||||
|
@ -282,7 +282,7 @@ Once implemented, your PyTorch model can then be registered and used in place of
|
|||
ModelCatalog.register_custom_model("my_torch_model", CustomTorchModel)
|
||||
|
||||
ray.init()
|
||||
trainer = ppo.PPOTrainer(env="CartPole-v0", config={
|
||||
trainer = ppo.PPO(env="CartPole-v0", config={
|
||||
"framework": "torch",
|
||||
"model": {
|
||||
"custom_model": "my_torch_model",
|
||||
|
@ -488,7 +488,7 @@ Similar to custom models and preprocessors, you can also specify a custom action
|
|||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import Preprocessor
|
||||
|
||||
|
@ -508,7 +508,7 @@ Similar to custom models and preprocessors, you can also specify a custom action
|
|||
ModelCatalog.register_custom_action_dist("my_dist", MyActionDist)
|
||||
|
||||
ray.init()
|
||||
trainer = ppo.PPOTrainer(env="CartPole-v0", config={
|
||||
trainer = ppo.PPO(env="CartPole-v0", config={
|
||||
"model": {
|
||||
"custom_action_dist": "my_dist",
|
||||
},
|
||||
|
|
|
@ -238,7 +238,7 @@ You can configure experience input for an agent using the following options:
|
|||
objects, which have the advantage of being type safe, allowing users to set different config settings within
|
||||
meaningful sub-categories (e.g. ``my_config.offline_data(input_=[xyz])``), and offer the ability to
|
||||
construct a Trainer instance from these config objects (via their ``.build()`` method).
|
||||
So far, this is only supported for some Trainer classes, such as :py:class:`~ray.rllib.agents.ppo.ppo.PPOTrainer`,
|
||||
So far, this is only supported for some Trainer classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
|
||||
but we are rolling this out right now across all RLlib.
|
||||
|
||||
|
||||
|
@ -335,7 +335,7 @@ You can configure experience output for an agent using the following options:
|
|||
objects, which have the advantage of being type safe, allowing users to set different config settings within
|
||||
meaningful sub-categories (e.g. ``my_config.offline_data(input_=[xyz])``), and offer the ability to
|
||||
construct a Trainer instance from these config objects (via their ``.build()`` method).
|
||||
So far, this is only supported for some Trainer classes, such as :py:class:`~ray.rllib.agents.ppo.ppo.PPOTrainer`,
|
||||
So far, this is only supported for some Trainer classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
|
||||
but we are rolling this out right now across all RLlib.
|
||||
|
||||
.. code-block:: python
|
||||
|
|
|
@ -164,7 +164,7 @@ Common Parameters
|
|||
objects, which have the advantage of being type safe, allowing users to set different config settings within
|
||||
meaningful sub-categories (e.g. ``my_config.training(lr=0.0003)``), and offer the ability to
|
||||
construct a Trainer instance from these config objects (via their ``build()`` method).
|
||||
So far, this is only supported for some Trainer classes, such as :py:class:`~ray.rllib.agents.ppo.ppo.PPOTrainer`,
|
||||
So far, this is only supported for some Trainer classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
|
||||
but we are rolling this out right now across all RLlib.
|
||||
|
||||
The following is a list of the common algorithm hyper-parameters:
|
||||
|
@ -705,14 +705,14 @@ Here is an example of the basic usage (for a more complete example, see `custom_
|
|||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
from ray.tune.logger import pretty_print
|
||||
|
||||
ray.init()
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
config["num_gpus"] = 0
|
||||
config["num_workers"] = 1
|
||||
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
trainer = ppo.PPO(config=config, env="CartPole-v0")
|
||||
|
||||
# Can optionally call trainer.restore(path) to load a checkpoint.
|
||||
|
||||
|
@ -783,7 +783,7 @@ It also simplifies saving the trained agent. For example:
|
|||
# tune.run() allows setting a custom log directory (other than ``~/ray-results``)
|
||||
# and automatically saving the trained agent
|
||||
analysis = ray.tune.run(
|
||||
ppo.PPOTrainer,
|
||||
ppo.PPO,
|
||||
config=config,
|
||||
local_dir=log_dir,
|
||||
stop=stop_criteria,
|
||||
|
@ -807,7 +807,7 @@ Loading and restoring a trained agent from a checkpoint is simple:
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
agent = ppo.PPOTrainer(config=config, env=env_class)
|
||||
agent = ppo.PPO(config=config, env=env_class)
|
||||
agent.restore(checkpoint_path)
|
||||
|
||||
|
||||
|
@ -1340,10 +1340,10 @@ customizations to your training loop.
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
|
||||
def train(config, reporter):
|
||||
trainer = PPOTrainer(config=config, env=YourEnv)
|
||||
trainer = PPO(config=config, env=YourEnv)
|
||||
while True:
|
||||
result = trainer.train()
|
||||
reporter(**result)
|
||||
|
|
|
@ -28,22 +28,25 @@ We will train and checkpoint a simple PPO model with the `CartPole-v0` environme
|
|||
In this tutorial we simply write to local disk, but in production you might want to consider using a cloud
|
||||
storage solution like S3 or a shared file system.
|
||||
|
||||
Let's get started by defining a `PPOTrainer` instance, training it for one iteration and then creating a checkpoint:
|
||||
Let's get started by defining a `PPO` instance, training it for one iteration and then creating a checkpoint:
|
||||
|
||||
```{code-cell} python3
|
||||
:tags: [remove-output]
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
from ray import serve
|
||||
|
||||
def train_ppo_model():
|
||||
trainer = ppo.PPOTrainer(
|
||||
config={"framework": "torch", "num_workers": 0},
|
||||
env="CartPole-v0",
|
||||
)
|
||||
# Train for one iteration
|
||||
# Configure our PPO algorithm.
|
||||
config = ppo.PPOConfig()\
|
||||
.framework("torch")\
|
||||
.rollouts(num_rollout_workers=0)
|
||||
# Create a `PPO` Trainer instance from the config.
|
||||
trainer = config.build(env="CartPole-v0")
|
||||
# Train for one iteration.
|
||||
trainer.train()
|
||||
# Save state of the trained Trainer in a checkpoint.
|
||||
trainer.save("/tmp/rllib_checkpoint")
|
||||
return "/tmp/rllib_checkpoint/checkpoint_000001/checkpoint-1"
|
||||
|
||||
|
@ -54,7 +57,7 @@ checkpoint_path = train_ppo_model()
|
|||
You create deployments with Ray Serve by using the `@serve.deployment` on a class that implements two methods:
|
||||
|
||||
- The `__init__` call creates the deployment instance and loads your data once.
|
||||
In the below example we restore our `PPOTrainer` from the checkpoint we just created.
|
||||
In the below example we restore our `PPO` Trainer from the checkpoint we just created.
|
||||
- The `__call__` method will be invoked every request.
|
||||
For each incoming request, this method has access to a `request` object,
|
||||
which is a [Starlette Request](https://www.starlette.io/requests/).
|
||||
|
@ -72,13 +75,10 @@ from starlette.requests import Request
|
|||
@serve.deployment(route_prefix="/cartpole-ppo")
|
||||
class ServePPOModel:
|
||||
def __init__(self, checkpoint_path) -> None:
|
||||
self.trainer = ppo.PPOTrainer(
|
||||
config={
|
||||
"framework": "torch",
|
||||
"num_workers": 0,
|
||||
},
|
||||
env="CartPole-v0",
|
||||
)
|
||||
config = ppo.PPOConfig()\
|
||||
.framework("torch")\
|
||||
.rollouts(num_rollout_workers=0)
|
||||
self.trainer = config.build(env="CartPole-v0")
|
||||
self.trainer.restore(checkpoint_path)
|
||||
|
||||
async def __call__(self, request: Request):
|
||||
|
|
|
@ -595,7 +595,7 @@ provider:
|
|||
ray_usage_lib._recorded_library_usages.clear()
|
||||
if os.environ.get("RAY_MINIMAL") != "1":
|
||||
from ray import tune # noqa: F401
|
||||
from ray.rllib.agents.ppo import PPOTrainer # noqa: F401
|
||||
from ray.rllib.algorithms.ppo import PPO # noqa: F401
|
||||
from ray import train # noqa: F401
|
||||
|
||||
ray.init(address=cluster.address)
|
||||
|
|
|
@ -351,12 +351,12 @@ class WandbIntegrationTest(unittest.TestCase):
|
|||
"""Test compatibility with RLlib configuration dicts"""
|
||||
# Local import to avoid tune dependency on rllib
|
||||
try:
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
except ImportError:
|
||||
self.skipTest("ray[rllib] not available")
|
||||
return
|
||||
|
||||
class WandbPPOTrainer(_MockWandbTrainableMixin, PPOTrainer):
|
||||
class WandbPPOTrainer(_MockWandbTrainableMixin, PPO):
|
||||
pass
|
||||
|
||||
config = {
|
||||
|
|
|
@ -8,7 +8,7 @@ import time
|
|||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents import DefaultCallbacks
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
|
||||
|
||||
def fn_trainable(config, checkpoint_dir=None):
|
||||
|
@ -68,7 +68,7 @@ def run_tune(
|
|||
if trainable == "rllib_str":
|
||||
train = "PPO"
|
||||
else:
|
||||
train = PPOTrainer
|
||||
train = PPO
|
||||
|
||||
config = {
|
||||
"env": "CartPole-v1",
|
||||
|
|
186
rllib/BUILD
186
rllib/BUILD
|
@ -16,7 +16,7 @@
|
|||
# -- "fake_gpus": Tests that run using 2 fake GPUs.
|
||||
|
||||
# - Quick agent compilation/tune-train tests, tagged "quick_train".
|
||||
# NOTE: These should be obsoleted in favor of "trainers_dir" tests as
|
||||
# NOTE: These should be obsoleted in favor of "algorithms_dir" tests as
|
||||
# they cover the same functionaliy.
|
||||
|
||||
# - Folder-bound tests, tagged with the name of the top-level dir:
|
||||
|
@ -27,7 +27,7 @@
|
|||
# - `policy` directory tests.
|
||||
# - `utils` directory tests.
|
||||
|
||||
# - Trainer ("agents") tests, tagged "trainers_dir".
|
||||
# - Trainer ("agents") tests, tagged "algorithms_dir".
|
||||
|
||||
# - Tests directory (everything in rllib/tests/...), tagged: "tests_dir" and
|
||||
# "tests_dir_[A-Z]"
|
||||
|
@ -168,8 +168,8 @@ py_test(
|
|||
tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/ppo/cartpole-appo.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/ppo"]
|
||||
data = ["tuned_examples/appo/cartpole-appo.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/appo"]
|
||||
)
|
||||
|
||||
# py_test(
|
||||
|
@ -178,8 +178,8 @@ py_test(
|
|||
# tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
|
||||
# size = "large",
|
||||
# srcs = ["tests/run_regression_tests.py"],
|
||||
# data = ["tuned_examples/ppo/cartpole-appo-vtrace.yaml"],
|
||||
# args = ["--yaml-dir=tuned_examples/ppo"]
|
||||
# data = ["tuned_examples/appo/cartpole-appo-vtrace.yaml"],
|
||||
# args = ["--yaml-dir=tuned_examples/appo"]
|
||||
# )
|
||||
|
||||
py_test(
|
||||
|
@ -189,9 +189,9 @@ py_test(
|
|||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = [
|
||||
"tuned_examples/ppo/cartpole-appo-vtrace-separate-losses.yaml"
|
||||
"tuned_examples/appo/cartpole-appo-vtrace-separate-losses.yaml"
|
||||
],
|
||||
args = ["--yaml-dir=tuned_examples/ppo"]
|
||||
args = ["--yaml-dir=tuned_examples/appo"]
|
||||
)
|
||||
|
||||
# py_test(
|
||||
|
@ -200,8 +200,8 @@ py_test(
|
|||
# tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
|
||||
# size = "large",
|
||||
# srcs = ["tests/run_regression_tests.py"],
|
||||
# data = ["tuned_examples/ppo/frozenlake-appo-vtrace.yaml"],
|
||||
# args = ["--yaml-dir=tuned_examples/ppo"]
|
||||
# data = ["tuned_examples/appo/frozenlake-appo-vtrace.yaml"],
|
||||
# args = ["--yaml-dir=tuned_examples/appo"]
|
||||
# )
|
||||
|
||||
py_test(
|
||||
|
@ -210,8 +210,8 @@ py_test(
|
|||
tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/ppo/cartpole-appo-vtrace-fake-gpus.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/ppo"]
|
||||
data = ["tuned_examples/appo/cartpole-appo-vtrace-fake-gpus.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/appo"]
|
||||
)
|
||||
|
||||
# ARS
|
||||
|
@ -268,8 +268,8 @@ py_test(
|
|||
tags = ["team:ml", "torch_only", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = glob(["tuned_examples/ppo/cartpole-ddppo.yaml"]),
|
||||
args = ["--yaml-dir=tuned_examples/ppo"]
|
||||
data = glob(["tuned_examples/ddppo/cartpole-ddppo.yaml"]),
|
||||
args = ["--yaml-dir=tuned_examples/ddppo"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -278,8 +278,8 @@ py_test(
|
|||
tags = ["team:ml", "torch_only", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = glob(["tuned_examples/ppo/pendulum-ddppo.yaml"]),
|
||||
args = ["--yaml-dir=tuned_examples/ppo"]
|
||||
data = glob(["tuned_examples/ddppo/pendulum-ddppo.yaml"]),
|
||||
args = ["--yaml-dir=tuned_examples/ddppo"]
|
||||
)
|
||||
|
||||
# DQN
|
||||
|
@ -637,13 +637,13 @@ py_test(
|
|||
# Agents (Compilation, Losses, simple agent functionality tests)
|
||||
# rllib/agents/
|
||||
#
|
||||
# Tag: trainers_dir
|
||||
# Tag: algorithms_dir
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
# Generic (all Trainers)
|
||||
py_test(
|
||||
name = "test_callbacks",
|
||||
tags = ["team:ml", "trainers_dir", "trainers_dir_generic"],
|
||||
tags = ["team:ml", "algorithms_dir", "algorithms_dir_generic"],
|
||||
size = "medium",
|
||||
srcs = ["agents/tests/test_callbacks.py"]
|
||||
)
|
||||
|
@ -651,31 +651,31 @@ py_test(
|
|||
py_test(
|
||||
name = "test_memory_leaks_generic",
|
||||
main = "agents/tests/test_memory_leaks.py",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["agents/tests/test_memory_leaks.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_trainer",
|
||||
tags = ["team:ml", "trainers_dir", "trainers_dir_generic"],
|
||||
tags = ["team:ml", "algorithms_dir", "algorithms_dir_generic"],
|
||||
size = "large",
|
||||
srcs = ["agents/tests/test_trainer.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_worker_failures",
|
||||
tags = ["team:ml", "tests_dir", "trainers_dir_generic"],
|
||||
tags = ["team:ml", "tests_dir", "algorithms_dir_generic"],
|
||||
size = "large",
|
||||
srcs = ["agents/tests/test_worker_failures.py"]
|
||||
)
|
||||
|
||||
# Specific Trainers (Algorithms)
|
||||
# Specific Algorithms
|
||||
|
||||
# A2C
|
||||
py_test(
|
||||
name = "test_a2c",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["algorithms/a2c/tests/test_a2c.py"]
|
||||
)
|
||||
|
@ -683,7 +683,7 @@ py_test(
|
|||
# A3C
|
||||
py_test(
|
||||
name = "test_a3c",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["algorithms/a3c/tests/test_a3c.py"]
|
||||
)
|
||||
|
@ -691,7 +691,7 @@ py_test(
|
|||
# AlphaStar
|
||||
py_test(
|
||||
name = "test_alpha_star",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["algorithms/alpha_star/tests/test_alpha_star.py"]
|
||||
)
|
||||
|
@ -699,31 +699,39 @@ py_test(
|
|||
# AlphaZero
|
||||
py_test(
|
||||
name = "test_alpha_zero",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["algorithms/alpha_zero/tests/test_alpha_zero.py"]
|
||||
)
|
||||
|
||||
# APEXTrainer (DQN)
|
||||
# APEX-DQN
|
||||
py_test(
|
||||
name = "test_apex_dqn",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["agents/dqn/tests/test_apex_dqn.py"]
|
||||
)
|
||||
|
||||
# APEXDDPGTrainer
|
||||
# APEX-DDPG
|
||||
py_test(
|
||||
name = "test_apex_ddpg",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/ddpg/tests/test_apex_ddpg.py"]
|
||||
)
|
||||
|
||||
# APPO
|
||||
py_test(
|
||||
name = "test_appo",
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["algorithms/appo/tests/test_appo.py"]
|
||||
)
|
||||
|
||||
# ARS
|
||||
py_test(
|
||||
name = "test_ars",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/ars/tests/test_ars.py"]
|
||||
)
|
||||
|
@ -731,31 +739,39 @@ py_test(
|
|||
# Bandits
|
||||
py_test(
|
||||
name = "test_bandits",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/bandit/tests/test_bandits.py"],
|
||||
)
|
||||
|
||||
# CQLTrainer
|
||||
# CQL
|
||||
py_test(
|
||||
name = "test_cql",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/cql/tests/test_cql.py"]
|
||||
)
|
||||
|
||||
# DDPGTrainer
|
||||
# DDPG
|
||||
py_test(
|
||||
name = "test_ddpg",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["algorithms/ddpg/tests/test_ddpg.py"]
|
||||
)
|
||||
|
||||
# DQNTrainer
|
||||
# DDPPO
|
||||
py_test(
|
||||
name = "test_ddppo",
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/ddppo/tests/test_ddppo.py"]
|
||||
)
|
||||
|
||||
# DQN
|
||||
py_test(
|
||||
name = "test_dqn",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["algorithms/dqn/tests/test_dqn.py"]
|
||||
)
|
||||
|
@ -763,7 +779,7 @@ py_test(
|
|||
# Dreamer
|
||||
py_test(
|
||||
name = "test_dreamer",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/dreamer/tests/test_dreamer.py"]
|
||||
)
|
||||
|
@ -771,153 +787,137 @@ py_test(
|
|||
# ES
|
||||
py_test(
|
||||
name = "test_es",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/es/tests/test_es.py"]
|
||||
)
|
||||
|
||||
# IMPALA
|
||||
# Impala
|
||||
py_test(
|
||||
name = "test_impala",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["agents/impala/tests/test_impala.py"]
|
||||
srcs = ["algorithms/impala/tests/test_impala.py"]
|
||||
)
|
||||
py_test(
|
||||
name = "test_vtrace",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "small",
|
||||
srcs = ["agents/impala/tests/test_vtrace.py"]
|
||||
srcs = ["algorithms/impala/tests/test_vtrace.py"]
|
||||
)
|
||||
|
||||
# MARWILTrainer
|
||||
# MARWIL
|
||||
py_test(
|
||||
name = "test_marwil",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
# Include the json data file.
|
||||
data = ["tests/data/cartpole/large.json"],
|
||||
srcs = ["algorithms/marwil/tests/test_marwil.py"]
|
||||
)
|
||||
|
||||
# BCTrainer (sub-type of MARWIL)
|
||||
# BC (sub-type of MARWIL)
|
||||
py_test(
|
||||
name = "test_bc",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
# Include the json data file.
|
||||
data = ["tests/data/cartpole/large.json"],
|
||||
srcs = ["algorithms/marwil/tests/test_bc.py"]
|
||||
)
|
||||
|
||||
# MADDPGTrainer
|
||||
# MADDPG
|
||||
py_test(
|
||||
name = "test_maddpg",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/maddpg/tests/test_maddpg.py"]
|
||||
)
|
||||
|
||||
# MAMLTrainer
|
||||
# MAML
|
||||
py_test(
|
||||
name = "test_maml",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/maml/tests/test_maml.py"]
|
||||
)
|
||||
|
||||
# MBMPOTrainer
|
||||
# MBMPO
|
||||
py_test(
|
||||
name = "test_mbmpo",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/mbmpo/tests/test_mbmpo.py"]
|
||||
)
|
||||
|
||||
# PGTrainer
|
||||
# PG
|
||||
py_test(
|
||||
name = "test_pg",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["algorithms/pg/tests/test_pg.py"]
|
||||
)
|
||||
|
||||
# PPOTrainer
|
||||
# PPO
|
||||
py_test(
|
||||
name = "test_ppo",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["agents/ppo/tests/test_ppo.py"]
|
||||
srcs = ["algorithms/ppo/tests/test_ppo.py"]
|
||||
)
|
||||
|
||||
# PPO: DDPPO
|
||||
py_test(
|
||||
name = "test_ddppo",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
size = "medium",
|
||||
srcs = ["agents/ppo/tests/test_ddppo.py"]
|
||||
)
|
||||
|
||||
# PPO: APPO
|
||||
py_test(
|
||||
name = "test_appo",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
size = "large",
|
||||
srcs = ["agents/ppo/tests/test_appo.py"]
|
||||
)
|
||||
|
||||
# QMixTrainer
|
||||
# QMix
|
||||
py_test(
|
||||
name = "test_qmix",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/qmix/tests/test_qmix.py"]
|
||||
)
|
||||
|
||||
# R2D2Trainer
|
||||
# R2D2
|
||||
py_test(
|
||||
name = "test_r2d2",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["agents/dqn/tests/test_r2d2.py"]
|
||||
)
|
||||
|
||||
# RNNSACTrainer
|
||||
# RNNSAC
|
||||
py_test(
|
||||
name = "test_rnnsac",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/sac/tests/test_rnnsac.py"]
|
||||
)
|
||||
|
||||
# SACTrainer
|
||||
# SAC
|
||||
py_test(
|
||||
name = "test_sac",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["algorithms/sac/tests/test_sac.py"]
|
||||
)
|
||||
|
||||
# SimpleQTrainer
|
||||
# SimpleQ
|
||||
py_test(
|
||||
name = "test_simple_q",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/dqn/tests/test_simple_q.py"]
|
||||
)
|
||||
|
||||
# SlateQTrainer
|
||||
# SlateQ
|
||||
py_test(
|
||||
name = "test_slateq",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/slateq/tests/test_slateq.py"]
|
||||
)
|
||||
|
||||
# TD3Trainer
|
||||
# TD3
|
||||
py_test(
|
||||
name = "test_td3",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["algorithms/ddpg/tests/test_td3.py"]
|
||||
)
|
||||
|
@ -928,7 +928,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "random_agent",
|
||||
tags = ["team:ml", "trainers_dir"],
|
||||
tags = ["team:ml", "algorithms_dir"],
|
||||
main = "contrib/random_agent/random_agent.py",
|
||||
size = "small",
|
||||
srcs = ["contrib/random_agent/random_agent.py"]
|
||||
|
@ -957,8 +957,8 @@ py_test(
|
|||
main = "utils/tests/run_memory_leak_tests.py",
|
||||
size = "large",
|
||||
srcs = ["utils/tests/run_memory_leak_tests.py"],
|
||||
data = ["tuned_examples/ppo/memory-leak-test-appo.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/ppo"]
|
||||
data = ["tuned_examples/appo/memory-leak-test-appo.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/appo"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
|
|
@ -123,7 +123,7 @@ Quick First Experiment
|
|||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
|
||||
|
||||
# Define your problem using python and openAI's gym API:
|
||||
|
@ -176,7 +176,7 @@ Quick First Experiment
|
|||
|
||||
# Create an RLlib Trainer instance to learn how to act in the above
|
||||
# environment.
|
||||
trainer = PPOTrainer(
|
||||
trainer = PPO(
|
||||
config={
|
||||
# Env class to use (here: our gym.Env sub-class from above).
|
||||
"env": ParrotEnv,
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from ray.rllib.agents.dqn.apex import ApexConfig, ApexTrainer, APEX_DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.dqn.dqn import DQNConfig, DQNTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
|
||||
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
|
||||
from ray.rllib.agents.dqn.r2d2 import R2D2Trainer, R2D2_DEFAULT_CONFIG
|
||||
from ray.rllib.agents.dqn.r2d2 import R2D2Config, R2D2Trainer, R2D2_DEFAULT_CONFIG
|
||||
from ray.rllib.agents.dqn.r2d2_torch_policy import R2D2TorchPolicy
|
||||
from ray.rllib.algorithms.dqn.simple_q import (
|
||||
SimpleQConfig,
|
||||
|
@ -11,6 +10,7 @@ from ray.rllib.algorithms.dqn.simple_q import (
|
|||
)
|
||||
from ray.rllib.algorithms.dqn.simple_q_tf_policy import SimpleQTFPolicy
|
||||
from ray.rllib.algorithms.dqn.simple_q_torch_policy import SimpleQTorchPolicy
|
||||
from ray.rllib.agents.dqn.apex import ApexConfig, ApexTrainer, APEX_DEFAULT_CONFIG
|
||||
|
||||
__all__ = [
|
||||
"ApexConfig",
|
||||
|
@ -19,6 +19,7 @@ __all__ = [
|
|||
"DQNTFPolicy",
|
||||
"DQNTorchPolicy",
|
||||
"DQNTrainer",
|
||||
"R2D2Config",
|
||||
"R2D2TorchPolicy",
|
||||
"R2D2Trainer",
|
||||
"SimpleQConfig",
|
||||
|
|
|
@ -1,7 +1,17 @@
|
|||
from ray.rllib.agents.impala.impala import DEFAULT_CONFIG, ImpalaConfig, ImpalaTrainer
|
||||
from ray.rllib.algorithms.impala.impala import (
|
||||
DEFAULT_CONFIG,
|
||||
ImpalaConfig,
|
||||
Impala as ImpalaTrainer,
|
||||
)
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ImpalaConfig",
|
||||
"ImpalaTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
||||
|
||||
deprecation_warning(
|
||||
"ray.rllib.agents.impala", "ray.rllib.algorithms.impala", error=False
|
||||
)
|
||||
|
|
|
@ -1,37 +0,0 @@
|
|||
# Proximal Policy Optimization (PPO)
|
||||
|
||||
## Overview
|
||||
|
||||
[PPO](https://arxiv.org/abs/1707.06347) is a model-free on-policy RL algorithm that works well for both discrete and continuous action space environments. PPO utilizes an actor-critic framework, where there are two networks, an actor (policy network) and critic network (value function).
|
||||
|
||||
There are two formulations of PPO, which are both implemented in RLlib. The first formulation of PPO imitates the prior paper [TRPO](https://arxiv.org/abs/1502.05477) without the complexity of second-order optimization. In this formulation, for every iteration, an old version of an actor-network is saved and the agent seeks to optimize the RL objective while staying close to the old policy. This makes sure that the agent does not destabilize during training. In the second formulation, To mitigate destructive large policy updates, an issue discovered for vanilla policy gradient methods, PPO introduces the surrogate objective, which clips large action probability ratios between the current and old policy. Clipping has been shown in the paper to significantly improve training stability and speed.
|
||||
|
||||
## Distributed PPO Algorithms
|
||||
|
||||
PPO is a core algorithm in RLlib due to its ability to scale well with the number of nodes. In RLlib, we provide various implementation of distributed PPO, with different underlying execution plans, as shown below.
|
||||
|
||||
Distributed baseline PPO is a synchronous distributed RL algorithm. Data collection nodes, which represent the old policy, gather data synchronously to create a large pool of on-policy data from which the agent performs minibatch gradient descent on.
|
||||
|
||||
On the other hand, Asychronous PPO (APPO) opts to imitate IMPALA as its distributed execution plan. Data collection nodes gather data asynchronously, which are collected in a circular replay buffer. A target network and doubly-importance sampled surrogate objective is introduced to enforce training stability in the asynchronous data-collection setting.
|
||||
|
||||
Lastly, Decentralized Distributed PPO (DDPPO) removes the assumption that gradient-updates must be done on a central node. Instead, gradients are computed remotely on each data collection node and all-reduced at each mini-batch using torch distributed. This allows each worker’s GPU to be used both for sampling and for training.
|
||||
|
||||
## Documentation & Implementation:
|
||||
|
||||
1) Proximal Policy Optimization (PPO).
|
||||
|
||||
**[Detailed Documentation](https://docs.ray.io/en/master/rllib-algorithms.html#ppo)**
|
||||
|
||||
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ppo.py)**
|
||||
|
||||
2) [Asynchronous Proximal Policy Optimization (APPO)](https://arxiv.org/abs/1912.00167).
|
||||
|
||||
**[Detailed Documentation](https://docs.ray.io/en/master/rllib-algorithms.html#appo)**
|
||||
|
||||
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/appo.py)**
|
||||
|
||||
3) [Decentralized Distributed Proximal Policy Optimization (DDPPO)](https://arxiv.org/abs/1911.00357)
|
||||
|
||||
**[Detailed Documentation](https://docs.ray.io/en/master/rllib-algorithms.html#decentralized-distributed-proximal-policy-optimization-dd-ppo)**
|
||||
|
||||
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ddppo.py)**
|
|
@ -1,18 +1,23 @@
|
|||
from ray.rllib.agents.ppo.ppo import PPOConfig, PPOTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import PPOStaticGraphTFPolicy, PPOEagerTFPolicy
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
from ray.rllib.agents.ppo.appo import APPOConfig, APPOTrainer
|
||||
from ray.rllib.agents.ppo.ddppo import DDPPOConfig, DDPPOTrainer
|
||||
from ray.rllib.algorithms.ppo.ppo import PPOConfig, PPO as PPOTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy, PPOTF2Policy
|
||||
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
from ray.rllib.algorithms.appo.appo import APPOConfig, APPO as APPOTrainer
|
||||
from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF1Policy, APPOTF2Policy
|
||||
from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy
|
||||
from ray.rllib.algorithms.ddppo.ddppo import DDPPOConfig, DDPPO as DDPPOTrainer
|
||||
|
||||
__all__ = [
|
||||
"APPOConfig",
|
||||
"APPOTF1Policy",
|
||||
"APPOTF2Policy",
|
||||
"APPOTorchPolicy",
|
||||
"APPOTrainer",
|
||||
"DDPPOConfig",
|
||||
"DDPPOTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
"PPOConfig",
|
||||
"PPOStaticGraphTFPolicy",
|
||||
"PPOEagerTFPolicy",
|
||||
"PPOTF1Policy",
|
||||
"PPOTF2Policy",
|
||||
"PPOTorchPolicy",
|
||||
"PPOTrainer",
|
||||
]
|
||||
|
|
|
@ -8,55 +8,49 @@ from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
|
|||
def _import_a2c():
|
||||
import ray.rllib.algorithms.a2c as a2c
|
||||
|
||||
return a2c.A2C, a2c.A2C_DEFAULT_CONFIG
|
||||
return a2c.A2C, a2c.A2CConfig().to_dict()
|
||||
|
||||
|
||||
def _import_a3c():
|
||||
import ray.rllib.algorithms.a3c as a3c
|
||||
|
||||
return a3c.A3C, a3c.DEFAULT_CONFIG
|
||||
return a3c.A3C, a3c.A3CConfig().to_dict()
|
||||
|
||||
|
||||
def _import_alpha_star():
|
||||
from ray.rllib.algorithms.alpha_star.alpha_star import (
|
||||
AlphaStarTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
||||
import ray.rllib.algorithms.alpha_star as alpha_star
|
||||
|
||||
return AlphaStarTrainer, DEFAULT_CONFIG
|
||||
return alpha_star.AlphaStarTrainer, alpha_star.AlphaStarConfig().to_dict()
|
||||
|
||||
|
||||
def _import_alpha_zero():
|
||||
from ray.rllib.algorithms.alpha_zero.alpha_zero import (
|
||||
AlphaZeroTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
||||
import ray.rllib.algorithms.alpha_zero as alpha_zero
|
||||
|
||||
return AlphaZeroTrainer, DEFAULT_CONFIG
|
||||
return alpha_zero.AlphaZeroTrainer, alpha_zero.AlphaZeroConfig().to_dict()
|
||||
|
||||
|
||||
def _import_apex():
|
||||
from ray.rllib.agents import dqn
|
||||
|
||||
return dqn.ApexTrainer, dqn.apex.APEX_DEFAULT_CONFIG
|
||||
return dqn.ApexTrainer, dqn.apex.ApexConfig().to_dict()
|
||||
|
||||
|
||||
def _import_apex_ddpg():
|
||||
from ray.rllib.algorithms import ddpg
|
||||
|
||||
return ddpg.ApexDDPGTrainer, ddpg.apex.APEX_DDPG_DEFAULT_CONFIG
|
||||
return ddpg.ApexDDPGTrainer, ddpg.apex.ApexDDPGConfig().to_dict()
|
||||
|
||||
|
||||
def _import_appo():
|
||||
from ray.rllib.agents import ppo
|
||||
import ray.rllib.algorithms.appo as appo
|
||||
|
||||
return ppo.APPOTrainer, ppo.appo.DEFAULT_CONFIG
|
||||
return appo.APPO, appo.APPOConfig().to_dict()
|
||||
|
||||
|
||||
def _import_ars():
|
||||
from ray.rllib.algorithms import ars
|
||||
|
||||
return ars.ARSTrainer, ars.DEFAULT_CONFIG
|
||||
return ars.ARSTrainer, ars.ARSConfig().to_dict()
|
||||
|
||||
|
||||
def _import_bandit_lints():
|
||||
|
@ -74,127 +68,127 @@ def _import_bandit_linucb():
|
|||
def _import_bc():
|
||||
from ray.rllib.algorithms import marwil
|
||||
|
||||
return marwil.BCTrainer, marwil.DEFAULT_CONFIG
|
||||
return marwil.BCTrainer, marwil.BCConfig().to_dict()
|
||||
|
||||
|
||||
def _import_cql():
|
||||
from ray.rllib.algorithms import cql
|
||||
|
||||
return cql.CQLTrainer, cql.DEFAULT_CONFIG
|
||||
return cql.CQLTrainer, cql.CQLConfig().to_dict()
|
||||
|
||||
|
||||
def _import_ddpg():
|
||||
from ray.rllib.algorithms import ddpg
|
||||
|
||||
return ddpg.DDPGTrainer, ddpg.DEFAULT_CONFIG
|
||||
return ddpg.DDPGTrainer, ddpg.DDPGConfig().to_dict()
|
||||
|
||||
|
||||
def _import_ddppo():
|
||||
from ray.rllib.agents import ppo
|
||||
import ray.rllib.algorithms.ddppo as ddppo
|
||||
|
||||
return ppo.DDPPOTrainer, ppo.DEFAULT_CONFIG
|
||||
return ddppo.DDPPO, ddppo.DDPPOConfig().to_dict()
|
||||
|
||||
|
||||
def _import_dqn():
|
||||
from ray.rllib.algorithms import dqn
|
||||
|
||||
return dqn.DQNTrainer, dqn.DEFAULT_CONFIG
|
||||
return dqn.DQNTrainer, dqn.DQNConfig().to_dict()
|
||||
|
||||
|
||||
def _import_dreamer():
|
||||
from ray.rllib.algorithms import dreamer
|
||||
|
||||
return dreamer.DREAMERTrainer, dreamer.DEFAULT_CONFIG
|
||||
return dreamer.DREAMERTrainer, dreamer.DREAMERConfig().to_dict()
|
||||
|
||||
|
||||
def _import_es():
|
||||
from ray.rllib.algorithms import es
|
||||
|
||||
return es.ESTrainer, es.DEFAULT_CONFIG
|
||||
return es.ESTrainer, es.ESConfig().to_dict()
|
||||
|
||||
|
||||
def _import_impala():
|
||||
from ray.rllib.agents import impala
|
||||
import ray.rllib.algorithms.impala as impala
|
||||
|
||||
return impala.ImpalaTrainer, impala.DEFAULT_CONFIG
|
||||
return impala.Impala, impala.ImpalaConfig().to_dict()
|
||||
|
||||
|
||||
def _import_maddpg():
|
||||
from ray.rllib.agents import maddpg
|
||||
import ray.rllib.algorithms.maddpg as maddpg
|
||||
|
||||
return maddpg.MADDPGTrainer, maddpg.DEFAULT_CONFIG
|
||||
return maddpg.MADDPGTrainer, maddpg.MADDPGConfig().to_dict()
|
||||
|
||||
|
||||
def _import_maml():
|
||||
from ray.rllib.algorithms import maml
|
||||
|
||||
return maml.MAMLTrainer, maml.DEFAULT_CONFIG
|
||||
return maml.MAMLTrainer, maml.MAMLConfig().to_dict()
|
||||
|
||||
|
||||
def _import_marwil():
|
||||
from ray.rllib.algorithms import marwil
|
||||
|
||||
return marwil.MARWILTrainer, marwil.DEFAULT_CONFIG
|
||||
return marwil.MARWILTrainer, marwil.MARWILConfig().to_dict()
|
||||
|
||||
|
||||
def _import_mbmpo():
|
||||
from ray.rllib.algorithms import mbmpo
|
||||
|
||||
return mbmpo.MBMPOTrainer, mbmpo.DEFAULT_CONFIG
|
||||
return mbmpo.MBMPOTrainer, mbmpo.MBMPOConfig().to_dict()
|
||||
|
||||
|
||||
def _import_pg():
|
||||
from ray.rllib.algorithms import pg
|
||||
|
||||
return pg.PGTrainer, pg.DEFAULT_CONFIG
|
||||
return pg.PGTrainer, pg.PGConfig().to_dict()
|
||||
|
||||
|
||||
def _import_ppo():
|
||||
from ray.rllib.agents import ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
|
||||
return ppo.PPOTrainer, ppo.DEFAULT_CONFIG
|
||||
return ppo.PPO, ppo.PPOConfig().to_dict()
|
||||
|
||||
|
||||
def _import_qmix():
|
||||
from ray.rllib.algorithms import qmix
|
||||
|
||||
return qmix.QMixTrainer, qmix.DEFAULT_CONFIG
|
||||
return qmix.QMixTrainer, qmix.QMixConfig().to_dict()
|
||||
|
||||
|
||||
def _import_r2d2():
|
||||
from ray.rllib.agents import dqn
|
||||
|
||||
return dqn.R2D2Trainer, dqn.R2D2_DEFAULT_CONFIG
|
||||
return dqn.R2D2Trainer, dqn.R2D2Config().to_dict()
|
||||
|
||||
|
||||
def _import_sac():
|
||||
from ray.rllib.algorithms import sac
|
||||
|
||||
return sac.SACTrainer, sac.DEFAULT_CONFIG
|
||||
return sac.SACTrainer, sac.SACConfig().to_dict()
|
||||
|
||||
|
||||
def _import_rnnsac():
|
||||
from ray.rllib.algorithms import sac
|
||||
|
||||
return sac.RNNSACTrainer, sac.RNNSAC_DEFAULT_CONFIG
|
||||
return sac.RNNSACTrainer, sac.RNNSACConfig().to_dict()
|
||||
|
||||
|
||||
def _import_simple_q():
|
||||
from ray.rllib.algorithms import dqn
|
||||
|
||||
return dqn.SimpleQTrainer, dqn.simple_q.DEFAULT_CONFIG
|
||||
return dqn.SimpleQTrainer, dqn.simple_q.SimpleQConfig().to_dict()
|
||||
|
||||
|
||||
def _import_slate_q():
|
||||
from ray.rllib.algorithms import slateq
|
||||
|
||||
return slateq.SlateQTrainer, slateq.DEFAULT_CONFIG
|
||||
return slateq.SlateQTrainer, slateq.SlateQConfig().to_dict()
|
||||
|
||||
|
||||
def _import_td3():
|
||||
from ray.rllib.algorithms import ddpg
|
||||
|
||||
return ddpg.TD3Trainer, ddpg.td3.TD3_DEFAULT_CONFIG
|
||||
return ddpg.TD3Trainer, ddpg.td3.TD3Config().to_dict()
|
||||
|
||||
|
||||
ALGORITHMS = {
|
||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
|||
|
||||
import ray
|
||||
import ray.rllib.algorithms.dqn as dqn
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
from ray.rllib.examples.env.memory_leaking_env import MemoryLeakingEnv
|
||||
from ray.rllib.examples.policy.memory_leaking_policy import MemoryLeakingPolicy
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
|
@ -30,7 +30,7 @@ class TestMemoryLeaks(unittest.TestCase):
|
|||
config["env_config"] = {
|
||||
"static_samples": True,
|
||||
}
|
||||
trainer = ppo.PPOTrainer(config=config)
|
||||
trainer = ppo.PPO(config=config)
|
||||
results = check_memory_leaks(trainer, to_check={"env"}, repeats=150)
|
||||
assert results["env"]
|
||||
trainer.stop()
|
||||
|
|
|
@ -1431,9 +1431,9 @@ class Trainer(Trainable):
|
|||
If None, the output format will be DL framework specific.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.ppo import PPOTrainer
|
||||
>>> from ray.rllib.algorithms.ppo import PPO
|
||||
>>> # Use a Trainer from RLlib or define your own.
|
||||
>>> trainer = PPOTrainer(...) # doctest: +SKIP
|
||||
>>> trainer = PPO(...) # doctest: +SKIP
|
||||
>>> for _ in range(10): # doctest: +SKIP
|
||||
>>> trainer.train() # doctest: +SKIP
|
||||
>>> trainer.export_policy_model("/tmp/dir") # doctest: +SKIP
|
||||
|
@ -1456,9 +1456,9 @@ class Trainer(Trainable):
|
|||
policy_id: Optional policy id to export.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.ppo import PPOTrainer
|
||||
>>> from ray.rllib.algorithms.ppo import PPO
|
||||
>>> # Use a Trainer from RLlib or define your own.
|
||||
>>> trainer = PPOTrainer(...) # doctest: +SKIP
|
||||
>>> trainer = PPO(...) # doctest: +SKIP
|
||||
>>> for _ in range(10): # doctest: +SKIP
|
||||
>>> trainer.train() # doctest: +SKIP
|
||||
>>> trainer.export_policy_checkpoint("/tmp/export_dir") # doctest: +SKIP
|
||||
|
@ -1478,8 +1478,8 @@ class Trainer(Trainable):
|
|||
policy_id: Optional policy id to import into.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.ppo import PPOTrainer
|
||||
>>> trainer = PPOTrainer(...) # doctest: +SKIP
|
||||
>>> from ray.rllib.algorithms.ppo import PPO
|
||||
>>> trainer = PPO(...) # doctest: +SKIP
|
||||
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5") # doctest: +SKIP
|
||||
>>> for _ in range(10): # doctest: +SKIP
|
||||
>>> trainer.train() # doctest: +SKIP
|
||||
|
|
|
@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class A3CConfig(TrainerConfig):
|
||||
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
|
||||
"""Defines a configuration class from which a A3C Trainer can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray import tune
|
||||
|
|
|
@ -12,7 +12,7 @@ from ray.actor import ActorHandle
|
|||
from ray.rllib.algorithms.alpha_star.distributed_learners import DistributedLearners
|
||||
from ray.rllib.algorithms.alpha_star.league_builder import AlphaStarLeagueBuilder
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
import ray.rllib.agents.ppo.appo as appo
|
||||
import ray.rllib.algorithms.appo.appo as appo
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.execution.parallel_requests import (
|
||||
AsyncRequestsManager,
|
||||
|
@ -232,12 +232,12 @@ class AlphaStarConfig(appo.APPOConfig):
|
|||
return self
|
||||
|
||||
|
||||
class AlphaStarTrainer(appo.APPOTrainer):
|
||||
_allow_unknown_subkeys = appo.APPOTrainer._allow_unknown_subkeys + [
|
||||
class AlphaStarTrainer(appo.APPO):
|
||||
_allow_unknown_subkeys = appo.APPO._allow_unknown_subkeys + [
|
||||
"league_builder_config",
|
||||
]
|
||||
_override_all_subkeys_if_type_changes = (
|
||||
appo.APPOTrainer._override_all_subkeys_if_type_changes
|
||||
appo.APPO._override_all_subkeys_if_type_changes
|
||||
+ [
|
||||
"league_builder_config",
|
||||
]
|
||||
|
@ -310,11 +310,11 @@ class AlphaStarTrainer(appo.APPOTrainer):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
@override(appo.APPOTrainer)
|
||||
@override(appo.APPO)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return AlphaStarConfig().to_dict()
|
||||
|
||||
@override(appo.APPOTrainer)
|
||||
@override(appo.APPO)
|
||||
def validate_config(self, config: TrainerConfigDict):
|
||||
# Create the LeagueBuilder object, allowing it to build the multiagent
|
||||
# config as well.
|
||||
|
@ -323,7 +323,7 @@ class AlphaStarTrainer(appo.APPOTrainer):
|
|||
)
|
||||
super().validate_config(config)
|
||||
|
||||
@override(appo.APPOTrainer)
|
||||
@override(appo.APPO)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
# Call super's setup to validate config, create RolloutWorkers
|
||||
# (train and eval), etc..
|
||||
|
@ -604,7 +604,7 @@ class AlphaStarTrainer(appo.APPOTrainer):
|
|||
|
||||
return train_results
|
||||
|
||||
@override(appo.APPOTrainer)
|
||||
@override(appo.APPO)
|
||||
def __getstate__(self) -> dict:
|
||||
state = super().__getstate__()
|
||||
state.update(
|
||||
|
@ -614,7 +614,7 @@ class AlphaStarTrainer(appo.APPOTrainer):
|
|||
)
|
||||
return state
|
||||
|
||||
@override(appo.APPOTrainer)
|
||||
@override(appo.APPO)
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
state_copy = state.copy()
|
||||
self.league_builder.__setstate__(state.pop("league_builder", {}))
|
||||
|
|
34
rllib/algorithms/appo/README.md
Normal file
34
rllib/algorithms/appo/README.md
Normal file
|
@ -0,0 +1,34 @@
|
|||
# Asynchronous Proximal Policy Optimization (APPO)
|
||||
|
||||
## Overview
|
||||
|
||||
[PPO](https://arxiv.org/abs/1707.06347) is a model-free on-policy RL algorithm that works
|
||||
well for both discrete and continuous action space environments. PPO utilizes an
|
||||
actor-critic framework, where there are two networks, an actor (policy network) and
|
||||
critic network (value function).
|
||||
|
||||
## Distributed PPO Algorithms
|
||||
|
||||
### Distributed baseline PPO
|
||||
[See implementation here](https://github.com/ray-project/ray/blob/master/rllib/algorithms/ppo/ppo.py)
|
||||
|
||||
### Asychronous PPO (APPO) ..
|
||||
|
||||
.. opts to imitate IMPALA as its distributed execution plan.
|
||||
Data collection nodes gather data asynchronously, which are collected in a circular replay
|
||||
buffer. A target network and doubly-importance sampled surrogate objective is introduced
|
||||
to enforce training stability in the asynchronous data-collection setting.
|
||||
[See implementation here](https://github.com/ray-project/ray/blob/master/rllib/algorithms/appo/appo.py)
|
||||
|
||||
### Decentralized Distributed PPO (DDPPO)
|
||||
|
||||
[See implementation here](https://github.com/ray-project/ray/blob/master/rllib/algorithms/ddppo/ddppo.py)
|
||||
|
||||
|
||||
## Documentation & Implementation:
|
||||
|
||||
### [Asynchronous Proximal Policy Optimization (APPO)](https://arxiv.org/abs/1912.00167).
|
||||
|
||||
**[Detailed Documentation](https://docs.ray.io/en/master/rllib-algorithms.html#appo)**
|
||||
|
||||
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/appo.py)**
|
12
rllib/algorithms/appo/__init__.py
Normal file
12
rllib/algorithms/appo/__init__.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
from ray.rllib.algorithms.appo.appo import APPO, APPOConfig, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF1Policy, APPOTF2Policy
|
||||
from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy
|
||||
|
||||
__all__ = [
|
||||
"APPO",
|
||||
"APPOConfig",
|
||||
"APPOTF1Policy",
|
||||
"APPOTF2Policy",
|
||||
"APPOTorchPolicy",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
|
@ -13,7 +13,7 @@ from typing import Optional, Type
|
|||
import logging
|
||||
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.rllib.agents.impala import ImpalaTrainer, ImpalaConfig
|
||||
from ray.rllib.algorithms.impala import Impala, ImpalaConfig
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
|
@ -33,10 +33,10 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class APPOConfig(ImpalaConfig):
|
||||
"""Defines a APPOTrainer configuration class from which a new Trainer can be built.
|
||||
"""Defines a configuration class from which an APPO Trainer can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.ppo import APPOConfig
|
||||
>>> from ray.rllib.algorithms.appo import APPOConfig
|
||||
>>> config = APPOConfig().training(lr=0.01, grad_clip=30.0)\
|
||||
... .resources(num_gpus=1)\
|
||||
... .rollouts(num_rollout_workers=16)
|
||||
|
@ -46,7 +46,7 @@ class APPOConfig(ImpalaConfig):
|
|||
>>> trainer.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.ppo import APPOConfig
|
||||
>>> from ray.rllib.algorithms.appo import APPOConfig
|
||||
>>> from ray import tune
|
||||
>>> config = APPOConfig()
|
||||
>>> # Print out some default values.
|
||||
|
@ -66,7 +66,7 @@ class APPOConfig(ImpalaConfig):
|
|||
|
||||
def __init__(self, trainer_class=None):
|
||||
"""Initializes a APPOConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or APPOTrainer)
|
||||
super().__init__(trainer_class=trainer_class or APPO)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -166,8 +166,9 @@ class APPOConfig(ImpalaConfig):
|
|||
return self
|
||||
|
||||
|
||||
class APPOTrainer(ImpalaTrainer):
|
||||
class APPO(Impala):
|
||||
def __init__(self, config, *args, **kwargs):
|
||||
"""Initializes a DDPPO instance."""
|
||||
super().__init__(config, *args, **kwargs)
|
||||
|
||||
# After init: Initialize target net.
|
||||
|
@ -226,7 +227,7 @@ class APPOTrainer(ImpalaTrainer):
|
|||
# Worker.
|
||||
self.workers.local_worker().foreach_policy_to_train(update)
|
||||
|
||||
@override(ImpalaTrainer)
|
||||
@override(Impala)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
train_results = super().training_iteration()
|
||||
|
||||
|
@ -236,36 +237,36 @@ class APPOTrainer(ImpalaTrainer):
|
|||
return train_results
|
||||
|
||||
@classmethod
|
||||
@override(ImpalaTrainer)
|
||||
@override(Impala)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return APPOConfig().to_dict()
|
||||
|
||||
@override(ImpalaTrainer)
|
||||
@override(Impala)
|
||||
def get_default_policy_class(
|
||||
self, config: PartialTrainerConfigDict
|
||||
) -> Optional[Type[Policy]]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.ppo.appo_torch_policy import APPOTorchPolicy
|
||||
from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy
|
||||
|
||||
return APPOTorchPolicy
|
||||
elif config["framework"] == "tf":
|
||||
from ray.rllib.agents.ppo.appo_tf_policy import APPOStaticGraphTFPolicy
|
||||
from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF1Policy
|
||||
|
||||
return APPOStaticGraphTFPolicy
|
||||
return APPOTF1Policy
|
||||
else:
|
||||
from ray.rllib.agents.ppo.appo_tf_policy import APPOEagerTFPolicy
|
||||
from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF2Policy
|
||||
|
||||
return APPOEagerTFPolicy
|
||||
return APPOTF2Policy
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.ppo.APPOConfig instead!
|
||||
# Deprecated: Use ray.rllib.algorithms.appo.APPOConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(APPOConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.agents.ppo.appo.DEFAULT_CONFIG",
|
||||
new="ray.rllib.agents.ppo.appo.APPOConfig(...)",
|
||||
old="ray.rllib.agents.ppo.appo::DEFAULT_CONFIG",
|
||||
new="ray.rllib.algorithms.appo.appo::APPOConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
|
@ -11,8 +11,8 @@ import gym
|
|||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.impala import vtrace_tf as vtrace
|
||||
from ray.rllib.agents.impala.vtrace_tf_policy import (
|
||||
from ray.rllib.algorithms.impala import vtrace_tf as vtrace
|
||||
from ray.rllib.algorithms.impala.impala_tf_policy import (
|
||||
_make_time_major,
|
||||
VTraceClipGradients,
|
||||
VTraceOptimizer,
|
||||
|
@ -123,7 +123,7 @@ def get_appo_tf_policy(base: type) -> type:
|
|||
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
|
||||
|
||||
Returns:
|
||||
A TF Policy to be used with ImpalaTrainer.
|
||||
A TF Policy to be used with Impala.
|
||||
"""
|
||||
|
||||
class APPOTFPolicy(
|
||||
|
@ -147,7 +147,9 @@ def get_appo_tf_policy(base: type) -> type:
|
|||
# First thing first, enable eager execution if necessary.
|
||||
base.enable_eager_execution_if_necessary()
|
||||
|
||||
config = dict(ray.rllib.agents.ppo.appo.DEFAULT_CONFIG, **config)
|
||||
config = dict(
|
||||
ray.rllib.algorithms.appo.appo.APPOConfig().to_dict(), **config
|
||||
)
|
||||
|
||||
# Although this is a no-op, we call __init__ here to make it clear
|
||||
# that base.__init__ will use the make_model() call.
|
||||
|
@ -470,5 +472,5 @@ def get_appo_tf_policy(base: type) -> type:
|
|||
return APPOTFPolicy
|
||||
|
||||
|
||||
APPOStaticGraphTFPolicy = get_appo_tf_policy(DynamicTFPolicyV2)
|
||||
APPOEagerTFPolicy = get_appo_tf_policy(EagerTFPolicyV2)
|
||||
APPOTF1Policy = get_appo_tf_policy(DynamicTFPolicyV2)
|
||||
APPOTF2Policy = get_appo_tf_policy(EagerTFPolicyV2)
|
|
@ -11,12 +11,12 @@ import logging
|
|||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.impala.vtrace_torch as vtrace
|
||||
from ray.rllib.agents.impala.vtrace_torch_policy import (
|
||||
from ray.rllib.algorithms.appo.appo_tf_policy import make_appo_model
|
||||
import ray.rllib.algorithms.impala.vtrace_torch as vtrace
|
||||
from ray.rllib.algorithms.impala.impala_torch_policy import (
|
||||
make_time_major,
|
||||
VTraceOptimizer,
|
||||
)
|
||||
from ray.rllib.agents.ppo.appo_tf_policy import make_appo_model
|
||||
from ray.rllib.evaluation.episode import Episode
|
||||
from ray.rllib.evaluation.postprocessing import (
|
||||
compute_gae_for_sample_batch,
|
||||
|
@ -65,10 +65,10 @@ class APPOTorchPolicy(
|
|||
TargetNetworkMixin,
|
||||
TorchPolicyV2,
|
||||
):
|
||||
"""PyTorch policy class used with APPOTrainer."""
|
||||
"""PyTorch policy class used with APPO."""
|
||||
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.ppo.appo.DEFAULT_CONFIG, **config)
|
||||
config = dict(ray.rllib.algorithms.appo.appo.APPOConfig().to_dict(), **config)
|
||||
|
||||
# Although this is a no-op, we call __init__ here to make it clear
|
||||
# that base.__init__ will use the make_model() call.
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.appo as appo
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.test_utils import (
|
||||
|
@ -21,8 +21,8 @@ class TestAPPO(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
|
||||
def test_appo_compilation(self):
|
||||
"""Test whether an APPOTrainer can be built with both frameworks."""
|
||||
config = ppo.appo.APPOConfig().rollouts(num_rollout_workers=1)
|
||||
"""Test whether APPO can be built with both frameworks."""
|
||||
config = appo.APPOConfig().rollouts(num_rollout_workers=1)
|
||||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
|
@ -47,11 +47,9 @@ class TestAPPO(unittest.TestCase):
|
|||
trainer.stop()
|
||||
|
||||
def test_appo_compilation_use_kl_loss(self):
|
||||
"""Test whether an APPOTrainer can be built with kl_loss enabled."""
|
||||
"""Test whether APPO can be built with kl_loss enabled."""
|
||||
config = (
|
||||
ppo.appo.APPOConfig()
|
||||
.rollouts(num_rollout_workers=1)
|
||||
.training(use_kl_loss=True)
|
||||
appo.APPOConfig().rollouts(num_rollout_workers=1).training(use_kl_loss=True)
|
||||
)
|
||||
num_iterations = 2
|
||||
|
||||
|
@ -68,7 +66,7 @@ class TestAPPO(unittest.TestCase):
|
|||
# Not explicitly setting this should cause a warning, but not fail.
|
||||
# config["_tf_policy_handles_more_than_one_loss"] = True
|
||||
config = (
|
||||
ppo.appo.APPOConfig()
|
||||
appo.APPOConfig()
|
||||
.rollouts(num_rollout_workers=1)
|
||||
.training(_separate_vf_optimizer=True, _lr_vf=0.002)
|
||||
)
|
||||
|
@ -91,7 +89,7 @@ class TestAPPO(unittest.TestCase):
|
|||
def test_appo_entropy_coeff_schedule(self):
|
||||
# Initial lr, doesn't really matter because of the schedule below.
|
||||
config = (
|
||||
ppo.appo.APPOConfig()
|
||||
appo.APPOConfig()
|
||||
.rollouts(
|
||||
num_rollout_workers=1,
|
||||
batch_mode="truncate_episodes",
|
|
@ -2,8 +2,8 @@ from typing import List, Optional
|
|||
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib.agents import Trainer
|
||||
from ray.rllib.agents.dqn.apex import ApexTrainer
|
||||
from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig, DDPGTrainer
|
||||
from ray.rllib.agents.dqn.apex import ApexTrainer
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
|
35
rllib/algorithms/ddppo/README.md
Normal file
35
rllib/algorithms/ddppo/README.md
Normal file
|
@ -0,0 +1,35 @@
|
|||
# Decentralized Distributed Proximal Policy Optimization (DDPPO)
|
||||
|
||||
## Overview
|
||||
|
||||
[PPO](https://arxiv.org/abs/1707.06347) is a model-free on-policy RL algorithm that works
|
||||
well for both discrete and continuous action space environments. PPO utilizes an
|
||||
actor-critic framework, where there are two networks, an actor (policy network) and
|
||||
critic network (value function).
|
||||
|
||||
## Distributed PPO Algorithms
|
||||
|
||||
### Distributed baseline PPO
|
||||
[See implementation here](https://github.com/ray-project/ray/blob/master/rllib/algorithms/ppo/ppo.py)
|
||||
|
||||
### Asychronous PPO (APPO)
|
||||
[See implementation here](https://github.com/ray-project/ray/blob/master/rllib/algorithms/appo/appo.py)
|
||||
|
||||
|
||||
### Decentralized Distributed PPO (DDPPO) ..
|
||||
|
||||
.. removes the assumption that gradient-updates must
|
||||
be done on a central node. Instead, gradients are computed remotely on each data
|
||||
collection node and all-reduced at each mini-batch using torch distributed. This allows
|
||||
each worker’s GPU to be used both for sampling and for training.
|
||||
|
||||
[See implementation here](https://github.com/ray-project/ray/blob/master/rllib/algorithms/ddppo/ddppo.py)
|
||||
|
||||
|
||||
## Documentation & Implementation:
|
||||
|
||||
### [Decentralized Distributed Proximal Policy Optimization (DDPPO)](https://arxiv.org/abs/1911.00357)
|
||||
|
||||
**[Detailed Documentation](https://docs.ray.io/en/master/rllib-algorithms.html#decentralized-distributed-proximal-policy-optimization-dd-ppo)**
|
||||
|
||||
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/algorithms/ddppo/ddppo.py)**
|
7
rllib/algorithms/ddppo/__init__.py
Normal file
7
rllib/algorithms/ddppo/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
from ray.rllib.algorithms.ddppo.ddppo import DDPPOConfig, DDPPO, DEFAULT_CONFIG
|
||||
|
||||
__all__ = [
|
||||
"DDPPOConfig",
|
||||
"DDPPO",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
|
@ -21,7 +21,7 @@ import time
|
|||
from typing import Callable, Optional, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo import PPOConfig, PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPOConfig, PPO
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.execution.common import (
|
||||
|
@ -52,10 +52,10 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class DDPPOConfig(PPOConfig):
|
||||
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
|
||||
"""Defines a configuration class from which a DDPPO Trainer can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.ppo import DDPPOConfig
|
||||
>>> from ray.rllib.algorithms.ddppo import DDPPOConfig
|
||||
>>> config = DDPPOConfig().training(lr=0.003, keep_local_weights_in_sync=True)\
|
||||
... .resources(num_gpus=1)\
|
||||
... .rollouts(num_workers=10)
|
||||
|
@ -65,7 +65,7 @@ class DDPPOConfig(PPOConfig):
|
|||
>>> trainer.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.ppo import DDPPOConfig
|
||||
>>> from ray.rllib.algorithms.ddppo import DDPPOConfig
|
||||
>>> from ray import tune
|
||||
>>> config = DDPPOConfig()
|
||||
>>> # Print out some default values.
|
||||
|
@ -85,7 +85,7 @@ class DDPPOConfig(PPOConfig):
|
|||
|
||||
def __init__(self, trainer_class=None):
|
||||
"""Initializes a DDPPOConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or DDPPOTrainer)
|
||||
super().__init__(trainer_class=trainer_class or DDPPO)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -157,7 +157,7 @@ class DDPPOConfig(PPOConfig):
|
|||
return self
|
||||
|
||||
|
||||
class DDPPOTrainer(PPOTrainer):
|
||||
class DDPPO(PPO):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[PartialTrainerConfigDict] = None,
|
||||
|
@ -166,7 +166,7 @@ class DDPPOTrainer(PPOTrainer):
|
|||
remote_checkpoint_dir: Optional[str] = None,
|
||||
sync_function_tpl: Optional[str] = None,
|
||||
):
|
||||
"""Initializes a DDPPOTrainer instance.
|
||||
"""Initializes a DDPPO instance.
|
||||
|
||||
Args:
|
||||
config: Algorithm-specific configuration dict.
|
||||
|
@ -195,11 +195,11 @@ class DDPPOTrainer(PPOTrainer):
|
|||
) * config.get("num_envs_per_worker", DEFAULT_CONFIG["num_envs_per_worker"])
|
||||
|
||||
@classmethod
|
||||
@override(PPOTrainer)
|
||||
@override(PPO)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return DDPPOConfig().to_dict()
|
||||
|
||||
@override(PPOTrainer)
|
||||
@override(PPO)
|
||||
def validate_config(self, config):
|
||||
"""Validates the Trainer's config dict.
|
||||
|
||||
|
@ -250,7 +250,7 @@ class DDPPOTrainer(PPOTrainer):
|
|||
if config["kl_coeff"] != 0.0 or config["kl_target"] != 0.0:
|
||||
raise ValueError("DDPPO doesn't support KL penalties like PPO-1")
|
||||
|
||||
@override(PPOTrainer)
|
||||
@override(PPO)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
super().setup(config)
|
||||
|
||||
|
@ -281,7 +281,7 @@ class DDPPOTrainer(PPOTrainer):
|
|||
ray_wait_timeout_s=0.03,
|
||||
)
|
||||
|
||||
@override(PPOTrainer)
|
||||
@override(PPO)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
# Shortcut.
|
||||
first_worker = self.workers.remote_workers()[0]
|
||||
|
@ -372,14 +372,14 @@ class DDPPOTrainer(PPOTrainer):
|
|||
}
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.ppo.DDPPOConfig instead!
|
||||
# Deprecated: Use ray.rllib.algorithms.ddppo.DDPPOConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(DDPPOConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.agents.ppo.ddppo.DEFAULT_CONFIG",
|
||||
new="ray.rllib.agents.ppo.ddppo.DDPPOConfig(...)",
|
||||
old="ray.rllib.agents.ppo.ddppo::DEFAULT_CONFIG",
|
||||
new="ray.rllib.algorithms.ddppo.ddppo::DDPPOConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
|
@ -2,7 +2,7 @@ import unittest
|
|||
import pytest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ddppo as ddppo
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.test_utils import (
|
||||
|
@ -23,8 +23,8 @@ class TestDDPPO(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
|
||||
def test_ddppo_compilation(self):
|
||||
"""Test whether a DDPPOTrainer can be built with both frameworks."""
|
||||
config = ppo.DDPPOConfig().resources(num_gpus_per_worker=0)
|
||||
"""Test whether DDPPO can be built with both frameworks."""
|
||||
config = ddppo.DDPPOConfig().resources(num_gpus_per_worker=0)
|
||||
|
||||
num_iterations = 2
|
||||
|
||||
|
@ -44,7 +44,7 @@ class TestDDPPO(unittest.TestCase):
|
|||
|
||||
def test_ddppo_schedule(self):
|
||||
"""Test whether lr_schedule will anneal lr to 0"""
|
||||
config = ppo.DDPPOConfig()
|
||||
config = ddppo.DDPPOConfig()
|
||||
config.resources(num_gpus_per_worker=0)
|
||||
config.training(lr_schedule=[[0, config.lr], [1000, 0.0]])
|
||||
|
||||
|
@ -64,7 +64,7 @@ class TestDDPPO(unittest.TestCase):
|
|||
|
||||
def test_validate_config(self):
|
||||
"""Test if DDPPO will raise errors after invalid configs are passed."""
|
||||
config = ppo.DDPPOConfig().training(kl_coeff=1.0)
|
||||
config = ddppo.DDPPOConfig().training(kl_coeff=1.0)
|
||||
msg = "DDPPO doesn't support KL penalties like PPO-1"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
config.build(env="CartPole-v0")
|
|
@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class DREAMERConfig(TrainerConfig):
|
||||
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
|
||||
"""Defines a configuration class from which a Dreamer Trainer can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.dreamer import DREAMERConfig
|
||||
|
|
15
rllib/algorithms/impala/__init__.py
Normal file
15
rllib/algorithms/impala/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from ray.rllib.algorithms.impala.impala import Impala, ImpalaConfig, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.impala.impala_tf_policy import (
|
||||
ImpalaTF1Policy,
|
||||
ImpalaTF2Policy,
|
||||
)
|
||||
from ray.rllib.algorithms.impala.impala_torch_policy import ImpalaTorchPolicy
|
||||
|
||||
__all__ = [
|
||||
"ImpalaConfig",
|
||||
"Impala",
|
||||
"ImpalaTF1Policy",
|
||||
"ImpalaTF2Policy",
|
||||
"ImpalaTorchPolicy",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
|
@ -56,10 +56,10 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ImpalaConfig(TrainerConfig):
|
||||
"""Defines a configuration class from which an ImpalaTrainer can be built.
|
||||
"""Defines a configuration class from which an Impala can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.impala import ImpalaConfig
|
||||
>>> from ray.rllib.algorithms.impala import ImpalaConfig
|
||||
>>> config = ImpalaConfig().training(lr=0.0003, train_batch_size=512)\
|
||||
... .resources(num_gpus=4)\
|
||||
... .rollouts(num_rollout_workers=64)
|
||||
|
@ -69,7 +69,7 @@ class ImpalaConfig(TrainerConfig):
|
|||
>>> trainer.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.impala import ImpalaConfig
|
||||
>>> from ray.rllib.algorithms.impala import ImpalaConfig
|
||||
>>> from ray import tune
|
||||
>>> config = ImpalaConfig()
|
||||
>>> # Print out some default values.
|
||||
|
@ -89,7 +89,7 @@ class ImpalaConfig(TrainerConfig):
|
|||
|
||||
def __init__(self, trainer_class=None):
|
||||
"""Initializes a ImpalaConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or ImpalaTrainer)
|
||||
super().__init__(trainer_class=trainer_class or Impala)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -433,7 +433,7 @@ class BroadcastUpdateLearnerWeights:
|
|||
self.workers.local_worker().set_global_vars(_get_global_vars())
|
||||
|
||||
|
||||
class ImpalaTrainer(Trainer):
|
||||
class Impala(Trainer):
|
||||
"""Importance weighted actor/learner architecture (IMPALA) Trainer
|
||||
|
||||
== Overview of data flow in IMPALA ==
|
||||
|
@ -458,31 +458,31 @@ class ImpalaTrainer(Trainer):
|
|||
) -> Optional[Type[Policy]]:
|
||||
if config["framework"] == "torch":
|
||||
if config["vtrace"]:
|
||||
from ray.rllib.agents.impala.vtrace_torch_policy import (
|
||||
VTraceTorchPolicy,
|
||||
from ray.rllib.algorithms.impala.impala_torch_policy import (
|
||||
ImpalaTorchPolicy,
|
||||
)
|
||||
|
||||
return VTraceTorchPolicy
|
||||
return ImpalaTorchPolicy
|
||||
else:
|
||||
from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy
|
||||
|
||||
return A3CTorchPolicy
|
||||
elif config["framework"] == "tf":
|
||||
if config["vtrace"]:
|
||||
from ray.rllib.agents.impala.vtrace_tf_policy import (
|
||||
VTraceStaticGraphTFPolicy,
|
||||
from ray.rllib.algorithms.impala.impala_tf_policy import (
|
||||
ImpalaTF1Policy,
|
||||
)
|
||||
|
||||
return VTraceStaticGraphTFPolicy
|
||||
return ImpalaTF1Policy
|
||||
else:
|
||||
from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy
|
||||
|
||||
return A3CTFPolicy
|
||||
else:
|
||||
if config["vtrace"]:
|
||||
from ray.rllib.agents.impala.vtrace_tf_policy import VTraceEagerTFPolicy
|
||||
from ray.rllib.algorithms.impala.impala_tf_policy import ImpalaTF2Policy
|
||||
|
||||
return VTraceEagerTFPolicy
|
||||
return ImpalaTF2Policy
|
||||
else:
|
||||
from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy
|
||||
|
||||
|
@ -940,14 +940,14 @@ class AggregatorWorker:
|
|||
return platform.node()
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.pg.PGConfig instead!
|
||||
# Deprecated: Use ray.rllib.algorithms.impala.ImpalaConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(ImpalaConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.agents.impala.default_config::DEFAULT_CONFIG",
|
||||
new="ray.rllib.agents.impala.impala.IMPALAConfig(...)",
|
||||
old="ray.rllib.agents.impala.impala::DEFAULT_CONFIG",
|
||||
new="ray.rllib.algorithms.impala.impala::IMPALAConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
|
@ -8,7 +8,7 @@ import gym
|
|||
from typing import Dict, List, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.impala import vtrace_tf as vtrace
|
||||
from ray.rllib.algorithms.impala import vtrace_tf as vtrace
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.tf_action_dist import Categorical, TFActionDistribution
|
||||
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
|
||||
|
@ -23,6 +23,7 @@ from ray.rllib.utils.typing import (
|
|||
LocalOptimizer,
|
||||
ModelGradients,
|
||||
TensorType,
|
||||
TFPolicyV2Type,
|
||||
)
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
@ -218,7 +219,7 @@ class VTraceOptimizer:
|
|||
pass
|
||||
|
||||
# TODO: maybe standardize this function, so the choice of optimizers are more
|
||||
# predictable for common agents.
|
||||
# predictable for common algorithms.
|
||||
def optimizer(
|
||||
self,
|
||||
) -> Union["tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"]]:
|
||||
|
@ -253,19 +254,19 @@ class VTraceOptimizer:
|
|||
|
||||
# We need this builder function because we want to share the same
|
||||
# custom logics between TF1 dynamic and TF2 eager policies.
|
||||
def get_vtrace_tf_policy(base: type) -> type:
|
||||
"""Construct an VTraceTFPolicy inheriting either dynamic or eager base policies.
|
||||
def get_impala_tf_policy(base: TFPolicyV2Type) -> TFPolicyV2Type:
|
||||
"""Construct an ImpalaTFPolicy inheriting either dynamic or eager base policies.
|
||||
|
||||
Args:
|
||||
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
|
||||
|
||||
Returns:
|
||||
A TF Policy to be used with ImpalaTrainer.
|
||||
A TF Policy to be used with Impala.
|
||||
"""
|
||||
# VTrace mixins are placed in front of more general mixins to make sure
|
||||
# their functions like optimizer() overrides all the other implementations
|
||||
# (e.g., LearningRateSchedule.optimizer())
|
||||
class VTraceTFPolicy(
|
||||
class ImpalaTFPolicy(
|
||||
VTraceClipGradients,
|
||||
VTraceOptimizer,
|
||||
LearningRateSchedule,
|
||||
|
@ -283,7 +284,9 @@ def get_vtrace_tf_policy(base: type) -> type:
|
|||
# First thing first, enable eager execution if necessary.
|
||||
base.enable_eager_execution_if_necessary()
|
||||
|
||||
config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
|
||||
config = dict(
|
||||
ray.rllib.algorithms.impala.impala.ImpalaConfig().to_dict(), **config
|
||||
)
|
||||
|
||||
# Initialize base class.
|
||||
base.__init__(
|
||||
|
@ -434,8 +437,8 @@ def get_vtrace_tf_policy(base: type) -> type:
|
|||
def get_batch_divisibility_req(self) -> int:
|
||||
return self.config["rollout_fragment_length"]
|
||||
|
||||
return VTraceTFPolicy
|
||||
return ImpalaTFPolicy
|
||||
|
||||
|
||||
VTraceStaticGraphTFPolicy = get_vtrace_tf_policy(DynamicTFPolicyV2)
|
||||
VTraceEagerTFPolicy = get_vtrace_tf_policy(EagerTFPolicyV2)
|
||||
ImpalaTF1Policy = get_impala_tf_policy(DynamicTFPolicyV2)
|
||||
ImpalaTF2Policy = get_impala_tf_policy(EagerTFPolicyV2)
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
from typing import Dict, List, Type, Union
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.impala.vtrace_torch as vtrace
|
||||
import ray.rllib.algorithms.impala.vtrace_torch as vtrace
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
|
@ -188,16 +188,18 @@ class VTraceOptimizer:
|
|||
# VTrace mixins are placed in front of more general mixins to make sure
|
||||
# their functions like optimizer() overrides all the other implementations
|
||||
# (e.g., LearningRateSchedule.optimizer())
|
||||
class VTraceTorchPolicy(
|
||||
class ImpalaTorchPolicy(
|
||||
VTraceOptimizer,
|
||||
LearningRateSchedule,
|
||||
EntropyCoeffSchedule,
|
||||
TorchPolicyV2,
|
||||
):
|
||||
"""PyTorch policy class used with ImpalaTrainer."""
|
||||
"""PyTorch policy class used with Impala."""
|
||||
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config)
|
||||
config = dict(
|
||||
ray.rllib.algorithms.impala.impala.ImpalaConfig().to_dict(), **config
|
||||
)
|
||||
|
||||
VTraceOptimizer.__init__(self)
|
||||
# Need to initialize learning rate variable before calling
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.impala as impala
|
||||
import ray.rllib.algorithms.impala as impala
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
|
||||
|
@ -25,7 +25,7 @@ class TestIMPALA(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
|
||||
def test_impala_compilation(self):
|
||||
"""Test whether an ImpalaTrainer can be built with both frameworks."""
|
||||
"""Test whether Impala can be built with both frameworks."""
|
||||
config = (
|
||||
impala.ImpalaConfig()
|
||||
.resources(num_gpus=0)
|
|
@ -24,8 +24,8 @@ from gym.spaces import Box
|
|||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from ray.rllib.agents.impala import vtrace_tf as vtrace_tf
|
||||
from ray.rllib.agents.impala import vtrace_torch as vtrace_torch
|
||||
from ray.rllib.algorithms.impala import vtrace_tf as vtrace_tf
|
||||
from ray.rllib.algorithms.impala import vtrace_torch as vtrace_torch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.numpy import softmax
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
|
@ -29,7 +29,7 @@ multi_from_logits method accepts lists of tensors instead of just
|
|||
tensors.
|
||||
"""
|
||||
|
||||
from ray.rllib.agents.impala.vtrace_tf import VTraceFromLogitsReturns, VTraceReturns
|
||||
from ray.rllib.algorithms.impala.vtrace_tf import VTraceFromLogitsReturns, VTraceReturns
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
from ray.rllib.utils import force_list
|
||||
from ray.rllib.utils.framework import try_import_torch
|
|
@ -2,7 +2,7 @@ import logging
|
|||
from typing import Dict, List, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import validate_config
|
||||
from ray.rllib.algorithms.ppo.ppo_tf_policy import validate_config
|
||||
from ray.rllib.evaluation.postprocessing import (
|
||||
Postprocessing,
|
||||
compute_gae_for_sample_batch,
|
||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
|||
from typing import Dict, List, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import validate_config
|
||||
from ray.rllib.algorithms.ppo.ppo_tf_policy import validate_config
|
||||
from ray.rllib.evaluation.postprocessing import (
|
||||
Postprocessing,
|
||||
compute_gae_for_sample_batch,
|
||||
|
|
49
rllib/algorithms/ppo/README.md
Normal file
49
rllib/algorithms/ppo/README.md
Normal file
|
@ -0,0 +1,49 @@
|
|||
# Proximal Policy Optimization (PPO)
|
||||
|
||||
## Overview
|
||||
|
||||
[PPO](https://arxiv.org/abs/1707.06347) is a model-free on-policy RL algorithm that works
|
||||
well for both discrete and continuous action space environments. PPO utilizes an
|
||||
actor-critic framework, where there are two networks, an actor (policy network) and
|
||||
critic network (value function).
|
||||
|
||||
There are two formulations of PPO, which are both implemented in RLlib. The first
|
||||
formulation of PPO imitates the prior paper [TRPO](https://arxiv.org/abs/1502.05477)
|
||||
without the complexity of second-order optimization. In this formulation, for every
|
||||
iteration, an old version of an actor-network is saved and the agent seeks to optimize
|
||||
the RL objective while staying close to the old policy. This makes sure that the agent
|
||||
does not destabilize during training. In the second formulation, To mitigate destructive
|
||||
large policy updates, an issue discovered for vanilla policy gradient methods, PPO
|
||||
introduces the surrogate objective, which clips large action probability ratios between
|
||||
the current and old policy. Clipping has been shown in the paper to significantly
|
||||
improve training stability and speed.
|
||||
|
||||
## Distributed PPO Algorithms
|
||||
|
||||
PPO is a core algorithm in RLlib due to its ability to scale well with the number of nodes.
|
||||
|
||||
In RLlib, we provide various implementations of distributed PPO, with different underlying
|
||||
execution plans, as shown below:
|
||||
|
||||
### Distributed baseline PPO ..
|
||||
.. is a synchronous distributed RL algorithm (this algo here).
|
||||
Data collection nodes, which represent the old policy, gather data synchronously to
|
||||
create a large pool of on-policy data from which the agent performs minibatch
|
||||
gradient descent on.
|
||||
|
||||
### Asychronous PPO (APPO)
|
||||
|
||||
[See implementation here](https://github.com/ray-project/ray/blob/master/rllib/algorithms/appo/appo.py)
|
||||
|
||||
### Decentralized Distributed PPO (DDPPO)
|
||||
|
||||
[See implementation here](https://github.com/ray-project/ray/blob/master/rllib/algorithms/ddppo/ddppo.py)
|
||||
|
||||
|
||||
## Documentation & Implementation:
|
||||
|
||||
### Proximal Policy Optimization (PPO).
|
||||
|
||||
**[Detailed Documentation](https://docs.ray.io/en/master/rllib-algorithms.html#ppo)**
|
||||
|
||||
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/algorithms/ppo/ppo.py)**
|
12
rllib/algorithms/ppo/__init__.py
Normal file
12
rllib/algorithms/ppo/__init__.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
from ray.rllib.algorithms.ppo.ppo import PPOConfig, PPO, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy, PPOTF2Policy
|
||||
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
|
||||
__all__ = [
|
||||
"PPOConfig",
|
||||
"PPOTF1Policy",
|
||||
"PPOTF2Policy",
|
||||
"PPOTorchPolicy",
|
||||
"PPO",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
|
@ -40,10 +40,10 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class PPOConfig(TrainerConfig):
|
||||
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
|
||||
"""Defines a configuration class from which a PPO Trainer can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.ppo import PPOConfig
|
||||
>>> from ray.rllib.algorithms.ppo import PPOConfig
|
||||
>>> config = PPOConfig().training(gamma=0.9, lr=0.01, kl_coeff=0.3)\
|
||||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_workers=4)
|
||||
|
@ -53,7 +53,7 @@ class PPOConfig(TrainerConfig):
|
|||
>>> trainer.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.ppo import PPOConfig
|
||||
>>> from ray.rllib.algorithms.ppo import PPOConfig
|
||||
>>> from ray import tune
|
||||
>>> config = PPOConfig()
|
||||
>>> # Print out some default values.
|
||||
|
@ -73,7 +73,7 @@ class PPOConfig(TrainerConfig):
|
|||
|
||||
def __init__(self, trainer_class=None):
|
||||
"""Initializes a PPOConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or PPOTrainer)
|
||||
super().__init__(trainer_class=trainer_class or PPO)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -267,7 +267,7 @@ def warn_about_bad_reward_scales(config, result):
|
|||
return result
|
||||
|
||||
|
||||
class PPOTrainer(Trainer):
|
||||
class PPO(Trainer):
|
||||
# TODO: Change the return value of this method to return a TrainerConfig object
|
||||
# instead.
|
||||
@classmethod
|
||||
|
@ -365,17 +365,17 @@ class PPOTrainer(Trainer):
|
|||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
|
||||
return PPOTorchPolicy
|
||||
elif config["framework"] == "tf":
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import PPOStaticGraphTFPolicy
|
||||
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
|
||||
|
||||
return PPOStaticGraphTFPolicy
|
||||
return PPOTF1Policy
|
||||
else:
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import PPOEagerTFPolicy
|
||||
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF2Policy
|
||||
|
||||
return PPOEagerTFPolicy
|
||||
return PPOTF2Policy
|
||||
|
||||
@ExperimentalAPI
|
||||
def training_iteration(self) -> ResultDict:
|
||||
|
@ -455,14 +455,14 @@ class PPOTrainer(Trainer):
|
|||
return train_results
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.ppo.PPOConfig instead!
|
||||
# Deprecated: Use ray.rllib.algorithms.ppo.PPOConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(PPOConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG",
|
||||
new="ray.rllib.agents.ppo.ppo.PPOConfig(...)",
|
||||
old="ray.rllib.agents.ppo.ppo::DEFAULT_CONFIG",
|
||||
new="ray.rllib.algorithms.ppo.ppo::PPOConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
|
@ -60,7 +60,7 @@ def get_ppo_tf_policy(base: TFPolicyV2Type) -> TFPolicyV2Type:
|
|||
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
|
||||
|
||||
Returns:
|
||||
A TF Policy to be used with PPOTrainer.
|
||||
A TF Policy to be used with PPO.
|
||||
"""
|
||||
|
||||
class PPOTFPolicy(
|
||||
|
@ -81,7 +81,7 @@ def get_ppo_tf_policy(base: TFPolicyV2Type) -> TFPolicyV2Type:
|
|||
# First thing first, enable eager execution if necessary.
|
||||
base.enable_eager_execution_if_necessary()
|
||||
|
||||
config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
|
||||
config = dict(ray.rllib.algorithms.ppo.ppo.PPOConfig().to_dict(), **config)
|
||||
validate_config(config)
|
||||
|
||||
# Initialize base class.
|
||||
|
@ -241,5 +241,5 @@ def get_ppo_tf_policy(base: TFPolicyV2Type) -> TFPolicyV2Type:
|
|||
return PPOTFPolicy
|
||||
|
||||
|
||||
PPOStaticGraphTFPolicy = get_ppo_tf_policy(DynamicTFPolicyV2)
|
||||
PPOEagerTFPolicy = get_ppo_tf_policy(EagerTFPolicyV2)
|
||||
PPOTF1Policy = get_ppo_tf_policy(DynamicTFPolicyV2)
|
||||
PPOTF2Policy = get_ppo_tf_policy(EagerTFPolicyV2)
|
|
@ -2,7 +2,7 @@ import logging
|
|||
from typing import Dict, List, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import validate_config
|
||||
from ray.rllib.algorithms.ppo.ppo_tf_policy import validate_config
|
||||
from ray.rllib.evaluation.postprocessing import (
|
||||
Postprocessing,
|
||||
compute_gae_for_sample_batch,
|
||||
|
@ -39,10 +39,10 @@ class PPOTorchPolicy(
|
|||
KLCoeffMixin,
|
||||
TorchPolicyV2,
|
||||
):
|
||||
"""PyTorch policy class used with PPOTrainer."""
|
||||
"""PyTorch policy class used with PPO."""
|
||||
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config)
|
||||
config = dict(ray.rllib.algorithms.ppo.ppo.PPOConfig().to_dict(), **config)
|
||||
validate_config(config)
|
||||
|
||||
TorchPolicyV2.__init__(
|
|
@ -3,9 +3,9 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import PPOEagerTFPolicy
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF2Policy
|
||||
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
from ray.rllib.evaluation.postprocessing import (
|
||||
compute_gae_for_sample_batch,
|
||||
Postprocessing,
|
||||
|
@ -89,7 +89,7 @@ class TestPPO(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
|
||||
def test_ppo_compilation_and_schedule_mixins(self):
|
||||
"""Test whether a PPOTrainer can be built with all frameworks."""
|
||||
"""Test whether PPO can be built with all frameworks."""
|
||||
|
||||
# Build a PPOConfig object.
|
||||
config = (
|
||||
|
@ -172,7 +172,7 @@ class TestPPO(unittest.TestCase):
|
|||
# Test against all frameworks.
|
||||
for fw in framework_iterator(config):
|
||||
# Default Agent should be setup with StochasticSampling.
|
||||
trainer = ppo.PPOTrainer(config=config, env="FrozenLake-v1")
|
||||
trainer = ppo.PPO(config=config, env="FrozenLake-v1")
|
||||
# explore=False, always expect the same (deterministic) action.
|
||||
a_ = trainer.compute_single_action(
|
||||
obs, explore=False, prev_action=np.array(2), prev_reward=np.array(1.0)
|
||||
|
@ -223,7 +223,7 @@ class TestPPO(unittest.TestCase):
|
|||
)
|
||||
|
||||
for fw, sess in framework_iterator(config, session=True):
|
||||
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
trainer = ppo.PPO(config=config, env="CartPole-v0")
|
||||
policy = trainer.get_policy()
|
||||
|
||||
# Check the free log std var is created.
|
||||
|
@ -265,7 +265,7 @@ class TestPPO(unittest.TestCase):
|
|||
# Expect warning.
|
||||
print(f"Accessing learning-rate from legacy config dict: {ppo_config['lr']}")
|
||||
# Build Trainer.
|
||||
ppo_trainer = ppo.PPOTrainer(config=ppo_config, env="CartPole-v1")
|
||||
ppo_trainer = ppo.PPO(config=ppo_config, env="CartPole-v1")
|
||||
print(ppo_trainer.train())
|
||||
|
||||
def test_ppo_loss_function(self):
|
||||
|
@ -286,7 +286,7 @@ class TestPPO(unittest.TestCase):
|
|||
)
|
||||
|
||||
for fw, sess in framework_iterator(config, session=True):
|
||||
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
trainer = ppo.PPO(config=config, env="CartPole-v0")
|
||||
policy = trainer.get_policy()
|
||||
|
||||
# Check no free log std var by default.
|
||||
|
@ -313,7 +313,7 @@ class TestPPO(unittest.TestCase):
|
|||
|
||||
# Calculate actual PPO loss.
|
||||
if fw in ["tf2", "tfe"]:
|
||||
PPOEagerTFPolicy.loss(policy, policy.model, Categorical, train_batch)
|
||||
PPOTF2Policy.loss(policy, policy.model, Categorical, train_batch)
|
||||
elif fw == "torch":
|
||||
PPOTorchPolicy.loss(
|
||||
policy, policy.model, policy.dist_class, train_batch
|
|
@ -7,7 +7,7 @@ import unittest
|
|||
import ray
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
import ray.rllib.algorithms.dqn as dqn
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentPendulum
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
|
@ -104,7 +104,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
config["model"]["lstm_use_prev_reward"] = True
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
trainer = ppo.PPOTrainer(config, env="CartPole-v0")
|
||||
trainer = ppo.PPO(config, env="CartPole-v0")
|
||||
policy = trainer.get_policy()
|
||||
view_req_model = policy.model.view_requirements
|
||||
view_req_policy = policy.view_requirements
|
||||
|
@ -168,7 +168,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
config["env_config"] = {"config": {"start_at_t": 1}} # first obs is [1.0]
|
||||
|
||||
for _ in framework_iterator(config, frameworks="tf2"):
|
||||
trainer = ppo.PPOTrainer(
|
||||
trainer = ppo.PPO(
|
||||
config,
|
||||
env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv",
|
||||
)
|
||||
|
@ -322,7 +322,6 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
print(batch)
|
||||
|
||||
def test_counting_by_agent_steps(self):
|
||||
"""Test whether a PPOTrainer can be built with all frameworks."""
|
||||
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
|
||||
|
||||
num_agents = 3
|
||||
|
@ -342,7 +341,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
config["env_config"] = {"num_agents": num_agents}
|
||||
|
||||
num_iterations = 2
|
||||
trainer = ppo.PPOTrainer(config=config)
|
||||
trainer = ppo.PPO(config=config)
|
||||
results = None
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
|
|
|
@ -147,7 +147,7 @@ if __name__ == "__main__":
|
|||
raise ValueError("This example only supports APPO and PPO.")
|
||||
ppo_config = ppo.DEFAULT_CONFIG.copy()
|
||||
ppo_config.update(config)
|
||||
trainer = ppo.PPOTrainer(config=ppo_config, env=ActionMaskEnv)
|
||||
trainer = ppo.PPO(config=ppo_config, env=ActionMaskEnv)
|
||||
# run manual training loop and print results after each iteration
|
||||
for _ in range(args.stop_iters):
|
||||
result = trainer.train()
|
||||
|
|
|
@ -170,7 +170,7 @@ if __name__ == "__main__":
|
|||
raise ValueError("Only support --run PPO with --no-tune.")
|
||||
ppo_config = ppo.DEFAULT_CONFIG.copy()
|
||||
ppo_config.update(config)
|
||||
trainer = ppo.PPOTrainer(config=ppo_config, env=args.env)
|
||||
trainer = ppo.PPO(config=ppo_config, env=args.env)
|
||||
# run manual training loop and print results after each iteration
|
||||
for _ in range(args.stop_iters):
|
||||
result = trainer.train()
|
||||
|
|
|
@ -163,7 +163,7 @@ if __name__ == "__main__":
|
|||
raise ValueError("Only support --run PPO with --no-tune.")
|
||||
ppo_config = ppo.DEFAULT_CONFIG.copy()
|
||||
ppo_config.update(config)
|
||||
trainer = ppo.PPOTrainer(config=ppo_config, env=CorrelatedActionsEnv)
|
||||
trainer = ppo.PPO(config=ppo_config, env=CorrelatedActionsEnv)
|
||||
# run manual training loop and print results after each iteration
|
||||
for _ in range(args.stop_iters):
|
||||
result = trainer.train()
|
||||
|
|
|
@ -86,9 +86,9 @@ if __name__ == "__main__":
|
|||
|
||||
# Example (use `config` from the above code):
|
||||
# >> import numpy as np
|
||||
# >> from ray.rllib.agents.ppo import PPOTrainer
|
||||
# >> from ray.rllib.algorithms.ppo import PPO
|
||||
# >>
|
||||
# >> trainer = PPOTrainer(config)
|
||||
# >> trainer = PPO(config)
|
||||
# >> lstm_cell_size = config["model"]["lstm_cell_size"]
|
||||
# >> env = StatelessCartPole()
|
||||
# >> obs = env.reset()
|
||||
|
|
|
@ -20,12 +20,12 @@ import os
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.ppo.ppo import PPOTrainer
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import (
|
||||
PPOStaticGraphTFPolicy,
|
||||
PPOEagerTFPolicy,
|
||||
from ray.rllib.algorithms.ppo.ppo import PPO
|
||||
from ray.rllib.algorithms.ppo.ppo_tf_policy import (
|
||||
PPOTF1Policy,
|
||||
PPOTF2Policy,
|
||||
)
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing
|
||||
from ray.rllib.examples.env.two_step_game import TwoStepGame
|
||||
from ray.rllib.examples.models.centralized_critic_models import (
|
||||
|
@ -209,8 +209,8 @@ def get_ccppo_policy(base):
|
|||
return CCPPOTFPolicy
|
||||
|
||||
|
||||
CCPPOStaticGraphTFPolicy = get_ccppo_policy(PPOStaticGraphTFPolicy)
|
||||
CCPPOEagerTFPolicy = get_ccppo_policy(PPOEagerTFPolicy)
|
||||
CCPPOStaticGraphTFPolicy = get_ccppo_policy(PPOTF1Policy)
|
||||
CCPPOEagerTFPolicy = get_ccppo_policy(PPOTF2Policy)
|
||||
|
||||
|
||||
class CCPPOTorchPolicy(CentralizedValueMixin, PPOTorchPolicy):
|
||||
|
@ -231,8 +231,8 @@ class CCPPOTorchPolicy(CentralizedValueMixin, PPOTorchPolicy):
|
|||
)
|
||||
|
||||
|
||||
class CCTrainer(PPOTrainer):
|
||||
@override(PPOTrainer)
|
||||
class CCTrainer(PPO):
|
||||
@override(PPO)
|
||||
def get_default_policy_class(self, config):
|
||||
if config["framework"] == "torch":
|
||||
return CCPPOTorchPolicy
|
||||
|
|
|
@ -7,7 +7,7 @@ import os
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
from ray.rllib.examples.env.coin_game_non_vectorized_env import CoinGame, AsymCoinGame
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -71,7 +71,7 @@ def main(debug, stop_iters=2000, tf=False, asymmetric_env=False):
|
|||
}
|
||||
|
||||
tune_analysis = tune.run(
|
||||
PPOTrainer,
|
||||
PPO,
|
||||
config=rllib_config,
|
||||
stop=stop,
|
||||
checkpoint_freq=0,
|
||||
|
|
|
@ -187,7 +187,7 @@ if __name__ == "__main__":
|
|||
ppo_config.update(config)
|
||||
# use fixed learning rate instead of grid search (needs tune)
|
||||
ppo_config["lr"] = 1e-3
|
||||
trainer = ppo.PPOTrainer(config=ppo_config, env=SimpleCorridor)
|
||||
trainer = ppo.PPO(config=ppo_config, env=SimpleCorridor)
|
||||
# run manual training loop and print results after each iteration
|
||||
for _ in range(args.stop_iters):
|
||||
result = trainer.train()
|
||||
|
|
|
@ -3,7 +3,7 @@ import argparse
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents import ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--train-iterations", type=int, default=10)
|
||||
|
@ -11,7 +11,7 @@ parser.add_argument("--train-iterations", type=int, default=10)
|
|||
|
||||
def experiment(config):
|
||||
iterations = config.pop("train-iterations")
|
||||
train_agent = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
train_agent = ppo.PPO(config=config, env="CartPole-v0")
|
||||
checkpoint = None
|
||||
train_results = {}
|
||||
|
||||
|
@ -25,7 +25,7 @@ def experiment(config):
|
|||
|
||||
# Manual Eval
|
||||
config["num_workers"] = 0
|
||||
eval_agent = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
eval_agent = ppo.PPO(config=config, env="CartPole-v0")
|
||||
eval_agent.restore(checkpoint)
|
||||
env = eval_agent.workers.local_worker().env
|
||||
|
||||
|
@ -53,5 +53,5 @@ if __name__ == "__main__":
|
|||
tune.run(
|
||||
experiment,
|
||||
config=config,
|
||||
resources_per_trial=ppo.PPOTrainer.default_resource_request(config),
|
||||
resources_per_trial=ppo.PPO.default_resource_request(config),
|
||||
)
|
||||
|
|
|
@ -90,9 +90,9 @@ if __name__ == "__main__":
|
|||
|
||||
# Example (use `config` from the above code):
|
||||
# >> import numpy as np
|
||||
# >> from ray.rllib.agents.ppo import PPOTrainer
|
||||
# >> from ray.rllib.algorithms.ppo import PPO
|
||||
# >>
|
||||
# >> trainer = PPOTrainer(config)
|
||||
# >> trainer = PPO(config)
|
||||
# >> lstm_cell_size = config["model"]["custom_model_config"]["cell_size"]
|
||||
# >> env = RepeatAfterMeEnv({})
|
||||
# >> obs = env.reset()
|
||||
|
|
|
@ -10,7 +10,7 @@ import os
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
@ -25,7 +25,7 @@ def my_train_fn(config, reporter):
|
|||
iterations = config.pop("train-iterations", 10)
|
||||
|
||||
# Train for n iterations with high LR
|
||||
agent1 = PPOTrainer(env="CartPole-v0", config=config)
|
||||
agent1 = PPO(env="CartPole-v0", config=config)
|
||||
for _ in range(iterations):
|
||||
result = agent1.train()
|
||||
result["phase"] = 1
|
||||
|
@ -36,7 +36,7 @@ def my_train_fn(config, reporter):
|
|||
|
||||
# Train for n iterations with low LR
|
||||
config["lr"] = 0.0001
|
||||
agent2 = PPOTrainer(env="CartPole-v0", config=config)
|
||||
agent2 = PPO(env="CartPole-v0", config=config)
|
||||
agent2.restore(state)
|
||||
for _ in range(iterations):
|
||||
result = agent2.train()
|
||||
|
@ -58,5 +58,5 @@ if __name__ == "__main__":
|
|||
"num_workers": 0,
|
||||
"framework": args.framework,
|
||||
}
|
||||
resources = PPOTrainer.default_resource_request(config)
|
||||
resources = PPO.default_resource_request(config)
|
||||
tune.run(my_train_fn, resources_per_trial=resources, config=config)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents import ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
|
||||
|
||||
class SimpleCorridor(gym.Env):
|
||||
|
@ -35,7 +35,7 @@ config = {
|
|||
},
|
||||
}
|
||||
|
||||
trainer = ppo.PPOTrainer(config=config)
|
||||
trainer = ppo.PPO(config=config)
|
||||
for _ in range(3):
|
||||
print(trainer.train())
|
||||
# __rllib-custom-gym-env-end__
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# __rllib-in-60s-begin__
|
||||
# Import the RL algorithm (Trainer) we would like to use.
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
|
||||
# Configure the algorithm.
|
||||
config = {
|
||||
|
@ -28,7 +28,7 @@ config = {
|
|||
}
|
||||
|
||||
# Create our RLlib Trainer.
|
||||
trainer = PPOTrainer(config=config)
|
||||
trainer = PPO(config=config)
|
||||
|
||||
# Run it for n training iterations. A training iteration includes
|
||||
# parallel sample collection by the environment workers as well as
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# __quick_start_begin__
|
||||
import gym
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
|
||||
|
||||
# Define your problem using python and openAI's gym API:
|
||||
|
@ -49,7 +49,7 @@ class SimpleCorridor(gym.Env):
|
|||
|
||||
|
||||
# Create an RLlib Trainer instance.
|
||||
trainer = PPOTrainer(
|
||||
trainer = PPO(
|
||||
config={
|
||||
# Env class to use (here: our gym.Env sub-class from above).
|
||||
"env": SimpleCorridor,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import gym
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
|
||||
|
||||
# Define your problem using python and openAI's gym API:
|
||||
|
@ -51,7 +51,7 @@ class ParrotEnv(gym.Env):
|
|||
|
||||
# Create an RLlib Trainer instance to learn how to act in the above
|
||||
# environment.
|
||||
trainer = PPOTrainer(
|
||||
trainer = PPO(
|
||||
config={
|
||||
# Env class to use (here: our gym.Env sub-class from above).
|
||||
"env": ParrotEnv,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import numpy as np
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
import onnxruntime
|
||||
import os
|
||||
import shutil
|
||||
|
@ -24,7 +24,7 @@ test_data = {
|
|||
|
||||
# Start Ray and initialize a PPO trainer
|
||||
ray.init()
|
||||
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
trainer = ppo.PPO(config=config, env="CartPole-v0")
|
||||
|
||||
# You could train the model here
|
||||
# trainer.train()
|
||||
|
|
|
@ -2,7 +2,7 @@ from distutils.version import LooseVersion
|
|||
|
||||
import numpy as np
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
import onnxruntime
|
||||
import os
|
||||
import shutil
|
||||
|
@ -28,7 +28,7 @@ test_data = {
|
|||
|
||||
# Start Ray and initialize a PPO trainer
|
||||
ray.init()
|
||||
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
trainer = ppo.PPO(config=config, env="CartPole-v0")
|
||||
|
||||
# You could train the model here
|
||||
# trainer.train()
|
||||
|
|
|
@ -104,8 +104,8 @@ if __name__ == "__main__":
|
|||
# Note: The above GPU settings should also work in case you are not
|
||||
# running via tune.run(), but instead do:
|
||||
|
||||
# >> from ray.rllib.agents.ppo import PPOTrainer
|
||||
# >> trainer = PPOTrainer(config=config)
|
||||
# >> from ray.rllib.algorithms.ppo import PPO
|
||||
# >> trainer = PPO(config=config)
|
||||
# >> for _ in range(10):
|
||||
# >> results = trainer.train()
|
||||
# >> print(results)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import numpy as np
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
@ -40,7 +40,7 @@ if __name__ == "__main__":
|
|||
ModelCatalog.register_custom_model("my_torch_model", MyCustomModel)
|
||||
|
||||
# Create the Trainer.
|
||||
trainer = ppo.PPOTrainer(
|
||||
trainer = ppo.PPO(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"framework": "torch",
|
||||
|
|
|
@ -14,10 +14,10 @@ import os
|
|||
|
||||
import ray
|
||||
from ray.rllib.algorithms.dqn import DQNTrainer, DQNTFPolicy, DQNTorchPolicy
|
||||
from ray.rllib.agents.ppo import (
|
||||
PPOTrainer,
|
||||
PPOStaticGraphTFPolicy,
|
||||
PPOEagerTFPolicy,
|
||||
from ray.rllib.algorithms.ppo import (
|
||||
PPO,
|
||||
PPOTF1Policy,
|
||||
PPOTF2Policy,
|
||||
PPOTorchPolicy,
|
||||
)
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
|
@ -66,9 +66,9 @@ if __name__ == "__main__":
|
|||
if framework == "torch":
|
||||
return PPOTorchPolicy
|
||||
elif framework == "tf":
|
||||
return PPOStaticGraphTFPolicy
|
||||
return PPOTF1Policy
|
||||
else:
|
||||
return PPOEagerTFPolicy
|
||||
return PPOTF2Policy
|
||||
elif algorithm == "DQN":
|
||||
if framework == "torch":
|
||||
return DQNTorchPolicy
|
||||
|
@ -100,7 +100,7 @@ if __name__ == "__main__":
|
|||
else:
|
||||
return "dqn_policy"
|
||||
|
||||
ppo_trainer = PPOTrainer(
|
||||
ppo_trainer = PPO(
|
||||
env="multi_agent_cartpole",
|
||||
config={
|
||||
"multiagent": {
|
||||
|
|
|
@ -13,7 +13,7 @@ import argparse
|
|||
import os
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
|
@ -76,10 +76,10 @@ def get_cli_args():
|
|||
|
||||
|
||||
# The modified Trainer class we will use. This is the exact same
|
||||
# as a PPOTrainer, but with the additional default_resource_request
|
||||
# as a PPO, but with the additional default_resource_request
|
||||
# override, telling tune that it's ok (not mandatory) to place our
|
||||
# n remote envs on a different node (each env using 1 CPU).
|
||||
class PPOTrainerRemoteInference(PPOTrainer):
|
||||
class PPOTrainerRemoteInference(PPO):
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def default_resource_request(cls, config):
|
||||
|
|
|
@ -16,7 +16,7 @@ import random
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
|
@ -102,7 +102,7 @@ if __name__ == "__main__":
|
|||
print(f".. best checkpoint was: {best_checkpoint}")
|
||||
|
||||
# Create a new dummy Trainer to "fix" our checkpoint.
|
||||
new_trainer = PPOTrainer(config=config)
|
||||
new_trainer = PPO(config=config)
|
||||
# Get untrained weights for all policies.
|
||||
untrained_weights = new_trainer.get_weights()
|
||||
# Restore all policies from checkpoint.
|
||||
|
|
|
@ -7,7 +7,7 @@ Run example: python sb2rllib_rllib_example.py
|
|||
"""
|
||||
import gym
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
|
||||
# settings used for both stable baselines and rllib
|
||||
env_name = "CartPole-v1"
|
||||
|
@ -30,7 +30,7 @@ checkpoint_path = analysis.get_best_checkpoint(trial=analysis.get_best_trial())
|
|||
print(f"Trained model saved at {checkpoint_path}")
|
||||
|
||||
# load and restore model
|
||||
agent = ppo.PPOTrainer(env=env_name)
|
||||
agent = ppo.PPO(env=env_name)
|
||||
agent.restore(checkpoint_path)
|
||||
print(f"Agent loaded from saved model at {checkpoint_path}")
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ import re
|
|||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
from ray.rllib.examples.self_play_with_open_spiel import ask_user_for_action
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
|
||||
|
@ -340,7 +340,7 @@ if __name__ == "__main__":
|
|||
# human on command line.
|
||||
if args.num_episodes_human_play > 0:
|
||||
num_episodes = 0
|
||||
trainer = PPOTrainer(config=dict(config, **{"explore": False}))
|
||||
trainer = PPO(config=dict(config, **{"explore": False}))
|
||||
if args.from_checkpoint:
|
||||
trainer.restore(args.from_checkpoint)
|
||||
else:
|
||||
|
|
|
@ -28,7 +28,7 @@ import sys
|
|||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
|
@ -242,7 +242,7 @@ if __name__ == "__main__":
|
|||
# human on command line.
|
||||
if args.num_episodes_human_play > 0:
|
||||
num_episodes = 0
|
||||
trainer = PPOTrainer(config=dict(config, **{"explore": False}))
|
||||
trainer = PPO(config=dict(config, **{"explore": False}))
|
||||
if args.from_checkpoint:
|
||||
trainer.restore(args.from_checkpoint)
|
||||
else:
|
||||
|
|
|
@ -19,7 +19,7 @@ from pprint import pformat
|
|||
import ray
|
||||
from ray import tune
|
||||
|
||||
from ray.rllib.agents.ppo import ppo
|
||||
from ray.rllib.algorithms.ppo import ppo
|
||||
from ray.rllib.examples.simulators.sumo import marlenvironment
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
|
||||
|
@ -78,7 +78,7 @@ if __name__ == "__main__":
|
|||
tune.register_env("sumo_test_env", marlenvironment.env_creator)
|
||||
|
||||
# Algorithm.
|
||||
policy_class = ppo.PPOStaticGraphTFPolicy
|
||||
policy_class = ppo.PPOTF1Policy
|
||||
config = ppo.DEFAULT_CONFIG
|
||||
config["framework"] = "tf"
|
||||
config["gamma"] = 0.99
|
||||
|
|
|
@ -2,7 +2,7 @@ import argparse
|
|||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
||||
from ray.rllib.examples.models.trajectory_view_utilizing_models import (
|
||||
FrameStackingCartPoleModel,
|
||||
|
@ -94,7 +94,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
checkpoint_path = checkpoints[0][0]
|
||||
trainer = PPOTrainer(config)
|
||||
trainer = PPO(config)
|
||||
trainer.restore(checkpoint_path)
|
||||
|
||||
# Inference loop.
|
||||
|
|
|
@ -15,9 +15,9 @@ from ray.rllib.agents.trainer import Trainer
|
|||
from ray.rllib.algorithms.dqn.dqn import DEFAULT_CONFIG as DQN_CONFIG
|
||||
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
|
||||
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
|
||||
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG as PPO_CONFIG
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import PPOStaticGraphTFPolicy
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
from ray.rllib.algorithms.ppo.ppo import DEFAULT_CONFIG as PPO_CONFIG
|
||||
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
|
||||
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
|
||||
from ray.rllib.execution.train_ops import train_one_step
|
||||
|
@ -179,9 +179,7 @@ if __name__ == "__main__":
|
|||
# policy configs, we have to explicitly set it in the multiagent config:
|
||||
policies = {
|
||||
"ppo_policy": (
|
||||
PPOTorchPolicy
|
||||
if args.torch or args.mixed_torch_tf
|
||||
else PPOStaticGraphTFPolicy,
|
||||
PPOTorchPolicy if args.torch or args.mixed_torch_tf else PPOTF1Policy,
|
||||
None,
|
||||
None,
|
||||
ppo_config,
|
||||
|
|
|
@ -283,8 +283,8 @@ class ModelV2:
|
|||
h5_file: The h5 file name to import weights from.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.ppo import PPOTrainer
|
||||
>>> trainer = PPOTrainer(...) # doctest: +SKIP
|
||||
>>> from ray.rllib.algorithms.ppo import PPO
|
||||
>>> trainer = PPO(...) # doctest: +SKIP
|
||||
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5") # doctest: +SKIP
|
||||
>>> for _ in range(10): # doctest: +SKIP
|
||||
>>> trainer.train() # doctest: +SKIP
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
from ray.rllib.examples.models.modelv3 import RNNModel
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
|
||||
|
@ -71,7 +71,7 @@ class TestModels(unittest.TestCase):
|
|||
},
|
||||
"num_workers": 0,
|
||||
}
|
||||
trainer = ppo.PPOTrainer(config=config)
|
||||
trainer = ppo.PPO(config=config)
|
||||
for _ in range(2):
|
||||
results = trainer.train()
|
||||
print(results)
|
||||
|
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import (
|
||||
DictFlatteningPreprocessor,
|
||||
|
@ -67,7 +67,7 @@ class TestPreprocessors(unittest.TestCase):
|
|||
num_iterations = 1
|
||||
# Only supported for tf so far.
|
||||
for _ in framework_iterator(config):
|
||||
trainer = ppo.PPOTrainer(config=config)
|
||||
trainer = ppo.PPO(config=config)
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
check_train_results(results)
|
||||
|
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
import ray
|
||||
import ray.rllib.algorithms.dqn as dqn
|
||||
import ray.rllib.algorithms.pg as pg
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
import ray.rllib.algorithms.sac as sac
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
|
@ -164,14 +164,14 @@ class TestComputeLogLikelihood(unittest.TestCase):
|
|||
config["model"]["fcnet_hiddens"] = [10]
|
||||
config["model"]["fcnet_activation"] = "linear"
|
||||
prev_a = np.array([0.0])
|
||||
do_test_log_likelihood(ppo.PPOTrainer, config, prev_a, continuous=True)
|
||||
do_test_log_likelihood(ppo.PPO, config, prev_a, continuous=True)
|
||||
|
||||
def test_ppo_discr(self):
|
||||
"""Tests PPO's (discr. actions) compute_log_likelihoods method."""
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
config["seed"] = 42
|
||||
prev_a = np.array(0)
|
||||
do_test_log_likelihood(ppo.PPOTrainer, config, prev_a)
|
||||
do_test_log_likelihood(ppo.PPO, config, prev_a)
|
||||
|
||||
def test_sac_cont(self):
|
||||
"""Tests SAC's (cont. actions) compute_log_likelihoods method."""
|
||||
|
|
|
@ -5,7 +5,7 @@ import queue
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import PPOStaticGraphTFPolicy
|
||||
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER
|
||||
|
@ -38,13 +38,13 @@ def iter_list(values):
|
|||
def make_workers(n):
|
||||
local = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_spec=PPOStaticGraphTFPolicy,
|
||||
policy_spec=PPOTF1Policy,
|
||||
rollout_fragment_length=100,
|
||||
)
|
||||
remotes = [
|
||||
RolloutWorker.as_remote().remote(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_spec=PPOStaticGraphTFPolicy,
|
||||
policy_spec=PPOTF1Policy,
|
||||
rollout_fragment_length=100,
|
||||
)
|
||||
for _ in range(n)
|
||||
|
|
|
@ -3,7 +3,7 @@ import pickle
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
from ray.rllib.examples.env.debug_counter_env import DebugCounterEnv
|
||||
from ray.rllib.examples.models.rnn_spy_model import RNNSpyModel
|
||||
from ray.rllib.models import ModelCatalog
|
||||
|
@ -176,7 +176,7 @@ class TestRNNSequencing(unittest.TestCase):
|
|||
def test_simple_optimizer_sequencing(self):
|
||||
ModelCatalog.register_custom_model("rnn", RNNSpyModel)
|
||||
register_env("counter", lambda _: DebugCounterEnv())
|
||||
ppo = PPOTrainer(
|
||||
ppo = PPO(
|
||||
env="counter",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
|
@ -244,7 +244,7 @@ class TestRNNSequencing(unittest.TestCase):
|
|||
def test_minibatch_sequencing(self):
|
||||
ModelCatalog.register_custom_model("rnn", RNNSpyModel)
|
||||
register_env("counter", lambda _: DebugCounterEnv())
|
||||
ppo = PPOTrainer(
|
||||
ppo = PPO(
|
||||
env="counter",
|
||||
config={
|
||||
"shuffle_sequences": False, # for deterministic testing
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import os
|
||||
import pytest
|
||||
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@ def test_dont_import_tf_error():
|
|||
with pytest.raises(
|
||||
ImportError, match="However, there was no installation found."
|
||||
):
|
||||
ppo.PPOTrainer(config, env="CartPole-v1")
|
||||
ppo.PPO(config, env="CartPole-v1")
|
||||
|
||||
|
||||
def test_dont_import_torch_error():
|
||||
|
@ -29,7 +29,7 @@ def test_dont_import_torch_error():
|
|||
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
|
||||
config = {"framework": "torch"}
|
||||
with pytest.raises(ImportError, match="However, there was no installation found."):
|
||||
ppo.PPOTrainer(config, env="CartPole-v1")
|
||||
ppo.PPO(config, env="CartPole-v1")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
import pytest
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents import ppo
|
||||
import ray.rllib.algorithms.ppo as ppo
|
||||
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
|
@ -29,7 +29,7 @@ class TestRayClient(unittest.TestCase):
|
|||
"num_workers": 0,
|
||||
"framework": "tf",
|
||||
}
|
||||
resources = ppo.PPOTrainer.default_resource_request(config)
|
||||
resources = ppo.PPO.default_resource_request(config)
|
||||
from ray.rllib.examples.custom_train_fn import my_train_fn
|
||||
|
||||
tune.run(my_train_fn, resources_per_trial=resources, config=config)
|
||||
|
@ -63,7 +63,7 @@ class TestRayClient(unittest.TestCase):
|
|||
tune.run(
|
||||
experiment,
|
||||
config=config,
|
||||
resources_per_trial=ppo.PPOTrainer.default_resource_request(config),
|
||||
resources_per_trial=ppo.PPO.default_resource_request(config),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -23,10 +23,10 @@ Training example via RLlib CLI:
|
|||
rllib train --run DQN --env CartPole-v0
|
||||
|
||||
Grid search example via RLlib CLI:
|
||||
rllib train -f tuned_examples/cartpole-grid-search-example.yaml
|
||||
rllib train -f tuned_examples/cartpole-ppo-grid-search-example.yaml
|
||||
|
||||
Grid search example via executable:
|
||||
./train.py -f tuned_examples/cartpole-grid-search-example.yaml
|
||||
./train.py -f tuned_examples/cartpole-ppo-grid-search-example.yaml
|
||||
|
||||
Note that -f overrides all other trial-specific command-line options.
|
||||
"""
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue