[RLlib] Trainer to Algorithm renaming. (#25539)

This commit is contained in:
Sven Mika 2022-06-11 15:10:39 +02:00 committed by GitHub
parent 0c527b4502
commit 130b7eeaba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
240 changed files with 6667 additions and 6124 deletions

View file

@ -2,7 +2,7 @@
.. include:: /_includes/rllib/we_are_hiring.rst
.. TODO: We need trainers, environments, algorithms, policies, models here. Likely in that order.
.. TODO: We need algorithms, environments, policies, models here. Likely in that order.
Execution plans are not a "core" concept for users. Sample batches should probably also be left out.
.. _rllib-core-concepts:
@ -10,31 +10,59 @@
Key Concepts
============
On this page, we'll cover the key concepts to help you understand how RLlib works and how to use it.
In RLlib you use `trainers` to train `algorithms`.
These algorithms use `policies` to select actions for your agents.
Given a policy, `evaluation` of a policy produces `sample batches` of experiences.
You can also customize the `execution plans` of your RL experiments.
On this page, we'll cover the key concepts to help you understand how RLlib works and
how to use it. In RLlib you use ``algorithms`` to learn in problem environments.
These algorithms use ``policies`` to select actions for your agents.
Given a policy, ``evaluation`` of a policy produces ``sample batches`` of experiences.
You can also customize the ``training_iteration``\s of your RL experiments.
Trainers
--------
Algorithms
----------
Trainers bring all RLlib components together, making algorithms accessible via RLlib's Python API and its command line interface (CLI).
They manage algorithm configuration, setup of the rollout workers and optimizer, and collection of training metrics.
Trainers also implement the :ref:`Tune Trainable API <tune-60-seconds>` for easy experiment management.
Algorithms bring all RLlib components together, making learning of different tasks
accessible via RLlib's Python API and its command line interface (CLI).
Each ``Algorithm`` class is managed by its respective ``AlgorithmConfig``, for example to
configure a ``PPO`` instance, you should use the ``PPOConfig`` class.
An ``Algorithm`` sets up its rollout workers and optimizers, and collects training metrics.
``Algorithms`` also implement the :ref:`Tune Trainable API <tune-60-seconds>` for
easy experiment management.
You have three ways to interact with a trainer. You can use the basic Python API or the command line to train it, or you
You have three ways to interact with an algorithm. You can use the basic Python API or the command line to train it, or you
can use Ray Tune to tune hyperparameters of your reinforcement learning algorithm.
The following example shows three equivalent ways of interacting with the ``PPO`` Trainer,
The following example shows three equivalent ways of interacting with ``PPO``,
which implements the proximal policy optimization algorithm in RLlib.
.. tabbed:: Basic RLlib Trainer
.. tabbed:: Basic RLlib Algorithm
.. code-block:: python
trainer = PPO(env="CartPole-v0", config={"train_batch_size": 4000})
# Configure.
from ray.rllib.algorithms import PPOConfig
config = PPOConfig().environment("CartPole-v0").training(train_batch_size=4000)
# Build.
algo = config.build()
# Train.
while True:
print(trainer.train())
print(algo.train())
.. tabbed:: RLlib Algorithms and Tune
.. code-block:: python
from ray import tune
# Configure.
from ray.rllib.algorithms import PPOConfig
config = PPOConfig().environment("CartPole-v0").training(train_batch_size=4000)
# Train via Ray Tune.
# Note that Ray Tune does not yet support AlgorithmConfig objects, hence
# we need to convert back to old-style config dicts.
tune.run("PPO", config=config.to_dict())
.. tabbed:: RLlib Command Line
@ -42,17 +70,9 @@ which implements the proximal policy optimization algorithm in RLlib.
rllib train --run=PPO --env=CartPole-v0 --config='{"train_batch_size": 4000}'
.. tabbed:: RLlib Tune Trainer
.. code-block:: python
from ray import tune
tune.run(PPO, config={"env": "CartPole-v0", "train_batch_size": 4000})
RLlib `Trainer classes <rllib-concepts.html#trainers>`__ coordinate the distributed workflow of running rollouts and optimizing policies.
Trainer classes leverage parallel iterators to implement the desired computation pattern.
RLlib `Algorithm classes <rllib-concepts.html#algorithms>`__ coordinate the distributed workflow of running rollouts and optimizing policies.
Algorithm classes leverage parallel iterators to implement the desired computation pattern.
The following figure shows *synchronous sampling*, the simplest of `these patterns <rllib-algorithms.html>`__:
.. figure:: images/a2c-arch.svg
@ -180,13 +200,13 @@ of a sequence of repeating steps, or *dataflow*, of:
2. ``ConcatBatches``: The experiences are concatenated into one batch for training.
3. ``TrainOneStep``: Take a gradient step with respect to the policy loss, and update the worker weights.
In code, this dataflow can be expressed as the following execution plan, which is a static method that can be overridden in your custom Trainer sub-classes to define new algorithms.
In code, this dataflow can be expressed as the following execution plan, which is a static method that can be overridden in your custom Algorithm sub-classes to define new algorithms.
It takes in a ``WorkerSet`` and config, and returns an iterator over training results:
.. code-block:: python
@staticmethod
def execution_plan(workers: WorkerSet, config: TrainerConfigDict):
def execution_plan(workers: WorkerSet, config: AlgorithmConfigDict):
# type: LocalIterator[SampleBatchType]
rollouts = ParallelRollouts(workers, mode="bulk_sync")
@ -204,7 +224,7 @@ As you can see, each step returns an *iterator* over objects (if you're unfamili
The reason it is a ``LocalIterator`` is that, though it is based on a parallel computation, the iterator has been turned into one that can be consumed locally in sequence by the program.
A couple other points to note:
- The reason the plan returns an iterator over training results, is that ``trainer.train()`` is pulling results from this iterator to return as the result of the train call.
- The reason the plan returns an iterator over training results, is that ``algorithm.train()`` is pulling results from this iterator to return as the result of the train call.
- The rollout workers have been already created ahead of time in the ``WorkerSet``, so the execution plan function is only defining a sequence of operations over the results of the rollouts.
These iterators represent the infinite stream of data items that can be produced from the dataflow.
@ -236,7 +256,8 @@ You'll see output like this on the console:
(pid=6555) I saw <class 'ray.rllib.policy.sample_batch.SampleBatch'>
(pid=6555) I saw <class 'ray.rllib.policy.sample_batch.SampleBatch'>
It is important to understand that the iterators of an execution plan are evaluated *lazily*. This means that no computation happens until the `trainer <#trainers>`__ tries to read the next item from the iterator (i.e., get the next training result for a ``Trainer.train()`` call).
It is important to understand that the iterators of an execution plan are evaluated *lazily*. This means that no computation happens until the `algorithm <#algorithms>`__ tries to read the next item from the iterator
(i.e., get the next training result for a ``Algorithms.train()`` call).
Execution Plan Concepts
~~~~~~~~~~~~~~~~~~~~~~~
@ -282,7 +303,7 @@ Examples
.. code-block:: python
def execution_plan(workers: WorkerSet, config: TrainerConfigDict):
def execution_plan(workers: WorkerSet, config: AlgorithmConfigDict):
# type: LocalIterator[(ModelGradients, int)]
grads = AsyncGradients(workers)
@ -300,7 +321,7 @@ Examples
.. code-block:: python
def execution_plan(workers: WorkerSet, config: TrainerConfigDict):
def execution_plan(workers: WorkerSet, config: AlgorithmConfigDict):
# Construct a replay buffer.
replay_buffer = LocalReplayBuffer(...)

View file

@ -24,8 +24,8 @@ prep.transform(env.reset()).shape
import numpy as np
from ray.rllib.algorithms.ppo import PPO
trainer = PPO(env="CartPole-v0", config={"framework": "tf2", "num_workers": 0})
policy = trainer.get_policy()
algo = PPO(env="CartPole-v0", config={"framework": "tf2", "num_workers": 0})
policy = algo.get_policy()
# <ray.rllib.policy.eager_tf_policy.PPOTFPolicy_eager object at 0x7fd020165470>
# Run a forward pass to get model output logits. Note that complex observations
@ -82,8 +82,8 @@ _____________________________________________________________________
import numpy as np
from ray.rllib.algorithms.dqn import DQN
trainer = DQN(env="CartPole-v0", config={"framework": "tf2"})
model = trainer.get_policy().model
algo = DQN(env="CartPole-v0", config={"framework": "tf2"})
model = algo.get_policy().model
# <ray.rllib.models.catalog.FullyConnectedNetwork_as_DistributionalQModel ...>
# List of all model variables

View file

@ -67,7 +67,7 @@ To be able to run our Atari examples, you should also install:
After these quick pip installs, you can start coding against RLlib.
Here is an example of running a PPO Trainer on the "`Taxi domain <https://www.gymlibrary.ml/environments/toy_text/taxi/>`_"
Here is an example of running PPO on the "`Taxi domain <https://www.gymlibrary.ml/environments/toy_text/taxi/>`_"
for a few training iterations, then perform a single evaluation loop
(with rendering enabled):

View file

@ -0,0 +1,56 @@
.. algorithm-reference-docs:
Algorithm API
=============
The :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` class is the highest-level API in RLlib.
It allows you to train and evaluate policies, save an experiment's progress and restore from
a prior saved experiment when continuing an RL run.
:py:class:`~ray.rllib.algorithms.algorithm.Algorithm` is a sub-class
of :py:class:`~ray.tune.Trainable`
and thus fully supports distributed hyperparameter tuning for RL.
.. https://docs.google.com/drawings/d/1J0nfBMZ8cBff34e-nSPJZMM1jKOuUL11zFJm6CmWtJU/edit
.. figure:: ../images/trainer_class_overview.svg
:align: left
**A typical RLlib Algorithm object:** The components sitting inside an Algorithm are
normally N :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet`s
(each consisting of one local :py:class:`~ray.rllib.evaluation.RolloutWorker`
and zero or more \@ray.remote
:py:class:`~ray.rllib.evaluation.RolloutWorker`s),
a set of :py:class:`~ray.rllib.policy.Policy`(ies)
and their NN models per worker, and a (already vectorized)
RLlib :py:class:`~ray.rllib.env.base_env.BaseEnv` per worker.
Building Custom Algorithm Classes
---------------------------------
.. warning::
As of Ray >= 1.9, it is no longer recommended to use the `build_trainer()` utility
function for creating custom Algorithm sub-classes.
Instead, follow the simple guidelines here for directly sub-classing from
:py:class:`~ray.rllib.algorithms.algorithm.Algorithm`.
In order to create a custom Algorithm, sub-class the
:py:class:`~ray.rllib.algorithms.algorithm.Algorithm` class
and override one or more of its methods. Those are in particular:
* :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.get_default_config`
* :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.validate_config`
* :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.get_default_policy_class`
* :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.setup`
* :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.training_iteration`
`See here for an example on how to override Algorithm <https://github.com/ray-project/ray/blob/master/rllib/algorithms/pg/pg.py>`_.
Algorithm base class (ray.rllib.algorithms.algorithm.Algorithm)
---------------------------------------------------------------
.. autoclass:: ray.rllib.algorithms.algorithm.Algorithm
:special-members: __init__
:members:
.. static-members: get_default_config, execution_plan

View file

@ -6,7 +6,7 @@ Evaluation and Environment Rollout
Data ingest via either environment rollouts or other data-generating methods
(e.g. reading from offline files) is done in RLlib by :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker`s,
which sit inside a :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet`
(together with other parallel ``RolloutWorkers``) in the RLlib :py:class:`~ray.rllib.agents.trainer.Trainer`
(together with other parallel ``RolloutWorkers``) in the RLlib :py:class:`~ray.rllib.algorithms.algorithm.Algorithm`
(under the ``self.workers`` property):
@ -15,7 +15,7 @@ which sit inside a :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet`
:width: 600
:align: left
**A typical RLlib WorkerSet setup inside an RLlib Trainer:** Each :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet` contains
**A typical RLlib WorkerSet setup inside an RLlib Algorithm:** Each :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet` contains
exactly one local :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker` object and n ray remote
:py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker` (ray actors).
The workers contain a policy map (with one or more policies), and - in case a simulator
@ -23,7 +23,7 @@ which sit inside a :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet`
(containing m sub-environments) and a :py:class:`~ray.rllib.evaluation.sampler.SamplerInput` (either synchronous or asynchronous) which controls
the environment data collection loop.
In the online (environment is available) as well as the offline case (no environment),
:py:class:`~ray.rllib.agents.trainer.Trainer` uses the :py:meth:`~ray.rllib.evaluation.rollout_worker.RolloutWorker.sample` method to
:py:class:`~ray.rllib.algorithms.algorithm.Algorithm` uses the :py:meth:`~ray.rllib.evaluation.rollout_worker.RolloutWorker.sample` method to
get :py:class:`~ray.rllib.policy.sample_batch.SampleBatch` objects for training.

View file

@ -9,11 +9,11 @@ The Policies are used to calculate actions for the next environment steps, losse
model updates, and other functionalities covered by RLlib's :py:class:`~ray.rllib.policy.policy.Policy` API.
A mapping function is used by episode objects to map AgentIDs produced by the environment to one of the PolicyIDs.
It is possible to add and remove policies to/from the :py:class:`~ray.rllib.agents.trainer.Trainer`'s workers at any given time
It is possible to add and remove policies to/from the :py:class:`~ray.rllib.algorithms.algorithm.Algorithm`'s workers at any given time
(even within an ongoing episode) as well as to change the policy mapping function.
See the Trainer's methods: :py:meth:`~ray.rllib.agents.trainer.Trainer.add_policy`,
:py:meth:`~ray.rllib.agents.trainer.Trainer.remove_policy`, and
:py:meth:`~ray.rllib.agents.trainer.Trainer.change_policy_mapping_fn` for more details.
See the Algorithm's methods: :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.add_policy`,
:py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.remove_policy`, and
:py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.change_policy_mapping_fn` for more details.
.. autoclass:: ray.rllib.policy.policy_map.PolicyMap
:members:

View file

@ -4,13 +4,13 @@ RolloutWorker
=============
RolloutWorkers are used as ``@ray.remote`` actors to collect and return samples
from environments or offline files in parallel. An RLlib :py:class:`~ray.rllib.agents.trainer.Trainer` usually has
from environments or offline files in parallel. An RLlib :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` usually has
``num_workers`` :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker`s plus a
single "local" :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker` (not ``@ray.remote``) in
its :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet` under ``self.workers``.
Depending on its evaluation config settings, an additional :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet` with
:py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker`s for evaluation may be present in the :py:class:`~ray.rllib.agents.trainer.Trainer`
:py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker`s for evaluation may be present in the :py:class:`~ray.rllib.algorithms.algorithm.Algorithm`
under ``self.evaluation_workers``.
.. autoclass:: ray.rllib.evaluation.rollout_worker.RolloutWorker

View file

@ -20,7 +20,7 @@ If you think there is anything missing, please open an issue on `Github`_.
:maxdepth: 2
env.rst
trainer.rst
algorithm.rst
policy.rst
models.rst
evaluation.rst

View file

@ -1,57 +0,0 @@
.. _trainer-reference-docs:
Trainer API
===========
The :py:class:`~ray.rllib.agents.trainer.Trainer` class is the highest-level API in RLlib.
It allows you to train and evaluate policies, save an experiment's progress and restore from
a prior saved experiment when continuing an RL run.
:py:class:`~ray.rllib.agents.trainer.Trainer` is a sub-class
of :py:class:`~ray.tune.Trainable`
and thus fully supports distributed hyperparameter tuning for RL.
.. https://docs.google.com/drawings/d/1J0nfBMZ8cBff34e-nSPJZMM1jKOuUL11zFJm6CmWtJU/edit
.. figure:: ../images/trainer_class_overview.svg
:align: left
**A typical RLlib Trainer object:** The components sitting inside a Trainer are
normally N :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet`s
(each consisting of one local :py:class:`~ray.rllib.evaluation.RolloutWorker`
and zero or more \@ray.remote
:py:class:`~ray.rllib.evaluation.RolloutWorker`s),
a set of :py:class:`~ray.rllib.policy.Policy`(ies)
and their NN models per worker, and a (already vectorized)
RLlib :py:class:`~ray.rllib.env.base_env.BaseEnv` per worker.
Building Custom Trainer Classes
-------------------------------
.. warning::
As of Ray >= 1.9, it is no longer recommended to use the `build_trainer()` utility
function for creating custom Trainer sub-classes.
Instead, follow the simple guidelines here for directly sub-classing from
:py:class:`~ray.rllib.agents.trainer.Trainer`.
In order to create a custom Trainer, sub-class the
:py:class:`~ray.rllib.agents.trainer.Trainer` class
and override one or more of its methods. Those are in particular:
* :py:meth:`~ray.rllib.agents.trainer.Trainer.get_default_config`
* :py:meth:`~ray.rllib.agents.trainer.Trainer.validate_config`
* :py:meth:`~ray.rllib.agents.trainer.Trainer.get_default_policy_class`
* :py:meth:`~ray.rllib.agents.trainer.Trainer.setup`
* :py:meth:`~ray.rllib.agents.trainer.Trainer.step_attempt`
* :py:meth:`~ray.rllib.agents.trainer.Trainer.execution_plan`
`See here for an example on how to override Trainer <https://github.com/ray-project/ray/blob/master/rllib/algorithms/pg/pg.py>`_.
Trainer base class (ray.rllib.agents.trainer.Trainer)
-----------------------------------------------------
.. autoclass:: ray.rllib.agents.trainer.Trainer
:special-members: __init__
:members:
.. static-members: get_default_config, execution_plan

View file

@ -207,7 +207,7 @@ Decentralized Distributed Proximal Policy Optimization (DD-PPO)
|pytorch|
`[paper] <https://arxiv.org/abs/1911.00357>`__
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/ddppo/ddppo.py>`__
Unlike APPO or PPO, with DD-PPO policy improvement is no longer done centralized in the trainer process. Instead, gradients are computed remotely on each rollout worker and all-reduced at each mini-batch using `torch distributed <https://pytorch.org/docs/stable/distributed.html>`__. This allows each worker's GPU to be used both for sampling and for training.
Unlike APPO or PPO, with DD-PPO policy improvement is no longer done centralized in the algorithm process. Instead, gradients are computed remotely on each rollout worker and all-reduced at each mini-batch using `torch distributed <https://pytorch.org/docs/stable/distributed.html>`__. This allows each worker's GPU to be used both for sampling and for training.
.. tip::
@ -895,7 +895,7 @@ Tuned examples:
**Activating Curiosity**
The curiosity plugin can be easily activated by specifying it as the Exploration class to-be-used
in the main Trainer config. Most of its parameters usually do not have to be specified
in the main Algorithm config. Most of its parameters usually do not have to be specified
as the module uses the values from the paper by default. For example:
.. code-block:: python
@ -933,7 +933,7 @@ In such environments, agents have to navigate (and change the underlying state o
For example, the task could be to find a key in some room, pick it up, find a matching door (matching the color of the key), and eventually unlock this door with the key to reach a goal state,
all the while not seeing any rewards.
Such problems are impossible to solve with standard RL exploration methods like epsilon-greedy or stochastic sampling.
The Curiosity module - when configured as the Exploration class to use via the Trainer's config (see above on how to do this) - automatically adds three simple models to the Policy's ``self.model``:
The Curiosity module - when configured as the Exploration class to use via the Algorithm's config (see above on how to do this) - automatically adds three simple models to the Policy's ``self.model``:
a) a latent space learning ("feature") model, taking an environment observation and outputting a latent vector, which represents this observation and
b) a "forward" model, predicting the next latent vector, given the current observation vector and an action to take next.
c) a so-called "inverse" net, only used to train the "feature" net. The inverse net tries to predict the action taken between two latent vectors (obs and next obs).
@ -963,7 +963,7 @@ Examples:
**Activating RE3**
The RE3 plugin can be easily activated by specifying it as the Exploration class to-be-used
in the main Trainer config and inheriting the `RE3UpdateCallbacks` as shown in this `example <https://github.com/ray-project/ray/blob/c9c3f0745a9291a4de0872bdfa69e4ffdfac3657/rllib/utils/exploration/tests/test_random_encoder.py#L35>`__. Most of its parameters usually do not have to be specified as the module uses the values from the paper by default. For example:
in the main Algorithm config and inheriting the `RE3UpdateCallbacks` as shown in this `example <https://github.com/ray-project/ray/blob/c9c3f0745a9291a4de0872bdfa69e4ffdfac3657/rllib/utils/exploration/tests/test_random_encoder.py#L35>`__. Most of its parameters usually do not have to be specified as the module uses the values from the paper by default. For example:
.. code-block:: python

View file

@ -146,20 +146,20 @@ In the above snippet, ``actions`` is a Tensor placeholder of shape ``[batch_size
name="MyTFPolicy",
loss_fn=policy_gradient_loss)
We can create a `Trainer <#trainers>`__ and try running this policy on a toy env with two parallel rollout workers:
We can create an `Algorithm <#algorithms>`__ and try running this policy on a toy env with two parallel rollout workers:
.. code-block:: python
import ray
from ray import tune
from ray.rllib.agents.trainer import Trainer
from ray.rllib.algorithms.algorithm import Algorithm
class MyTrainer(Trainer):
class MyAlgo(Algorithm):
def get_default_policy_class(self, config):
return MyTFPolicy
ray.init()
tune.run(MyTrainer, config={"env": "CartPole-v0", "num_workers": 2})
tune.run(MyAlgo, config={"env": "CartPole-v0", "num_workers": 2})
If you run the above snippet `(runnable file here) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_tf_policy.py>`__, you'll probably notice that CartPole doesn't learn so well:
@ -202,7 +202,7 @@ Let's modify our policy loss to include rewards summed over time. To enable this
loss_fn=policy_gradient_loss,
postprocess_fn=postprocess_advantages)
The ``postprocess_advantages()`` function above uses calls RLlib's ``compute_advantages`` function to compute advantages for each timestep. If you re-run the trainer with this improved policy, you'll find that it quickly achieves the max reward of 200.
The ``postprocess_advantages()`` function above uses calls RLlib's ``compute_advantages`` function to compute advantages for each timestep. If you re-run the algorithm with this improved policy, you'll find that it quickly achieves the max reward of 200.
You might be wondering how RLlib makes the advantages placeholder automatically available as ``train_batch[Postprocessing.ADVANTAGES]``. When building your policy, RLlib will create a "dummy" trajectory batch where all observations, actions, rewards, etc. are zeros. It then calls your ``postprocess_fn``, and generates TF placeholders based on the numpy shapes of the postprocessed batch. RLlib tracks which placeholders that ``loss_fn`` and ``stats_fn`` access, and then feeds the corresponding sample data into those placeholders during loss optimization. You can also access these placeholders via ``policy.get_placeholder(<name>)`` after loss initialization.
@ -210,38 +210,37 @@ You might be wondering how RLlib makes the advantages placeholder automatically
In the above section you saw how to compose a simple policy gradient algorithm with RLlib.
In this example, we'll dive into how PPO is defined within RLlib and how you can modify it.
First, check out the `PPO trainer definition <https://github.com/ray-project/ray/blob/master/rllib/algorithms/ppo/ppo.py>`__:
First, check out the `PPO definition <https://github.com/ray-project/ray/blob/master/rllib/algorithms/ppo/ppo.py>`__:
.. code-block:: python
class PPO(Trainer):
class PPO(Algorithm):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return DEFAULT_CONFIG
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@override(Algorithm)
def validate_config(self, config: AlgorithmConfigDict) -> None:
...
@override(Trainer)
@override(Algorithm)
def get_default_policy_class(self, config):
return PPOTFPolicy
@staticmethod
@override(Trainer)
def execution_plan(workers, config, **kwargs):
@override(Algorithm)
def training_iteration(self):
...
Besides some boilerplate for defining the PPO configuration and some warnings, the most important method to take note of is the ``execution_plan``.
The trainer's `execution plan <#execution-plans>`__ defines the distributed training workflow.
Depending on the ``simple_optimizer`` trainer config,
The algorithm's `execution plan <#execution-plans>`__ defines the distributed training workflow.
Depending on the ``simple_optimizer`` config key,
PPO can switch between a simple synchronous plan, or a multi-GPU plan that implements minibatch SGD (the default):
.. code-block:: python
def execution_plan(workers: WorkerSet, config: TrainerConfigDict):
def execution_plan(workers: WorkerSet, config: AlgorithmConfigDict):
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# Collect large batches of relevant experiences & standardize.
@ -316,7 +315,7 @@ Now let's look at each PPO policy definition:
before_loss_init=setup_mixins,
mixins=[LearningRateSchedule, KLCoeffMixin, ValueNetworkMixin])
``stats_fn``: The stats function returns a dictionary of Tensors that will be reported with the training results. This also includes the ``kl`` metric which is used by the trainer to adjust the KL penalty. Note that many of the values below reference ``policy.loss_obj``, which is assigned by ``loss_fn`` (not shown here since the PPO loss is quite complex). RLlib will always call ``stats_fn`` after ``loss_fn``, so you can rely on using values saved by ``loss_fn`` as part of your statistics:
``stats_fn``: The stats function returns a dictionary of Tensors that will be reported with the training results. This also includes the ``kl`` metric which is used by the algorithm to adjust the KL penalty. Note that many of the values below reference ``policy.loss_obj``, which is assigned by ``loss_fn`` (not shown here since the PPO loss is quite complex). RLlib will always call ``stats_fn`` after ``loss_fn``, so you can rely on using values saved by ``loss_fn`` as part of your statistics:
.. code-block:: python
@ -464,7 +463,8 @@ Here's an example of using eager ops embedded
Building Policies in PyTorch
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a trainer given a Torch policy is exactly the same). Here's a simple example of a trivial torch policy `(runnable file here) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_torch_policy.py>`__:
Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a algorithm given a Torch policy is exactly the same).
Here's a simple example of a trivial torch policy `(runnable file here) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_torch_policy.py>`__:
.. code-block:: python

View file

@ -55,13 +55,13 @@ Contributing Algorithms
These are the guidelines for merging new algorithms into RLlib:
* Contributed algorithms (`rllib/contrib <https://github.com/ray-project/ray/tree/master/rllib/contrib>`__):
- must subclass Trainer and implement the ``step()`` method
- must subclass Algorithm and implement the ``step()`` method
- must include a lightweight test (`example <https://github.com/ray-project/ray/blob/6bb110393008c9800177490688c6ed38b2da52a9/test/jenkins_tests/run_multi_node_tests.sh#L45>`__) to ensure the algorithm runs
- should include tuned hyperparameter examples and documentation
- should offer functionality not present in existing algorithms
* Fully integrated algorithms (`rllib/agents <https://github.com/ray-project/ray/tree/master/rllib/agents>`__) have the following additional requirements:
- must fully implement the Trainer API
- must fully implement the Algorithm API
- must offer substantial new functionality not possible to add to other algorithms
- should support custom models and preprocessors
- should use RLlib abstractions and support distributed execution
@ -70,14 +70,14 @@ Both integrated and contributed algorithms ship with the ``ray`` PyPI package, a
How to add an algorithm to ``contrib``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
It takes just two changes to add an algorithm to `contrib <https://github.com/ray-project/ray/tree/master/rllib/contrib>`__. A minimal example can be found `here <https://github.com/ray-project/ray/tree/master/rllib/contrib/random_agent/random_agent.py>`__. First, subclass `Trainer <https://github.com/ray-project/ray/commits/master/rllib/agents/trainer.py>`__ and implement the ``_init`` and ``step`` methods:
It takes just two changes to add an algorithm to `contrib <https://github.com/ray-project/ray/tree/master/rllib/contrib>`__. A minimal example can be found `here <https://github.com/ray-project/ray/tree/master/rllib/contrib/random_agent/random_agent.py>`__. First, subclass `Algorithm <https://github.com/ray-project/ray/commits/master/rllib/algorithms/algorithm.py>`__ and implement the ``_init`` and ``step`` methods:
.. literalinclude:: ../../../rllib/contrib/random_agent/random_agent.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
Second, register the trainer with a name in `contrib/registry.py <https://github.com/ray-project/ray/blob/master/rllib/contrib/registry.py>`__.
Second, register the algorithm with a name in `contrib/registry.py <https://github.com/ray-project/ray/blob/master/rllib/contrib/registry.py>`__.
.. code-block:: python
@ -110,7 +110,7 @@ Finding Memory Leaks In Workers
Keeping the memory usage of long running workers stable can be challenging. The ``MemoryTrackingCallbacks`` class can be used to track memory usage of workers.
.. autoclass:: ray.rllib.agents.callbacks.MemoryTrackingCallbacks
.. autoclass:: ray.rllib.algorithms.callbacks.MemoryTrackingCallbacks
The objects with the top 20 memory usage in the workers will be added as custom metrics. These can then be monitored using tensorboard or other metrics integrations like Weights and Biases:

View file

@ -16,12 +16,13 @@ RLlib works with several different types of environments, including `OpenAI Gym
Configuring Environments
------------------------
You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name <https://www.gymlibrary.ml/>`__. Custom env classes passed directly to the trainer must take a single ``env_config`` parameter in their constructor:
You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name <https://www.gymlibrary.ml/>`__.
Custom env classes passed directly to the algorithm must take a single ``env_config`` parameter in their constructor:
.. code-block:: python
import gym, ray
from ray.rllib.agents import ppo
from ray.rllib.algorithms import ppo
class MyEnv(gym.Env):
def __init__(self, env_config):
@ -33,12 +34,12 @@ You can pass either a string name or a Python class to specify an environment. B
return <obs>, <reward: float>, <done: bool>, <info: dict>
ray.init()
trainer = ppo.PPO(env=MyEnv, config={
algo = ppo.PPO(env=MyEnv, config={
"env_config": {}, # config to pass to env class
})
while True:
print(trainer.train())
print(algo.train())
You can also register a custom env creator function with a string name. This function must take a single ``env_config`` (dict) parameter and return an env instance:
@ -50,7 +51,7 @@ You can also register a custom env creator function with a string name. This fun
return MyEnv(...) # return an env instance
register_env("my_env", env_creator)
trainer = ppo.PPO(env="my_env")
algo = ppo.PPO(env="my_env")
For a full runnable code example using the custom environment API, see `custom_env.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__.
@ -58,7 +59,10 @@ For a full runnable code example using the custom environment API, see `custom_e
The gym registry is not compatible with Ray. Instead, always use the registration flows documented above to ensure Ray workers can access the environment.
In the above example, note that the ``env_creator`` function takes in an ``env_config`` object. This is a dict containing options passed in through your trainer. You can also access ``env_config.worker_index`` and ``env_config.vector_index`` to get the worker id and env id within the worker (if ``num_envs_per_worker > 0``). This can be useful if you want to train over an ensemble of different environments, for example:
In the above example, note that the ``env_creator`` function takes in an ``env_config`` object.
This is a dict containing options passed in through your algorithm.
You can also access ``env_config.worker_index`` and ``env_config.vector_index`` to get the worker id and env id within the worker (if ``num_envs_per_worker > 0``).
This can be useful if you want to train over an ensemble of different environments, for example:
.. code-block:: python
@ -225,12 +229,12 @@ If all the agents will be using the same algorithm class to train, then you can
.. code-block:: python
trainer = pg.PGAgent(env="my_multiagent_env", config={
algo = pg.PGAgent(env="my_multiagent_env", config={
"multiagent": {
"policies": {
# Use the PolicySpec namedtuple to specify an individual policy:
"car1": PolicySpec(
policy_class=None, # infer automatically from Trainer
policy_class=None, # infer automatically from Algorithm
observation_space=None, # infer automatically from env
action_space=None, # infer automatically from env
config={"gamma": 0.85}, # use main config plus <- this override here
@ -238,7 +242,7 @@ If all the agents will be using the same algorithm class to train, then you can
# Deprecated way: Tuple specifying class, obs-/action-spaces,
# config-overrides for each policy as a tuple.
# If class is None -> Uses Trainer's default policy class.
# If class is None -> Uses Algorithm's default policy class.
"car2": (None, car_obs_space, car_act_space, {"gamma": 0.99}),
# New way: Use PolicySpec() with keywords: `policy_class`,
@ -257,7 +261,7 @@ If all the agents will be using the same algorithm class to train, then you can
})
while True:
print(trainer.train())
print(algo.train())
To exclude some policies in your ``multiagent.policies`` dictionary, you can use the ``multiagent.policies_to_train`` setting.
For example, you may want to have one or more random (non learning) policies interact with your learning ones:
@ -275,7 +279,7 @@ For example, you may want to have one or more random (non learning) policies int
# (start player) and sometimes player2 (player to move 2nd).
return "learning_policy" if episode.episode_id % 2 == agent_idx else "random_policy"
trainer = pg.PGAgent(env="two_player_game", config={
algo = pg.PGAgent(env="two_player_game", config={
"multiagent": {
"policies": {
"learning_policy": PolicySpec(), # <- use default class & infer obs-/act-spaces from env.
@ -505,11 +509,11 @@ External Application Clients
For applications that are running entirely outside the Ray cluster (i.e., cannot be packaged into a Python environment of any form), RLlib provides the ``PolicyServerInput`` application connector, which can be connected to over the network using ``PolicyClient`` instances.
You can configure any Trainer to launch a policy server with the following config:
You can configure any Algorithm to launch a policy server with the following config:
.. code-block:: python
trainer_config = {
config = {
# An environment class is still required, but it doesn't need to be runnable.
# You only need to define its action and observation space attributes.
# See examples/serving/unity3d_server.py for an example using a RandomMultiAgentEnv stub.
@ -518,13 +522,16 @@ You can configure any Trainer to launch a policy server with the following confi
"input": (
lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, SERVER_PORT)
),
# Use the existing trainer process to run the server.
# Use the existing algorithm process to run the server.
"num_workers": 0,
# Disable OPE, since the rollouts are coming from online clients.
"off_policy_estimation_methods": {},
}
Clients can then connect in either *local* or *remote* inference mode. In local inference mode, copies of the policy are downloaded from the server and cached on the client for a configurable period of time. This allows actions to be computed by the client without requiring a network round trip each time. In remote inference mode, each computed action requires a network call to the server.
Clients can then connect in either *local* or *remote* inference mode.
In local inference mode, copies of the policy are downloaded from the server and cached on the client for a configurable period of time.
This allows actions to be computed by the client without requiring a network round trip each time.
In remote inference mode, each computed action requires a network call to the server.
Example:

View file

@ -38,7 +38,7 @@ Environments and Adapters
- `Registering a custom env and model <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__:
Example of defining and registering a gym env and model for use with RLlib.
- `Local Unity3D multi-agent environment example <https://github.com/ray-project/ray/tree/master/rllib/examples/unity3d_env_local.py>`__:
Example of how to setup an RLlib Trainer against a locally running Unity3D editor instance to
Example of how to setup an RLlib Algorithm against a locally running Unity3D editor instance to
learn any Unity3D game (including support for multi-agent).
Use this example to try things out and watch the game and the learning progress live in the editor.
Providing a compiled game, this example could also run in distributed fashion with `num_workers > 0`.
@ -50,7 +50,7 @@ Environments and Adapters
- `DMLab Watermaze example <https://github.com/ray-project/ray/blob/master/rllib/examples/dmlab_watermaze.py>`__:
Example for how to use a DMLab environment (Watermaze).
- `RecSym environment example (for recommender systems) using the SlateQ algorithm <https://github.com/ray-project/ray/blob/master/rllib/examples/recommender_system_with_recsim_and_slateq.py>`__:
Script showing how to train a SlateQTrainer on a RecSym environment.
Script showing how to train SlateQ on a RecSym environment.
- `SUMO (Simulation of Urban MObility) environment example <https://github.com/ray-project/ray/blob/master/rllib/examples/sumo_env_local.py>`__:
Example demonstrating how to use the SUMO simulator in connection with RLlib.
- `VizDoom example script using RLlib's auto-attention wrapper <https://github.com/ray-project/ray/blob/master/rllib/examples/vizdoom_with_attention_net.py>`__:
@ -107,7 +107,7 @@ Training Workflows
- `Using rollout workers directly for control over the whole training workflow <https://github.com/ray-project/ray/blob/master/rllib/examples/rollout_worker_custom_workflow.py>`__:
Example of how to use RLlib's lower-level building blocks to implement a fully customized training workflow.
- `Custom execution plan function handling two different Policies (DQN and PPO) at the same time <https://github.com/ray-project/ray/blob/master/rllib/examples/two_trainer_workflow.py>`__:
Example of how to use the exec. plan of a Trainer to trin two different policies in parallel (also using multi-agent API).
Example of how to use the exec. plan of an Algorithm to trin two different policies in parallel (also using multi-agent API).
- `Custom tune experiment <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_experiment.py>`__:
How to run a custom Ray Tune experiment with RLlib with custom training- and evaluation phases.
@ -164,8 +164,8 @@ Multi-Agent and Hierarchical
Example of running a custom hand-coded policy alongside trainable policies.
- `Weight sharing between policies <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent_cartpole.py>`__:
Example of how to define weight-sharing layers between two different policies.
- `Multiple trainers <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent_two_trainers.py>`__:
Example of alternating training between two DQN and PPO trainers.
- `Multiple algorithms <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent_two_trainers.py>`__:
Example of alternating training between DQN and PPO.
- `Hierarchical training <https://github.com/ray-project/ray/blob/master/rllib/examples/hierarchical_training.py>`__:
Example of hierarchical training using the multi-agent API.
- `Iterated Prisoner's Dilemma environment example <https://github.com/ray-project/ray/blob/master/rllib/examples/iterated_prisoners_dilemma_env.py>`__:

View file

@ -43,7 +43,7 @@ observation space. Thereby, the following simple rules apply:
observations: ``dict_or_tuple_obs = restore_original_dimensions(input_dict["obs"], self.obs_space, "tf|torch")``
For Atari observation spaces, RLlib defaults to using the `DeepMind preprocessors <https://github.com/ray-project/ray/blob/master/rllib/env/wrappers/atari_wrappers.py>`__
(``preprocessor_pref=deepmind``). However, if the Trainer's config key ``preprocessor_pref`` is set to "rllib",
(``preprocessor_pref=deepmind``). However, if the Algorithm's config key ``preprocessor_pref`` is set to "rllib",
the following mappings apply for Atari-type observation spaces:
- Images of shape ``(210, 160, 3)`` are downscaled to ``dim x dim``, where
@ -75,7 +75,7 @@ and some special options for Atari environments:
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
The dict above (or an overriding sub-set) is handed to the Trainer via the ``model`` key within
The dict above (or an overriding sub-set) is handed to the Algorithm via the ``model`` key within
the main config dict like so:
.. code-block:: python
@ -90,7 +90,7 @@ the main config dict like so:
"fcnet_activation": "relu",
},
# ... other Trainer config keys, e.g. "lr" ...
# ... other Algorithm config keys, e.g. "lr" ...
"lr": 0.00001,
}
@ -107,7 +107,7 @@ based on simple heuristics:
- A fully connected network (`TF <https://github.com/ray-project/ray/blob/master/rllib/models/tf/fcnet.py>`__ or `Torch <https://github.com/ray-project/ray/blob/master/rllib/models/torch/fcnet.py>`__)
for everything else.
These default model types can further be configured via the ``model`` config key inside your Trainer config (as discussed above).
These default model types can further be configured via the ``model`` config key inside your Algorithm config (as discussed above).
Available settings are `listed above <#default-model-config-settings>`__ and also documented in the `model catalog file <https://github.com/ray-project/ray/blob/master/rllib/models/catalog.py>`__.
Note that for the vision network case, you'll probably have to configure ``conv_filters``, if your environment observations
@ -227,7 +227,7 @@ Once implemented, your TF model can then be registered and used in place of a bu
ModelCatalog.register_custom_model("my_tf_model", MyModelClass)
ray.init()
trainer = ppo.PPO(env="CartPole-v0", config={
algo = ppo.PPO(env="CartPole-v0", config={
"model": {
"custom_model": "my_tf_model",
# Extra kwargs to be passed to your model's c'tor.
@ -270,7 +270,7 @@ Once implemented, your PyTorch model can then be registered and used in place of
import torch.nn as nn
import ray
from ray.rllib.agents import ppo
from ray.rllib.algorithms import ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
@ -282,7 +282,7 @@ Once implemented, your PyTorch model can then be registered and used in place of
ModelCatalog.register_custom_model("my_torch_model", CustomTorchModel)
ray.init()
trainer = ppo.PPO(env="CartPole-v0", config={
algo = ppo.PPO(env="CartPole-v0", config={
"framework": "torch",
"model": {
"custom_model": "my_torch_model",
@ -368,7 +368,7 @@ Custom Model APIs (on Top of Default- or Custom Models)
```````````````````````````````````````````````````````
So far we talked about a) the default models that are built into RLlib and are being provided
automatically if you don't specify anything in your Trainer's config and b) custom Models through
automatically if you don't specify anything in your Algorithm's config and b) custom Models through
which you can define any arbitrary forward passes.
Another typical situation in which you would have to customize a model would be to
@ -508,7 +508,7 @@ Similar to custom models and preprocessors, you can also specify a custom action
ModelCatalog.register_custom_action_dist("my_dist", MyActionDist)
ray.init()
trainer = ppo.PPO(env="CartPole-v0", config={
algo = ppo.PPO(env="CartPole-v0", config={
"model": {
"custom_action_dist": "my_dist",
},

View file

@ -83,13 +83,13 @@ This example plot shows the Q-value metric in addition to importance sampling (I
.. code-block:: python
trainer = DQN(...)
algo = DQN(...)
... # train policy offline
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
estimator = WeightedImportanceSamplingEstimator(trainer.get_policy(), gamma=0.99)
estimator = WeightedImportanceSamplingEstimator(algo.get_policy(), gamma=0.99)
reader = JsonReader("/path/to/data")
for _ in range(1000):
batch = reader.next()
@ -246,11 +246,11 @@ Input API
You can configure experience input for an agent using the following options:
.. tip::
Plain python config dicts will soon be replaced by :py:class:`~ray.rllib.agents.trainer_config.TrainerConfig`
Plain python config dicts will soon be replaced by :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig`
objects, which have the advantage of being type safe, allowing users to set different config settings within
meaningful sub-categories (e.g. ``my_config.offline_data(input_=[xyz])``), and offer the ability to
construct a Trainer instance from these config objects (via their ``.build()`` method).
So far, this is only supported for some Trainer classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
construct an Algorithm instance from these config objects (via their ``.build()`` method).
So far, this is only supported for some Algorithm classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
but we are rolling this out right now across all RLlib.
@ -343,11 +343,11 @@ Output API
You can configure experience output for an agent using the following options:
.. tip::
Plain python config dicts will soon be replaced by :py:class:`~ray.rllib.agents.trainer_config.TrainerConfig`
Plain python config dicts will soon be replaced by :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig`
objects, which have the advantage of being type safe, allowing users to set different config settings within
meaningful sub-categories (e.g. ``my_config.offline_data(input_=[xyz])``), and offer the ability to
construct a Trainer instance from these config objects (via their ``.build()`` method).
So far, this is only supported for some Trainer classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
construct an Algorithm instance from these config objects (via their ``.build()`` method).
So far, this is only supported for some Algorithm classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
but we are rolling this out right now across all RLlib.
.. code-block:: python

View file

@ -29,7 +29,7 @@ This is done using a dict that maps strings (column names) to `ViewRequirement`
The exact behavior for a single such rollout and the number of environment transitions therein
are determined by the following Trainer config keys:
are determined by the following Algorithm config keys:
**batch_mode [truncate_episodes|complete_episodes]**:
*truncated_episodes (default value)*:
@ -65,7 +65,7 @@ of each episode (arrow heads). This way, RLlib makes sure that the
**multiagent.count_steps_by [env_steps|agent_steps]**:
Within the Trainer's ``multiagent`` config dict, you can set the unit, by which RLlib will count a) rollout fragment lengths as well as b) the size of the final train_batch (see below). The two supported values are:
Within the Algorithm's ``multiagent`` config dict, you can set the unit, by which RLlib will count a) rollout fragment lengths as well as b) the size of the final train_batch (see below). The two supported values are:
*env_steps (default)*:
Each call to ``[Env].step()`` is counted as one. It does not
@ -109,7 +109,7 @@ RLlib's default ``SampleCollector`` class is the ``SimpleListCollector``, which
to lists, then builds SampleBatches from these and sends them to the downstream processing functions.
It thereby tries to avoid collecting duplicate data separately (OBS and NEXT_OBS use the same underlying list).
If you want to implement your own collection logic and data structures, you can sub-class ``SampleCollector``
and specify that new class under the Trainer's "sample_collector" config key.
and specify that new class under the Algorithm's "sample_collector" config key.
Let's now look at how the Policy's Model lets the RolloutWorker and its SampleCollector
know, what data in the ongoing episode/trajectory to use for the different required method calls

View file

@ -8,13 +8,13 @@ Training APIs
Getting Started
---------------
At a high level, RLlib provides an ``Trainer`` class which
holds a policy for environment interaction. Through the trainer interface, the policy can
be trained, checkpointed, or an action computed. In multi-agent training, the trainer manages the querying and optimization of multiple policies at once.
At a high level, RLlib provides an ``Algorithm`` class which
holds a policy for environment interaction. Through the algorithm's interface, the policy can
be trained, checkpointed, or an action computed. In multi-agent training, the algorithm manages the querying and optimization of multiple policies at once.
.. image:: images/rllib-api.svg
You can train a simple DQN trainer with the following commands:
You can train DQN with the following commands:
.. code-block:: bash
@ -75,8 +75,8 @@ Specifying Parameters
~~~~~~~~~~~~~~~~~~~~~
Each algorithm has specific hyperparameters that can be set with ``--config``, in addition to a number of
`common hyperparameters <https://github.com/ray-project/ray/blob/master/rllib/agents/trainer.py>`__
(soon to be replaced by `TrainerConfig objects <https://github.com/ray-project/ray/blob/master/rllib/agents/trainer_config.py>`__).
`common hyperparameters <https://github.com/ray-project/ray/blob/master/rllib/algorithms/algorithm.py>`__
(soon to be replaced by `AlgorithmConfig objects <https://github.com/ray-project/ray/blob/master/rllib/algorithms/algorithm_config.py>`__).
See the `algorithms documentation <rllib-algorithms.html>`__ for more information.
@ -90,10 +90,10 @@ Specifying Resources
~~~~~~~~~~~~~~~~~~~~
You can control the degree of parallelism used by setting the ``num_workers``
hyperparameter for most algorithms. The Trainer will construct that many
hyperparameter for most algorithms. The Algorithm will construct that many
"remote worker" instances (`see RolloutWorker class <https://github.com/ray-project/ray/blob/master/rllib/evaluation/rollout_worker.py>`__)
that are constructed as ray.remote actors, plus exactly one "local worker", a ``RolloutWorker`` object that is not a
ray actor, but lives directly inside the Trainer.
ray actor, but lives directly inside the Algorithm.
For most algorithms, learning updates are performed on the local worker and sample collection from
one or more environments is performed by the remote workers (in parallel).
For example, setting ``num_workers=0`` will only create the local worker, in which case both
@ -106,7 +106,7 @@ to that worker via the ``num_gpus`` setting.
Similarly, the resource allocation to remote workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``.
The number of GPUs can be fractional quantities (e.g. 0.5) to allocate only a fraction
of a GPU. For example, with DQN you can pack five trainers onto one GPU by setting
of a GPU. For example, with DQN you can pack five algorithms onto one GPU by setting
``num_gpus: 0.2``. Check out `this fractional GPU example here <https://github.com/ray-project/ray/blob/master/rllib/examples/fractional_gpus.py>`__
as well that also demonstrates how environments (running on the remote workers) that
require a GPU can benefit from the ``num_gpus_per_worker`` setting.
@ -150,8 +150,8 @@ Here are some rules of thumb for scaling training with RLlib.
In case you are using lots of workers (``num_workers >> 10``) and you observe worker failures for whatever reasons, which normally interrupt your RLlib training runs, consider using
the config settings ``ignore_worker_failures=True`` or ``recreate_failed_workers=True``:
``ignore_worker_failures=True`` allows your Trainer to not crash due to a single worker error, but to continue for as long as there is at least one functional worker remaining.
``recreate_failed_workers=True`` will have your Trainer attempt to replace/recreate any failed worker(s) with a new one.
``ignore_worker_failures=True`` allows your Algorithm to not crash due to a single worker error, but to continue for as long as there is at least one functional worker remaining.
``recreate_failed_workers=True`` will have your Algorithm attempt to replace/recreate any failed worker(s) with a new one.
Both these settings will make your training runs much more stable and more robust against occasional OOM or other similar "once in a while" errors.
@ -160,11 +160,11 @@ Common Parameters
~~~~~~~~~~~~~~~~~
.. tip::
Plain python config dicts will soon be replaced by :py:class:`~ray.rllib.agents.trainer_config.TrainerConfig`
Plain python config dicts will soon be replaced by :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig`
objects, which have the advantage of being type safe, allowing users to set different config settings within
meaningful sub-categories (e.g. ``my_config.training(lr=0.0003)``), and offer the ability to
construct a Trainer instance from these config objects (via their ``build()`` method).
So far, this is only supported for some Trainer classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
construct an Algorithm instance from these config objects (via their ``build()`` method).
So far, this is only supported for some Algorithm classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
but we are rolling this out right now across all RLlib.
The following is a list of the common algorithm hyper-parameters:
@ -173,7 +173,7 @@ The following is a list of the common algorithm hyper-parameters:
# === Settings for Rollout Worker processes ===
# Number of rollout worker actors to create for parallel sampling. Setting
# this to 0 will force rollouts to be done in the trainer actor.
# this to 0 will force rollouts to be done in the algorithm's actor.
"num_workers": 2,
# Number of environments to evaluate vector-wise per worker. This enables
# model inference batching, which can improve performance for inference
@ -214,7 +214,7 @@ The following is a list of the common algorithm hyper-parameters:
# terminates or a configured horizon (hard or soft) is hit.
"batch_mode": "truncate_episodes",
# === Settings for the Trainer process ===
# === Settings for the Algorithm process ===
# Discount factor of the MDP.
"gamma": 0.99,
# The default learning rate.
@ -257,7 +257,7 @@ The following is a list of the common algorithm hyper-parameters:
# a PyBullet env, a ViZDoomGym env, or a fully qualified classpath to an
# Env class, e.g. "ray.rllib.examples.env.random_env.RandomEnv".
"env": None,
# The observation- and action spaces for the Policies of this Trainer.
# The observation- and action spaces for the Policies of this Algorithm.
# Use None for automatically inferring these from the given env.
"observation_space": None,
"action_space": None,
@ -399,10 +399,10 @@ The following is a list of the common algorithm hyper-parameters:
# The unit, with which to count the evaluation duration. Either "episodes"
# (default) or "timesteps".
"evaluation_duration_unit": "episodes",
# Whether to run evaluation in parallel to a Trainer.train() call
# Whether to run evaluation in parallel to a Algorithm.train() call
# using threading. Default=False.
# E.g. evaluation_interval=2 -> For every other training iteration,
# the Trainer.train() and Trainer.evaluate() calls run in parallel.
# the Algorithm.train() and Algorithm.evaluate() calls run in parallel.
# Note: This is experimental. Possible pitfalls could be race conditions
# for weight synching at the beginning of the evaluation loop.
"evaluation_parallel_to_training": False,
@ -420,16 +420,16 @@ The following is a list of the common algorithm hyper-parameters:
},
# Number of parallel workers to use for evaluation. Note that this is set
# to zero by default, which means evaluation will be run in the trainer
# to zero by default, which means evaluation will be run in the algorithm
# process (only if evaluation_interval is not None). If you increase this,
# it will increase the Ray resource usage of the trainer since evaluation
# it will increase the Ray resource usage of the algorithm since evaluation
# workers are created separately from rollout workers (used to sample data
# for training).
"evaluation_num_workers": 0,
# Customize the evaluation method. This must be a function of signature
# (trainer: Trainer, eval_workers: WorkerSet) -> metrics: dict. See the
# Trainer.evaluate() method to see the default implementation.
# The Trainer guarantees all eval workers have the latest policy state
# (algorithm: Algorithm, eval_workers: WorkerSet) -> metrics: dict. See the
# Algorithm.evaluate() method to see the default implementation.
# The Algorithm guarantees all eval workers have the latest policy state
# before this function is called.
"custom_eval_function": None,
# Make sure the latest available evaluation results are always attached to
@ -504,15 +504,15 @@ The following is a list of the common algorithm hyper-parameters:
# each worker, so that identically configured trials will have identical
# results. This makes experiments reproducible.
"seed": None,
# Any extra python env vars to set in the trainer process, e.g.,
# Any extra python env vars to set in the algorithm process, e.g.,
# {"OMP_NUM_THREADS": "16"}
"extra_python_environs_for_driver": {},
# The extra python environments need to set for worker processes.
"extra_python_environs_for_worker": {},
# === Resource Settings ===
# Number of GPUs to allocate to the trainer process. Note that not all
# algorithms can take advantage of trainer GPUs. Support for multi-GPU
# Number of GPUs to allocate to the algorithm process. Note that not all
# algorithms can take advantage of GPUs. Support for multi-GPU
# is currently only available for tf-[PPO/IMPALA/DQN/PG].
# This can be fractional (e.g., 0.3 GPUs).
"num_gpus": 0,
@ -528,13 +528,13 @@ The following is a list of the common algorithm hyper-parameters:
"num_gpus_per_worker": 0,
# Any custom Ray resources to allocate per worker.
"custom_resources_per_worker": {},
# Number of CPUs to allocate for the trainer. Note: this only takes effect
# when running in Tune. Otherwise, the trainer runs in the main program.
# Number of CPUs to allocate for the algorithm. Note: this only takes effect
# when running in Tune. Otherwise, the algorithm runs in the main program.
"num_cpus_for_driver": 1,
# The strategy for the placement group factory returned by
# `Trainer.default_resource_request()`. A PlacementGroup defines, which
# `Algorithm.default_resource_request()`. A PlacementGroup defines, which
# devices (resources) should always be co-located on the same node.
# For example, a Trainer with 2 rollout workers, running with
# For example, an Algorithm with 2 rollout workers, running with
# num_gpus=1 will request a placement group with the bundles:
# [{"gpu": 1, "cpu": 1}, {"cpu": 1}, {"cpu": 1}], where the first bundle is
# for the driver and the other 2 bundles are for the two workers.
@ -655,7 +655,7 @@ The following is a list of the common algorithm hyper-parameters:
# === API deprecations/simplifications/changes ===
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_step()` method will be called on each
# a Algorithm's `training_step()` method will be called on each
# training iteration.
"_disable_execution_plan_api": True,
@ -716,17 +716,17 @@ Here is an example of the basic usage (for a more complete example, see `custom_
config = ppo.DEFAULT_CONFIG.copy()
config["num_gpus"] = 0
config["num_workers"] = 1
trainer = ppo.PPO(config=config, env="CartPole-v0")
algo = ppo.PPO(config=config, env="CartPole-v0")
# Can optionally call trainer.restore(path) to load a checkpoint.
# Can optionally call algo.restore(path) to load a checkpoint.
for i in range(1000):
# Perform one iteration of training the policy with PPO
result = trainer.train()
result = algo.train()
print(pretty_print(result))
if i % 100 == 0:
checkpoint = trainer.save()
checkpoint = algo.save()
print("checkpoint saved at", checkpoint)
# Also, in case you have trained a model outside of ray/RLlib and have created
@ -734,9 +734,9 @@ Here is an example of the basic usage (for a more complete example, see `custom_
# my_keras_model_trained_outside_rllib.save_weights("model.h5")
# (see: https://keras.io/models/about-keras-models/)
# ... you can load the h5-weights into your Trainer's Policy's ModelV2
# ... you can load the h5-weights into your Algorithm's Policy's ModelV2
# (tf or torch) by doing:
trainer.import_model("my_weights.h5")
algo.import_model("my_weights.h5")
# NOTE: In order for this to work, your (custom) model needs to implement
# the `import_from_h5` method.
# See https://github.com/ray-project/ray/blob/master/rllib/tests/test_model_imports.py
@ -744,9 +744,9 @@ Here is an example of the basic usage (for a more complete example, see `custom_
.. note::
It's recommended that you run RLlib trainers with :doc:`Tune <../tune/index>`, for easy experiment management and visualization of results. Just set ``"run": ALG_NAME, "env": ENV_NAME`` in the experiment config.
It's recommended that you run RLlib algorithms with :doc:`Tune <../tune/index>`, for easy experiment management and visualization of results. Just set ``"run": ALG_NAME, "env": ENV_NAME`` in the experiment config.
All RLlib trainers are compatible with the :ref:`Tune API <tune-60-seconds>`. This enables them to be easily used in experiments with :doc:`Tune <../tune/index>`. For example, the following code performs a simple hyperparam sweep of PPO:
All RLlib algorithms are compatible with the :ref:`Tune API <tune-60-seconds>`. This enables them to be easily used in experiments with :doc:`Tune <../tune/index>`. For example, the following code performs a simple hyperparam sweep of PPO:
.. code-block:: python
@ -818,7 +818,7 @@ Loading and restoring a trained agent from a checkpoint is simple:
Computing Actions
~~~~~~~~~~~~~~~~~
The simplest way to programmatically compute actions from a trained agent is to use ``trainer.compute_action()``.
The simplest way to programmatically compute actions from a trained agent is to use ``Algorithm.compute_action()``.
This method preprocesses and filters the observation before passing it to the agent policy.
Here is a simple example of testing a trained agent for one episode:
@ -836,12 +836,12 @@ Here is a simple example of testing a trained agent for one episode:
obs, reward, done, info = env.step(action)
episode_reward += reward
For more advanced usage, you can access the ``workers`` and policies held by the trainer
For more advanced usage, you can access the ``workers`` and policies held by the algorithm
directly as ``compute_action()`` does:
.. code-block:: python
class Trainer(Trainable):
class Algorithm(Trainable):
@PublicAPI
def compute_action(self,
@ -857,7 +857,7 @@ directly as ``compute_action()`` does:
Note that you can also access the policy object through
self.get_policy(policy_id) and call compute_actions() on it directly.
Arguments:
Args:
observation (obj): observation from the environment.
state (list): RNN hidden state, if any. If state is not None,
then all of compute_single_action(...) is returned
@ -905,23 +905,31 @@ directly as ``compute_action()`` does:
Accessing Policy State
~~~~~~~~~~~~~~~~~~~~~~
It is common to need to access a trainer's internal state, e.g., to set or get internal weights. In RLlib trainer state is replicated across multiple *rollout workers* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``trainer.workers.foreach_worker()`` or ``trainer.workers.foreach_worker_with_index()``. These functions take a lambda function that is applied with the worker as an arg. You can also return values from these functions and those will be returned as a list.
It is common to need to access a algorithm's internal state, e.g., to set or get internal weights.
In RLlib algorithm state is replicated across multiple *rollout workers* (Ray actors) in the cluster.
However, you can easily get and update this state between calls to ``train()`` via ``Algorithm.workers.foreach_worker()`` or ``Algorithm.workers.foreach_worker_with_index()``.
These functions take a lambda function that is applied with the worker as an arg.
You can also return values from these functions and those will be returned as a list.
You can also access just the "master" copy of the trainer state through ``trainer.get_policy()`` or ``trainer.workers.local_worker()``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``trainer.get_policy().get_weights()``. This is also equivalent to ``trainer.workers.local_worker().policy_map["default_policy"].get_weights()``:
You can also access just the "master" copy of the algorithm state through ``Algorithm.get_policy()`` or
``Algorithm.workers.local_worker()``, but note that updates here may not be immediately reflected in
remote replicas if you have configured ``num_workers > 0``.
For example, to access the weights of a local TF policy, you can run ``Algorithm.get_policy().get_weights()``.
This is also equivalent to ``Algorithm.workers.local_worker().policy_map["default_policy"].get_weights()``:
.. code-block:: python
# Get weights of the default local policy
trainer.get_policy().get_weights()
algo.get_policy().get_weights()
# Same as above
trainer.workers.local_worker().policy_map["default_policy"].get_weights()
algo.workers.local_worker().policy_map["default_policy"].get_weights()
# Get list of weights of each worker, including remote replicas
trainer.workers.foreach_worker(lambda ev: ev.get_policy().get_weights())
algo.workers.foreach_worker(lambda ev: ev.get_policy().get_weights())
# Same as above
trainer.workers.foreach_worker_with_index(lambda ev, i: ev.get_policy().get_weights())
algo.workers.foreach_worker_with_index(lambda ev, i: ev.get_policy().get_weights())
Accessing Model State
~~~~~~~~~~~~~~~~~~~~~
@ -967,7 +975,9 @@ Advanced Python APIs
Custom Training Workflows
~~~~~~~~~~~~~~~~~~~~~~~~~
In the `basic training example <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__, Tune will call ``train()`` on your trainer once per training iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports :ref:`custom trainable functions <trainable-docs>` that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_train_fn.py>`__.
In the `basic training example <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__, Tune will call ``train()`` on your algorithm once per training iteration and report the new training results.
Sometimes, it is desirable to have full control over training, but still run inside Tune.
Tune supports :ref:`custom trainable functions <trainable-docs>` that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_train_fn.py>`__.
For even finer-grained control over training, you can use RLlib's lower-level `building blocks <rllib-concepts.html>`__ directly to implement `fully customized training workflows <https://github.com/ray-project/ray/blob/master/rllib/examples/rollout_worker_custom_workflow.py>`__.
@ -1003,7 +1013,7 @@ You can provide callbacks to be called at points during policy evaluation. These
User-defined state can be stored for the `episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__ in the ``episode.user_data`` dict, and custom scalar metrics reported by saving values to the ``episode.custom_metrics`` dict. These custom metrics will be aggregated and reported as part of training results. For a full example, see `custom_metrics_and_callbacks.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_metrics_and_callbacks.py>`__.
.. autoclass:: ray.rllib.agents.callbacks.DefaultCallbacks
.. autoclass:: ray.rllib.algorithms.callbacks.DefaultCallbacks
:members:
@ -1012,7 +1022,7 @@ Chaining Callbacks
Use the ``MultiCallbacks`` class to chaim multiple callbacks together.
.. autoclass:: ray.rllib.agents.callbacks.MultiCallbacks
.. autoclass:: ray.rllib.algorithms.callbacks.MultiCallbacks
Visualizing Custom Metrics
@ -1032,19 +1042,19 @@ exploration behavior, including the decisions (how and whether) to sample
actions from distributions (stochastically or deterministically).
The setup can be done via using built-in Exploration classes
(see `this package <https://github.com/ray-project/ray/blob/master/rllib/utils/exploration/>`__),
which are specified (and further configured) inside ``Trainer.config["exploration_config"]``.
which are specified (and further configured) inside ``Algorithm.config["exploration_config"]``.
Besides using one of the available classes, one can sub-class any of
these built-ins, add custom behavior to it, and use that new class in
the config instead.
Every policy has-an Exploration object, which is created from the Trainers
Every policy has-an Exploration object, which is created from the Algorithms
``config[“exploration_config”]`` dict, which specifies the class to use via the
special “type” key, as well as constructor arguments via all other keys,
e.g.:
.. code-block:: python
# in Trainer.config:
# in Algorithm.config:
"exploration_config": {
"type": "StochasticSampling", # <- Special `type` key provides class information
"[c'tor arg]" : "[value]", # <- Add any needed constructor args here.
@ -1070,19 +1080,19 @@ b) log-likelihood:
:start-after: __sphinx_doc_begin_get_exploration_action__
:end-before: __sphinx_doc_end_get_exploration_action__
On the highest level, the ``Trainer.compute_action`` and ``Policy.compute_action(s)``
On the highest level, the ``Algorithm.compute_actions`` and ``Policy.compute_actions``
methods have a boolean ``explore`` switch, which is passed into
``Exploration.get_exploration_action``. If ``explore=None``, the value of
``Trainer.config[“explore”]`` is used, which thus serves as a main switch for
``Algorithm.config[“explore”]`` is used, which thus serves as a main switch for
exploratory behavior, allowing e.g. turning off any exploration easily for
evaluation purposes (see :ref:`CustomEvaluation`).
The following are example excerpts from different Trainers' configs
(see rllib/agents/trainer.py) to setup different exploration behaviors:
The following are example excerpts from different Algorithms' configs
(see rllib/algorithms/algorithm.py) to setup different exploration behaviors:
.. code-block:: python
# All of the following configs go into Trainer.config.
# All of the following configs go into Algorithm.config.
# 1) Switching *off* exploration by default.
# Behavior: Calling `compute_action(s)` without explicitly setting its `explore`
@ -1119,10 +1129,10 @@ The following are example excerpts from different Trainers' configs
"temperature": 1.0,
},
# c) All policy-gradient algos and SAC: see rllib/agents/trainer.py
# c) All policy-gradient algos and SAC: see rllib/algorithms/algorithm.py
# Behavior: The algo samples stochastically from the
# model-parameterized distribution. This is the global Trainer default
# setting defined in trainer.py and used by all PG-type algos (plus SAC).
# model-parameterized distribution. This is the global Algorithm default
# setting defined in algorithm.py and used by all PG-type algos (plus SAC).
"explore": True,
"exploration_config": {
"type": "StochasticSampling",
@ -1137,13 +1147,13 @@ Customized Evaluation During Training
RLlib will report online training rewards, however in some cases you may want to compute
rewards with different settings (e.g., with exploration turned off, or on a specific set
of environment configurations). You can activate evaluating policies during training (``Trainer.train()``) by setting
the ``evaluation_interval`` to an int value (> 0) indicating every how many ``Trainer.train()``
of environment configurations). You can activate evaluating policies during training (``Algorithm.train()``) by setting
the ``evaluation_interval`` to an int value (> 0) indicating every how many ``Algorithm.train()``
calls an "evaluation step" is run:
.. code-block:: python
# Run one evaluation step on every 3rd `Trainer.train()` call.
# Run one evaluation step on every 3rd `Algorithm.train()` call.
{
"evaluation_interval": 3,
}
@ -1186,7 +1196,7 @@ roughly as long as the train step:
.. code-block:: python
# Run eval and train at the same time via threading and make sure they roughly
# take the same time, such that the next `Trainer.train()` call can execute
# take the same time, such that the next `Algorithm.train()` call can execute
# immediately and not have to wait for a still ongoing (e.g. very long episode)
# evaluation step:
{
@ -1204,8 +1214,8 @@ do:
.. code-block:: python
# Switching off exploration behavior for evaluation workers
# (see rllib/agents/trainer.py). Use any keys in this sub-dict that are
# also supported in the main Trainer config.
# (see rllib/algorithms/algorithm.py). Use any keys in this sub-dict that are
# also supported in the main Algorithm config.
"evaluation_config": {
"explore": False
}
@ -1223,8 +1233,8 @@ run as much in parallel as possible. For example, if your ``evaluation_duration=
only has to run 1 episode in each eval step.
In case you would like to entirely customize the evaluation step, set ``custom_eval_function`` in your
config to a callable taking the Trainer object and a WorkerSet object (the evaluation WorkerSet)
and returning a metrics dict. See `trainer.py <https://github.com/ray-project/ray/blob/master/rllib/agents/trainer.py>`__
config to a callable taking the Algorithm object and a WorkerSet object (the evaluation WorkerSet)
and returning a metrics dict. See `algorithm.py <https://github.com/ray-project/ray/blob/master/rllib/algorithms/algorithm.py>`__
for further documentation.
There is an end to end example of how to set up custom online evaluation in `custom_eval.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_eval.py>`__. Note that if you only want to eval your policy at the end of training, you can set ``evaluation_interval: N``, where ``N`` is the number of training iterations before stopping.
@ -1237,12 +1247,12 @@ Below are some examples of how the custom evaluation metrics are reported nested
Sample output for `python custom_eval.py`
------------------------------------------------------------------------
INFO trainer.py:623 -- Evaluating current policy for 10 episodes.
INFO trainer.py:650 -- Running round 0 of parallel evaluation (2/10 episodes)
INFO trainer.py:650 -- Running round 1 of parallel evaluation (4/10 episodes)
INFO trainer.py:650 -- Running round 2 of parallel evaluation (6/10 episodes)
INFO trainer.py:650 -- Running round 3 of parallel evaluation (8/10 episodes)
INFO trainer.py:650 -- Running round 4 of parallel evaluation (10/10 episodes)
INFO algorithm.py:623 -- Evaluating current policy for 10 episodes.
INFO algorithm.py:650 -- Running round 0 of parallel evaluation (2/10 episodes)
INFO algorithm.py:650 -- Running round 1 of parallel evaluation (4/10 episodes)
INFO algorithm.py:650 -- Running round 2 of parallel evaluation (6/10 episodes)
INFO algorithm.py:650 -- Running round 3 of parallel evaluation (8/10 episodes)
INFO algorithm.py:650 -- Running round 4 of parallel evaluation (10/10 episodes)
Result for PG_SimpleCorridor_2c6b27dc:
...
@ -1326,17 +1336,17 @@ which receives the last training results and returns a new task for the env to b
new_task = current_task + 1
return new_task
# Setup your Trainer's config like so:
# Setup your Algorithm's config like so:
config = {
"env": MyEnv,
"env_task_fn": curriculum_fn,
}
# Train using `tune.run` or `Trainer.train()` and the above config stub.
# Train using `tune.run` or `Algorithm.train()` and the above config stub.
# ...
There are two more ways to use the RLlib's other APIs to implement `curriculum learning <https://bair.berkeley.edu/blog/2017/12/20/reverse-curriculum/>`__.
Use the Trainer API and update the environment between calls to ``train()``. This example shows the trainer being run inside a Tune function.
Use the Algorithm API and update the environment between calls to ``train()``. This example shows the algorithm being run inside a Tune function.
This is basically the same as what the built-in `env_task_fn` API described above already does under the hood, but allows you to do even more
customizations to your training loop.
@ -1347,9 +1357,9 @@ customizations to your training loop.
from ray.rllib.algorithms.ppo import PPO
def train(config, reporter):
trainer = PPO(config=config, env=YourEnv)
algo = PPO(config=config, env=YourEnv)
while True:
result = trainer.train()
result = algo.train()
reporter(**result)
if result["episode_reward_mean"] > 200:
task = 2
@ -1357,7 +1367,7 @@ customizations to your training loop.
task = 1
else:
task = 0
trainer.workers.foreach_worker(
algo.workers.foreach_worker(
lambda ev: ev.foreach_env(
lambda env: env.set_task(task)))
@ -1382,28 +1392,26 @@ You could also use RLlib's callbacks API to update the environment on new traini
import ray
from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks
def on_train_result(info):
result = info["result"]
if result["episode_reward_mean"] > 200:
task = 2
elif result["episode_reward_mean"] > 100:
task = 1
else:
task = 0
trainer = info["trainer"]
trainer.workers.foreach_worker(
lambda ev: ev.foreach_env(
lambda env: env.set_task(task)))
class MyCallbacks(DefaultCallbacks):
def on_train_result(self, algorithm, result, **kwargs):
if result["episode_reward_mean"] > 200:
task = 2
elif result["episode_reward_mean"] > 100:
task = 1
else:
task = 0
algorithm.workers.foreach_worker(
lambda ev: ev.foreach_env(
lambda env: env.set_task(task)))
ray.init()
tune.run(
"PPO",
config={
"env": YourEnv,
"callbacks": {
"on_train_result": on_train_result,
},
"callbacks": MyCallbacks,
},
)
@ -1444,8 +1452,8 @@ However, eager can be slower than graph mode unless tracing is enabled.
Using PyTorch
~~~~~~~~~~~~~
Trainers that have an implemented TorchPolicy, will allow you to run
`rllib train` using the command line ``--torch`` flag.
Algorithms that have an implemented TorchPolicy, will allow you to run
`rllib train` using the command line ``--framework=torch`` flag.
Algorithms that do not have a torch version yet will complain with an error in
this case.
@ -1467,7 +1475,10 @@ You can use the `data output API <rllib-offline.html>`__ to save episode traces
Log Verbosity
~~~~~~~~~~~~~
You can control the trainer log level via the ``"log_level"`` flag. Valid values are "DEBUG", "INFO", "WARN" (default), and "ERROR". This can be used to increase or decrease the verbosity of internal logging. You can also use the ``-v`` and ``-vv`` flags. For example, the following two commands are about equivalent:
You can control the log level via the ``"log_level"`` flag. Valid values are "DEBUG",
"INFO", "WARN" (default), and "ERROR". This can be used to increase or decrease the
verbosity of internal logging. You can also use the ``-v`` and ``-vv`` flags.
For example, the following two commands are about equivalent:
.. code-block:: bash

View file

@ -11,12 +11,12 @@ from ray.air.preprocessor import Preprocessor
from ray.air.checkpoint import Checkpoint
from ray.train.rl import RLTrainer
from ray.rllib.agents import Trainer
from ray.rllib.algorithms import Algorithm
from ray.rllib.policy import Policy
from ray.tune.utils.trainable import TrainableUtil
class _DummyTrainer(Trainer):
class _DummyAlgo(Algorithm):
train_exec_impl = None
def setup(self, config):
@ -59,7 +59,7 @@ def create_checkpoint(
preprocessor: Optional[Preprocessor] = None, config: Optional[dict] = None
) -> Checkpoint:
rl_trainer = RLTrainer(
algorithm=_DummyTrainer,
algorithm=_DummyAlgo,
config=config or {},
preprocessor=preprocessor,
)

View file

@ -10,9 +10,9 @@ from ray.air._internal.checkpointing import (
load_preprocessor_from_dir,
save_preprocessor_to_dir,
)
from ray.rllib.agents.trainer import Trainer as RLlibTrainer
from ray.rllib.algorithms.algorithm import Algorithm as RLlibAlgo
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import PartialTrainerConfigDict, EnvType
from ray.rllib.utils.typing import PartialAlgorithmConfigDict, EnvType
from ray.tune import Trainable, PlacementGroupFactory
from ray.tune.logger import Logger
from ray.tune.registry import get_trainable_cls
@ -86,7 +86,7 @@ class RLTrainer(BaseTrainer):
import ray
from ray.air.config import RunConfig
from ray.train.rl import RLTrainer
from ray.rllib.agents.marwil.bc import BCTrainer
from ray.rllib.algorithms.bc.bc import BC
dataset = ray.data.read_json(
"/tmp/data-dir", parallelism=2, ray_remote_args={"num_cpus": 1}
@ -114,7 +114,7 @@ class RLTrainer(BaseTrainer):
def __init__(
self,
algorithm: Union[str, Type[RLlibTrainer]],
algorithm: Union[str, Type[RLlibAlgo]],
config: Optional[Dict[str, Any]] = None,
scaling_config: Optional[ScalingConfig] = None,
run_config: Optional[RunConfig] = None,
@ -137,8 +137,7 @@ class RLTrainer(BaseTrainer):
super(RLTrainer, self)._validate_attributes()
if not isinstance(self._algorithm, str) and not (
inspect.isclass(self._algorithm)
and issubclass(self._algorithm, RLlibTrainer)
inspect.isclass(self._algorithm) and issubclass(self._algorithm, RLlibAlgo)
):
raise ValueError(
f"`algorithm` should be either a string or a RLlib trainer class, "
@ -201,7 +200,7 @@ class RLTrainer(BaseTrainer):
class AIRRLTrainer(rllib_trainer):
def __init__(
self,
config: Optional[PartialTrainerConfigDict] = None,
config: Optional[PartialAlgorithmConfigDict] = None,
env: Optional[Union[str, EnvType]] = None,
logger_creator: Optional[Callable[[], Logger]] = None,
remote_checkpoint_dir: Optional[str] = None,
@ -241,7 +240,7 @@ class RLTrainer(BaseTrainer):
@classmethod
def default_resource_request(
cls, config: PartialTrainerConfigDict
cls, config: PartialAlgorithmConfigDict
) -> Union[Resources, PlacementGroupFactory]:
resolved_config = merge_dicts(base_config, config)
param_dict["config"] = resolved_config

View file

@ -56,7 +56,13 @@ class _CallbackMeta(ABCMeta):
@classmethod
def need_override_by_subclass(mcs, attr_name: str, attr: Any) -> bool:
return (attr_name.startswith("on_") or attr_name == "setup") and callable(attr)
return (
(
attr_name.startswith("on_")
and not attr_name.startswith("on_trainer_init")
)
or attr_name == "setup"
) and callable(attr)
@PublicAPI(stability="beta")

View file

@ -8,7 +8,7 @@ import uuid
import ray._private.utils
from ray.rllib.agents.mock import _MockTrainer
from ray.rllib.algorithms.mock import _MockTrainer
from ray.tune import Trainable
from ray.tune.callback import Callback
from ray.tune.sync_client import get_sync_client

View file

@ -7,7 +7,7 @@ import time
import ray
from ray import tune
from ray.rllib.agents import DefaultCallbacks
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.ppo import PPO
@ -34,7 +34,7 @@ def fn_trainable(config, checkpoint_dir=None):
class RLlibCallback(DefaultCallbacks):
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
def on_train_result(self, *, algorithm, result: dict, **kwargs) -> None:
result["internal_iter"] = result["training_iteration"]

View file

@ -1796,6 +1796,13 @@ py_test(
# for `tests/test_all_stuff.py`.
# --------------------------------------------------------------------
py_test(
name = "tests/backward_compat/test_backward_compat",
tags = ["team:rllib", "tests_dir", "tests_dir_B"],
size = "medium",
srcs = ["tests/backward_compat/test_backward_compat.py"]
)
py_test(
name = "tests/test_algorithm_imports",
tags = ["team:rllib", "tests_dir", "tests_dir_C"],
@ -2626,7 +2633,7 @@ py_test(
tags = ["team:rllib", "exclusive", "multi_gpu", "examples"],
size = "medium",
srcs = ["examples/deterministic_training.py"],
args = ["--as-test", "--stop-iters=1", "--framework=tf", "--num-gpus-trainer=1", "--num-gpus-per-worker=1"]
args = ["--as-test", "--stop-iters=1", "--framework=tf", "--num-gpus=1", "--num-gpus-per-worker=1"]
)
py_test(
@ -2635,7 +2642,7 @@ py_test(
tags = ["team:rllib", "exclusive", "multi_gpu", "examples"],
size = "medium",
srcs = ["examples/deterministic_training.py"],
args = ["--as-test", "--stop-iters=1", "--framework=tf2", "--num-gpus-trainer=1", "--num-gpus-per-worker=1"]
args = ["--as-test", "--stop-iters=1", "--framework=tf2", "--num-gpus=1", "--num-gpus-per-worker=1"]
)
py_test(
@ -2644,7 +2651,7 @@ py_test(
tags = ["team:rllib", "exclusive", "multi_gpu", "examples"],
size = "medium",
srcs = ["examples/deterministic_training.py"],
args = ["--as-test", "--stop-iters=1", "--framework=torch", "--num-gpus-trainer=1", "--num-gpus-per-worker=1"]
args = ["--as-test", "--stop-iters=1", "--framework=torch", "--num-gpus=1", "--num-gpus-per-worker=1"]
)
py_test(

View file

@ -174,9 +174,9 @@ Quick First Experiment
return self.cur_obs, reward, done, {}
# Create an RLlib Trainer instance to learn how to act in the above
# Create an RLlib Algorithm instance to learn how to act in the above
# environment.
trainer = PPO(
algo = PPO(
config={
# Env class to use (here: our gym.Env sub-class from above).
"env": ParrotEnv,
@ -193,7 +193,7 @@ Quick First Experiment
# (exact match between observation and action value),
# we can expect to reach an optimal episode reward of 0.0.
for i in range(5):
results = trainer.train()
results = algo.train()
print(f"Iter: {i}; avg. reward={results['episode_reward_mean']}")
@ -220,7 +220,7 @@ and `attention nets <https://github.com/ray-project/ray/blob/master/rllib/exampl
while not done:
# Compute a single action, given the current observation
# from the environment.
action = trainer.compute_single_action(obs)
action = algo.compute_single_action(obs)
# Apply the computed action in the environment.
obs, reward, done, info = env.step(action)
# Sum up rewards for reporting purposes.

View file

@ -28,8 +28,8 @@ def _setup_logger():
def _register_all():
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.registry import ALGORITHMS, get_trainer_class
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.registry import ALGORITHMS, get_algorithm_class
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
for key in (
@ -38,12 +38,12 @@ def _register_all():
+ ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]
):
logging.warning(key)
register_trainable(key, get_trainer_class(key))
register_trainable(key, get_algorithm_class(key))
def _see_contrib(name):
"""Returns dummy agent class warning algo is in contrib/."""
class _SeeContrib(Trainer):
class _SeeContrib(Algorithm):
def setup(self, config):
raise NameError("Please run `contrib/{}` instead.".format(name))

View file

@ -1,15 +1,7 @@
from ray.rllib.agents.callbacks import (
DefaultCallbacks,
MemoryTrackingCallbacks,
MultiCallbacks,
)
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.algorithms.algorithm import Algorithm as Trainer, with_common_config
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig as TrainerConfig
__all__ = [
"DefaultCallbacks",
"MemoryTrackingCallbacks",
"MultiCallbacks",
"Trainer",
"TrainerConfig",
"with_common_config",

View file

@ -1,3 +1,4 @@
import ray.rllib.agents.a3c.a2c as a2c # noqa
from ray.rllib.algorithms.a2c.a2c import (
A2CConfig,
A2C as A2CTrainer,

4
rllib/agents/a3c/a2c.py Normal file
View file

@ -0,0 +1,4 @@
from ray.rllib.algorithms.a2c import ( # noqa
A2C as A2CTrainer,
A2C_DEFAULT_CONFIG,
)

6
rllib/agents/a3c/a3c.py Normal file
View file

@ -0,0 +1,6 @@
from ray.rllib.algorithms.a3c import ( # noqa
a3c_tf_policy,
a3c_torch_policy,
A3C as A3CTrainer,
DEFAULT_CONFIG,
)

View file

@ -1,4 +1,4 @@
from ray.rllib.algorithms.ars.ars import ARSTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.ars.ars import ARS as ARSTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.ars.ars_tf_policy import ARSTFPolicy
from ray.rllib.algorithms.ars.ars_torch_policy import ARSTorchPolicy

View file

@ -1,570 +1,13 @@
import numpy as np
import os
import tracemalloc
from typing import Dict, Optional, Tuple, TYPE_CHECKING
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.exploration.random_encoder import (
_MovingMeanStd,
compute_states_entropy,
update_beta,
from ray.rllib.algorithms.callbacks import ( # noqa
DefaultCallbacks,
MemoryTrackingCallbacks,
MultiCallbacks,
RE3UpdateCallbacks,
)
from ray.rllib.utils.typing import AgentID, EnvType, PolicyID
from ray.tune.callback import _CallbackMeta
from ray.rllib.utils.deprecation import deprecation_warning
# Import psutil after ray so the packaged version is used.
import psutil
if TYPE_CHECKING:
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation import RolloutWorker
@PublicAPI
class DefaultCallbacks(metaclass=_CallbackMeta):
"""Abstract base class for RLlib callbacks (similar to Keras callbacks).
These callbacks can be used for custom metrics and custom postprocessing.
By default, all of these callbacks are no-ops. To configure custom training
callbacks, subclass DefaultCallbacks and then set
{"callbacks": YourCallbacksClass} in the trainer config.
"""
def __init__(self, legacy_callbacks_dict: Dict[str, callable] = None):
if legacy_callbacks_dict:
deprecation_warning(
"callbacks dict interface",
"a class extending rllib.agents.callbacks.DefaultCallbacks",
)
self.legacy_callbacks = legacy_callbacks_dict or {}
def on_sub_environment_created(
self,
*,
worker: "RolloutWorker",
sub_environment: EnvType,
env_context: EnvContext,
**kwargs,
) -> None:
"""Callback run when a new sub-environment has been created.
This method gets called after each sub-environment (usually a
gym.Env) has been created, validated (RLlib built-in validation
+ possible custom validation function implemented by overriding
`Trainer.validate_env()`), wrapped (e.g. video-wrapper), and seeded.
Args:
worker: Reference to the current rollout worker.
sub_environment: The sub-environment instance that has been
created. This is usually a gym.Env object.
env_context: The `EnvContext` object that has been passed to
the env's constructor.
kwargs: Forward compatibility placeholder.
"""
pass
def on_trainer_init(
self,
*,
trainer: "Trainer",
**kwargs,
) -> None:
"""Callback run when a new trainer instance has finished setup.
This method gets called at the end of Trainer.setup() after all
the initialization is done, and before actually training starts.
Args:
trainer: Reference to the trainer instance.
kwargs: Forward compatibility placeholder.
"""
pass
def on_episode_start(
self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Episode,
**kwargs,
) -> None:
"""Callback run on the rollout worker before each episode starts.
Args:
worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies: Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
episode: Episode object which contains the episode's
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_episode_start"):
self.legacy_callbacks["on_episode_start"](
{
"env": base_env,
"policy": policies,
"episode": episode,
}
)
def on_episode_step(
self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Optional[Dict[PolicyID, Policy]] = None,
episode: Episode,
**kwargs,
) -> None:
"""Runs on each episode step.
Args:
worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies: Mapping of policy id to policy objects.
In single agent mode there will only be a single
"default_policy".
episode: Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_episode_step"):
self.legacy_callbacks["on_episode_step"](
{"env": base_env, "episode": episode}
)
def on_episode_end(
self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Episode,
**kwargs,
) -> None:
"""Runs when an episode is done.
Args:
worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies: Mapping of policy id to policy
objects. In single agent mode there will only be a single
"default_policy".
episode: Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_episode_end"):
self.legacy_callbacks["on_episode_end"](
{
"env": base_env,
"policy": policies,
"episode": episode,
}
)
def on_postprocess_trajectory(
self,
*,
worker: "RolloutWorker",
episode: Episode,
agent_id: AgentID,
policy_id: PolicyID,
policies: Dict[PolicyID, Policy],
postprocessed_batch: SampleBatch,
original_batches: Dict[AgentID, Tuple[Policy, SampleBatch]],
**kwargs,
) -> None:
"""Called immediately after a policy's postprocess_fn is called.
You can use this callback to do additional postprocessing for a policy,
including looking at the trajectory data of other agents in multi-agent
settings.
Args:
worker: Reference to the current rollout worker.
episode: Episode object.
agent_id: Id of the current agent.
policy_id: Id of the current policy for the agent.
policies: Mapping of policy id to policy objects. In single
agent mode there will only be a single "default_policy".
postprocessed_batch: The postprocessed sample batch
for this agent. You can mutate this object to apply your own
trajectory postprocessing.
original_batches: Mapping of agents to their unpostprocessed
trajectory data. You should not mutate this object.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_postprocess_traj"):
self.legacy_callbacks["on_postprocess_traj"](
{
"episode": episode,
"agent_id": agent_id,
"pre_batch": original_batches[agent_id],
"post_batch": postprocessed_batch,
"all_pre_batches": original_batches,
}
)
def on_sample_end(
self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs
) -> None:
"""Called at the end of RolloutWorker.sample().
Args:
worker: Reference to the current rollout worker.
samples: Batch to be returned. You can mutate this
object to modify the samples generated.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_sample_end"):
self.legacy_callbacks["on_sample_end"](
{
"worker": worker,
"samples": samples,
}
)
def on_learn_on_batch(
self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs
) -> None:
"""Called at the beginning of Policy.learn_on_batch().
Note: This is called before 0-padding via
`pad_batch_to_sequences_of_same_size`.
Also note, SampleBatch.INFOS column will not be available on
train_batch within this callback if framework is tf1, due to
the fact that tf1 static graph would mistake it as part of the
input dict if present.
It is available though, for tf2 and torch frameworks.
Args:
policy: Reference to the current Policy object.
train_batch: SampleBatch to be trained on. You can
mutate this object to modify the samples generated.
result: A results dict to add custom metrics to.
kwargs: Forward compatibility placeholder.
"""
pass
def on_train_result(self, *, trainer: "Trainer", result: dict, **kwargs) -> None:
"""Called at the end of Trainable.train().
Args:
trainer: Current trainer instance.
result: Dict of results returned from trainer.train() call.
You can mutate this object to add additional metrics.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_train_result"):
self.legacy_callbacks["on_train_result"](
{
"trainer": trainer,
"result": result,
}
)
class MemoryTrackingCallbacks(DefaultCallbacks):
"""MemoryTrackingCallbacks can be used to trace and track memory usage
in rollout workers.
The Memory Tracking Callbacks uses tracemalloc and psutil to track
python allocations during rollouts,
in training or evaluation.
The tracking data is logged to the custom_metrics of an episode and
can therefore be viewed in tensorboard
(or in WandB etc..)
Add MemoryTrackingCallbacks callback to the tune config
e.g. { ...'callbacks': MemoryTrackingCallbacks ...}
Note:
This class is meant for debugging and should not be used
in production code as tracemalloc incurs
a significant slowdown in execution speed.
"""
def __init__(self):
super().__init__()
# Will track the top 10 lines where memory is allocated
tracemalloc.start(10)
def on_episode_end(
self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Episode,
env_index: Optional[int] = None,
**kwargs,
) -> None:
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics("lineno")
for stat in top_stats[:10]:
count = stat.count
size = stat.size
trace = str(stat.traceback)
episode.custom_metrics[f"tracemalloc/{trace}/size"] = size
episode.custom_metrics[f"tracemalloc/{trace}/count"] = count
process = psutil.Process(os.getpid())
worker_rss = process.memory_info().rss
worker_data = process.memory_info().data
worker_vms = process.memory_info().vms
episode.custom_metrics["tracemalloc/worker/rss"] = worker_rss
episode.custom_metrics["tracemalloc/worker/data"] = worker_data
episode.custom_metrics["tracemalloc/worker/vms"] = worker_vms
class MultiCallbacks(DefaultCallbacks):
"""MultiCallbacks allows multiple callbacks to be registered at
the same time in the config of the environment.
Example:
.. code-block:: python
'callbacks': MultiCallbacks([
MyCustomStatsCallbacks,
MyCustomVideoCallbacks,
MyCustomTraceCallbacks,
....
])
"""
IS_CALLBACK_CONTAINER = True
def __init__(self, callback_class_list):
super().__init__()
self._callback_class_list = callback_class_list
self._callback_list = []
def __call__(self, *args, **kwargs):
self._callback_list = [
callback_class() for callback_class in self._callback_class_list
]
return self
def on_trainer_init(self, *, trainer: "Trainer", **kwargs) -> None:
for callback in self._callback_list:
callback.on_trainer_init(trainer=trainer, **kwargs)
def on_sub_environment_created(
self,
*,
worker: "RolloutWorker",
sub_environment: EnvType,
env_context: EnvContext,
**kwargs,
) -> None:
for callback in self._callback_list:
callback.on_sub_environment_created(
worker=worker,
sub_environment=sub_environment,
env_context=env_context,
**kwargs,
)
def on_episode_start(
self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Episode,
env_index: Optional[int] = None,
**kwargs,
) -> None:
for callback in self._callback_list:
callback.on_episode_start(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode,
env_index=env_index,
**kwargs,
)
def on_episode_step(
self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Optional[Dict[PolicyID, Policy]] = None,
episode: Episode,
env_index: Optional[int] = None,
**kwargs,
) -> None:
for callback in self._callback_list:
callback.on_episode_step(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode,
env_index=env_index,
**kwargs,
)
def on_episode_end(
self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Episode,
env_index: Optional[int] = None,
**kwargs,
) -> None:
for callback in self._callback_list:
callback.on_episode_end(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode,
env_index=env_index,
**kwargs,
)
def on_postprocess_trajectory(
self,
*,
worker: "RolloutWorker",
episode: Episode,
agent_id: AgentID,
policy_id: PolicyID,
policies: Dict[PolicyID, Policy],
postprocessed_batch: SampleBatch,
original_batches: Dict[AgentID, Tuple[Policy, SampleBatch]],
**kwargs,
) -> None:
for callback in self._callback_list:
callback.on_postprocess_trajectory(
worker=worker,
episode=episode,
agent_id=agent_id,
policy_id=policy_id,
policies=policies,
postprocessed_batch=postprocessed_batch,
original_batches=original_batches,
**kwargs,
)
def on_sample_end(
self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs
) -> None:
for callback in self._callback_list:
callback.on_sample_end(worker=worker, samples=samples, **kwargs)
def on_learn_on_batch(
self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs
) -> None:
for callback in self._callback_list:
callback.on_learn_on_batch(
policy=policy, train_batch=train_batch, result=result, **kwargs
)
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
for callback in self._callback_list:
callback.on_train_result(trainer=trainer, result=result, **kwargs)
# This Callback is used by the RE3 exploration strategy.
# See rllib/examples/re3_exploration.py for details.
class RE3UpdateCallbacks(DefaultCallbacks):
"""Update input callbacks to mutate batch with states entropy rewards."""
_step = 0
def __init__(
self,
*args,
embeds_dim: int = 128,
k_nn: int = 50,
beta: float = 0.1,
rho: float = 0.0001,
beta_schedule: str = "constant",
**kwargs,
):
self.embeds_dim = embeds_dim
self.k_nn = k_nn
self.beta = beta
self.rho = rho
self.beta_schedule = beta_schedule
self._rms = _MovingMeanStd()
super().__init__(*args, **kwargs)
def on_learn_on_batch(
self,
*,
policy: Policy,
train_batch: SampleBatch,
result: dict,
**kwargs,
):
super().on_learn_on_batch(
policy=policy, train_batch=train_batch, result=result, **kwargs
)
states_entropy = compute_states_entropy(
train_batch[SampleBatch.OBS_EMBEDS], self.embeds_dim, self.k_nn
)
states_entropy = update_beta(
self.beta_schedule, self.beta, self.rho, RE3UpdateCallbacks._step
) * np.reshape(
self._rms(states_entropy),
train_batch[SampleBatch.OBS_EMBEDS].shape[:-1],
)
train_batch[SampleBatch.REWARDS] = (
train_batch[SampleBatch.REWARDS] + states_entropy
)
if Postprocessing.ADVANTAGES in train_batch:
train_batch[Postprocessing.ADVANTAGES] = (
train_batch[Postprocessing.ADVANTAGES] + states_entropy
)
train_batch[Postprocessing.VALUE_TARGETS] = (
train_batch[Postprocessing.VALUE_TARGETS] + states_entropy
)
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
# TODO(gjoliver): Remove explicit _step tracking and pass
# trainer._iteration as a parameter to on_learn_on_batch() call.
RE3UpdateCallbacks._step = result["training_iteration"]
super().on_train_result(trainer=trainer, result=result, **kwargs)
deprecation_warning(
old="ray.rllib.agents.callbacks",
new="ray.rllib.algorithms.callbacks",
error=False,
)

View file

@ -1,3 +1,5 @@
import ray.rllib.agents.ddpg.apex as apex # noqa
import ray.rllib.agents.ddpg.td3 as td3 # noqa
from ray.rllib.algorithms.apex_ddpg.apex_ddpg import ApexDDPG as ApexDDPGTrainer
from ray.rllib.algorithms.ddpg.ddpg import (
DDPGConfig,

View file

@ -0,0 +1,4 @@
from ray.rllib.algorithms.apex_ddpg import ( # noqa
ApexDDPG as ApexDDPGTrainer,
APEX_DDPG_DEFAULT_CONFIG,
)

View file

@ -0,0 +1,6 @@
from ray.rllib.algorithms.ddpg import ( # noqa
ddpg_tf_policy,
ddpg_torch_policy,
DDPG as DDPGTrainer,
DEFAULT_CONFIG,
)

4
rllib/agents/ddpg/td3.py Normal file
View file

@ -0,0 +1,4 @@
from ray.rllib.algorithms.td3 import ( # noqa
TD3 as TD3Trainer,
TD3_DEFAULT_CONFIG,
)

View file

@ -1,3 +1,5 @@
import ray.rllib.agents.dqn.apex as apex # noqa
import ray.rllib.agents.dqn.simple_q as simple_q # noqa
from ray.rllib.algorithms.apex_dqn.apex_dqn import (
ApexDQNConfig,
ApexDQN as ApexTrainer,

4
rllib/agents/dqn/apex.py Normal file
View file

@ -0,0 +1,4 @@
from ray.rllib.algorithms.apex_dqn import ( # noqa
ApexDQN as ApexTrainer,
APEX_DEFAULT_CONFIG,
)

6
rllib/agents/dqn/dqn.py Normal file
View file

@ -0,0 +1,6 @@
from ray.rllib.algorithms.dqn import ( # noqa
dqn_tf_policy,
dqn_torch_policy,
DQN as DQNTrainer,
DEFAULT_CONFIG,
)

View file

@ -0,0 +1,4 @@
from ray.rllib.algorithms.simple_q import ( # noqa
SimpleQ as SimpleQTrainer,
DEFAULT_CONFIG,
)

View file

@ -1,4 +1,4 @@
from ray.rllib.algorithms.es.es import ESTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.es.es import ES as ESTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.es.es_tf_policy import ESTFPolicy
from ray.rllib.algorithms.es.es_torch_policy import ESTorchPolicy

View file

@ -1,158 +1,13 @@
import os
import pickle
import numpy as np
from ray.rllib.algorithms.mock import ( # noqa
_MockTrainer,
_ParameterTuningTrainer,
_SigmoidFakeData,
)
from ray.tune import result as tune_result
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.deprecation import deprecation_warning
class _MockTrainer(Trainer):
"""Mock trainer for use in tests"""
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return with_common_config(
{
"mock_error": False,
"persistent_error": False,
"test_variable": 1,
"num_workers": 0,
"user_checkpoint_freq": 0,
"framework": "tf",
}
)
@classmethod
def default_resource_request(cls, config):
return None
@override(Trainer)
def setup(self, config):
# Setup our config: Merge the user-supplied config (which could
# be a partial config dict with the class' default).
self.config = self.merge_trainer_configs(
self.get_default_config(), config, self._allow_unknown_configs
)
self.config["env"] = self._env_id
self.validate_config(self.config)
self.callbacks = self.config["callbacks"]()
# Add needed properties.
self.info = None
self.restored = False
@override(Trainer)
def step(self):
if (
self.config["mock_error"]
and self.iteration == 1
and (self.config["persistent_error"] or not self.restored)
):
raise Exception("mock error")
result = dict(
episode_reward_mean=10, episode_len_mean=10, timesteps_this_iter=10, info={}
)
if self.config["user_checkpoint_freq"] > 0 and self.iteration > 0:
if self.iteration % self.config["user_checkpoint_freq"] == 0:
result.update({tune_result.SHOULD_CHECKPOINT: True})
return result
@override(Trainer)
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
with open(path, "wb") as f:
pickle.dump(self.info, f)
return path
@override(Trainer)
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path, "rb") as f:
info = pickle.load(f)
self.info = info
self.restored = True
@staticmethod
@override(Trainer)
def _get_env_id_and_creator(env_specifier, config):
# No env to register.
return None, None
def set_info(self, info):
self.info = info
return info
def get_info(self, sess=None):
return self.info
class _SigmoidFakeData(_MockTrainer):
"""Trainer that returns sigmoid learning curves.
This can be helpful for evaluating early stopping algorithms."""
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return with_common_config(
{
"width": 100,
"height": 100,
"offset": 0,
"iter_time": 10,
"iter_timesteps": 1,
"num_workers": 0,
}
)
def step(self):
i = max(0, self.iteration - self.config["offset"])
v = np.tanh(float(i) / self.config["width"])
v *= self.config["height"]
return dict(
episode_reward_mean=v,
episode_len_mean=v,
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"],
info={},
)
class _ParameterTuningTrainer(_MockTrainer):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return with_common_config(
{
"reward_amt": 10,
"dummy_param": 10,
"dummy_param2": 15,
"iter_time": 10,
"iter_timesteps": 1,
"num_workers": 0,
}
)
def step(self):
return dict(
episode_reward_mean=self.config["reward_amt"] * self.iteration,
episode_len_mean=self.config["reward_amt"],
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"],
info={},
)
def _trainer_import_failed(trace):
"""Returns dummy agent class for if PyTorch etc. is not installed."""
class _TrainerImportFailed(Trainer):
_name = "TrainerImportFailed"
def setup(self, config):
raise ImportError(trace)
return _TrainerImportFailed
deprecation_warning(
old="ray.rllib.agents.callbacks",
new="ray.rllib.algorithms.callbacks",
error=False,
)

View file

@ -1,3 +1,4 @@
import ray.rllib.agents.ppo.appo as appo # noqa
from ray.rllib.algorithms.ppo.ppo import PPOConfig, PPO as PPOTrainer, DEFAULT_CONFIG
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy, PPOTF2Policy
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy

4
rllib/agents/ppo/appo.py Normal file
View file

@ -0,0 +1,4 @@
from ray.rllib.algorithms.appo import ( # noqa
APPO as APPOTrainer,
DEFAULT_CONFIG,
)

View file

@ -0,0 +1,4 @@
from ray.rllib.algorithms.ddppo import ( # noqa
DDPPO as DDPPOTrainer,
DEFAULT_CONFIG,
)

6
rllib/agents/ppo/ppo.py Normal file
View file

@ -0,0 +1,6 @@
from ray.rllib.algorithms.ppo import ( # noqa
ppo_tf_policy,
ppo_torch_policy,
PPO as PPOTrainer,
DEFAULT_CONFIG,
)

View file

@ -1,7 +1,7 @@
import unittest
import ray
from ray.rllib.agents.callbacks import DefaultCallbacks, MultiCallbacks
from ray.rllib.algorithms.callbacks import DefaultCallbacks, MultiCallbacks
import ray.rllib.algorithms.dqn as dqn
from ray.rllib.utils.test_utils import framework_iterator

View file

@ -12,14 +12,14 @@ import ray.rllib.algorithms.a3c as a3c
import ray.rllib.algorithms.dqn as dqn
from ray.rllib.algorithms.bc import BC, BCConfig
import ray.rllib.algorithms.pg as pg
from ray.rllib.agents.trainer import COMMON_CONFIG
from ray.rllib.algorithms.algorithm import COMMON_CONFIG
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.examples.parallel_evaluation_and_training import AssertEvalCallback
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.test_utils import check, framework_iterator
class TestTrainer(unittest.TestCase):
class TestAlgorithm(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(num_cpus=6)
@ -35,20 +35,20 @@ class TestTrainer(unittest.TestCase):
"""
# Given:
standard_config = copy.deepcopy(COMMON_CONFIG)
trainer = pg.PG(env="CartPole-v0", config=standard_config)
algo = pg.PG(env="CartPole-v0", config=standard_config)
# When (we validate config 2 times).
# Try deprecated `Trainer._validate_config()` method (static).
trainer._validate_config(standard_config, trainer)
algo._validate_config(standard_config, algo)
config_v1 = copy.deepcopy(standard_config)
# Try new method: `Trainer.validate_config()` (non-static).
trainer.validate_config(standard_config)
algo.validate_config(standard_config)
config_v2 = copy.deepcopy(standard_config)
# Make sure nothing changed.
self.assertEqual(config_v1, config_v2)
trainer.stop()
algo.stop()
def test_add_delete_policy(self):
config = pg.DEFAULT_CONFIG.copy()
@ -84,9 +84,9 @@ class TestTrainer(unittest.TestCase):
)
for _ in framework_iterator(config):
trainer = pg.PG(config=config)
pol0 = trainer.get_policy("p0")
r = trainer.train()
algo = pg.PG(config=config)
pol0 = algo.get_policy("p0")
r = algo.train()
self.assertTrue("p0" in r["info"][LEARNER_INFO])
for i in range(1, 3):
@ -95,21 +95,21 @@ class TestTrainer(unittest.TestCase):
# Add a new policy.
pid = f"p{i}"
new_pol = trainer.add_policy(
new_pol = algo.add_policy(
pid,
trainer.get_default_policy_class(config),
algo.get_default_policy_class(config),
# Test changing the mapping fn.
policy_mapping_fn=new_mapping_fn,
# Change the list of policies to train.
policies_to_train=[f"p{i}", f"p{i-1}"],
)
pol_map = trainer.workers.local_worker().policy_map
pol_map = algo.workers.local_worker().policy_map
self.assertTrue(new_pol is not pol0)
for j in range(i + 1):
self.assertTrue(f"p{j}" in pol_map)
self.assertTrue(len(pol_map) == i + 1)
trainer.train()
checkpoint = trainer.save()
algo.train()
checkpoint = algo.save()
# Test restoring from the checkpoint (which has more policies
# than what's defined in the config dict).
@ -124,7 +124,7 @@ class TestTrainer(unittest.TestCase):
all(test.evaluation_workers.foreach_worker(_has_policy))
)
# Make sure trainer can continue training the restored policy.
# Make sure algorithm can continue training the restored policy.
pol0 = test.get_policy("p0")
test.train()
# Test creating an action with the added (and restored) policy.
@ -134,9 +134,9 @@ class TestTrainer(unittest.TestCase):
self.assertTrue(pol0.action_space.contains(a))
test.stop()
# Delete all added policies again from trainer.
# Delete all added policies again from Algorithm.
for i in range(2, 0, -1):
trainer.remove_policy(
algo.remove_policy(
f"p{i}",
# Note that the complete signature of a policy_mapping_fn
# is: `agent_id, episode, worker, **kwargs`.
@ -144,7 +144,7 @@ class TestTrainer(unittest.TestCase):
policies_to_train=[f"p{i - 1}"],
)
trainer.stop()
algo.stop()
def test_evaluation_option(self):
# Use a custom callback that asserts that we are running the
@ -164,18 +164,18 @@ class TestTrainer(unittest.TestCase):
)
for _ in framework_iterator(config, frameworks=("tf", "torch")):
trainer = config.build()
algo = config.build()
# Given evaluation_interval=2, r0, r2, r4 should not contain
# evaluation metrics, while r1, r3 should.
r0 = trainer.train()
r0 = algo.train()
print(r0)
r1 = trainer.train()
r1 = algo.train()
print(r1)
r2 = trainer.train()
r2 = algo.train()
print(r2)
r3 = trainer.train()
r3 = algo.train()
print(r3)
trainer.stop()
algo.stop()
self.assertFalse("evaluation" in r0)
self.assertTrue("evaluation" in r1)
@ -202,13 +202,13 @@ class TestTrainer(unittest.TestCase):
.callbacks(callbacks_class=AssertEvalCallback)
)
for _ in framework_iterator(config, frameworks=("tf", "torch")):
trainer = config.build()
algo = config.build()
# Should always see latest available eval results.
r0 = trainer.train()
r1 = trainer.train()
r2 = trainer.train()
r3 = trainer.train()
trainer.stop()
r0 = algo.train()
r1 = algo.train()
r2 = algo.train()
r3 = algo.train()
algo.stop()
# Eval results are not available at step 0.
# But step 3 should still have it, even though no eval was
@ -228,26 +228,26 @@ class TestTrainer(unittest.TestCase):
)
for _ in framework_iterator(frameworks=("tf", "torch")):
# Setup trainer w/o evaluation worker set and still call
# Setup algorithm w/o evaluation worker set and still call
# evaluate() -> Expect error.
trainer_wo_env_on_driver = config.build()
algo_wo_env_on_driver = config.build()
self.assertRaisesRegex(
ValueError,
"Cannot evaluate w/o an evaluation worker set",
trainer_wo_env_on_driver.evaluate,
algo_wo_env_on_driver.evaluate,
)
trainer_wo_env_on_driver.stop()
algo_wo_env_on_driver.stop()
# Try again using `create_env_on_driver=True`.
# This force-adds the env on the local-worker, so this Trainer
# can `evaluate` even though it doesn't have an evaluation-worker
# set.
config.create_env_on_local_worker = True
trainer_w_env_on_driver = config.build()
results = trainer_w_env_on_driver.evaluate()
algo_w_env_on_driver = config.build()
results = algo_w_env_on_driver.evaluate()
assert "evaluation" in results
assert "episode_reward_mean" in results["evaluation"]
trainer_w_env_on_driver.stop()
algo_w_env_on_driver.stop()
config.create_env_on_local_worker = False
def test_space_inference_from_remote_workers(self):
@ -264,20 +264,20 @@ class TestTrainer(unittest.TestCase):
# No env on driver -> expect longer build time due to space
# lookup from remote worker.
t0 = time.time()
trainer = config.build()
algo = config.build()
w_lookup = time.time() - t0
print(f"No env on learner: {w_lookup}sec")
trainer.stop()
algo.stop()
# Env on driver -> expect shorted build time due to no space
# lookup required from remote worker.
config.create_env_on_local_worker = True
t0 = time.time()
trainer = config.build()
algo = config.build()
wo_lookup = time.time() - t0
print(f"Env on learner: {wo_lookup}sec")
self.assertLess(wo_lookup, w_lookup)
trainer.stop()
algo.stop()
# Spaces given -> expect shorter build time due to no space
# lookup required from remote worker.
@ -287,11 +287,11 @@ class TestTrainer(unittest.TestCase):
action_space=env.action_space,
)
t0 = time.time()
trainer = config.build()
algo = config.build()
wo_lookup = time.time() - t0
print(f"Spaces given manually in config: {wo_lookup}sec")
self.assertLess(wo_lookup, w_lookup)
trainer.stop()
algo.stop()
def test_worker_validation_time(self):
"""Tests the time taken by `validate_workers_after_construction=True`."""
@ -302,17 +302,17 @@ class TestTrainer(unittest.TestCase):
# >> 1 workers.
config.num_workers = 1
t0 = time.time()
trainer = config.build()
algo = config.build()
total_time_1 = time.time() - t0
print(f"Validating w/ 1 worker: {total_time_1}sec")
trainer.stop()
algo.stop()
config.num_workers = 5
t0 = time.time()
trainer = config.build()
algo = config.build()
total_time_5 = time.time() - t0
print(f"Validating w/ 5 workers: {total_time_5}sec")
trainer.stop()
algo.stop()
check(total_time_5 / total_time_1, 1.0, atol=1.0)
@ -343,9 +343,9 @@ class TestTrainer(unittest.TestCase):
.offline_data(input_=[input_file])
)
bc_trainer = BC(config=offline_rl_config)
bc_trainer.train()
bc_trainer.stop()
bc = BC(config=offline_rl_config)
bc.train()
bc.stop()
if __name__ == "__main__":

View file

@ -3,7 +3,7 @@ import unittest
import ray
from ray.rllib import _register_all
from ray.rllib.agents.registry import get_trainer_class
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.utils.test_utils import framework_iterator
from ray.tune.registry import register_env
@ -66,7 +66,7 @@ class IgnoresWorkerFailure(unittest.TestCase):
def _do_test_fault_ignore(self, alg: str, config: dict):
register_env("fault_env", lambda c: FaultInjectEnv(c))
agent_cls = get_trainer_class(alg)
agent_cls = get_algorithm_class(alg)
# Test fault handling
config["num_workers"] = 2
@ -82,7 +82,7 @@ class IgnoresWorkerFailure(unittest.TestCase):
def _do_test_fault_fatal(self, alg, config):
register_env("fault_env", lambda c: FaultInjectEnv(c))
agent_cls = get_trainer_class(alg)
agent_cls = get_algorithm_class(alg)
# Test raises real error when out of workers
config["num_workers"] = 2
@ -97,7 +97,7 @@ class IgnoresWorkerFailure(unittest.TestCase):
def _do_test_fault_fatal_but_recreate(self, alg, config):
register_env("fault_env", lambda c: FaultInjectEnv(c))
agent_cls = get_trainer_class(alg)
agent_cls = get_algorithm_class(alg)
# Test raises real error when out of workers
config["num_workers"] = 2

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,8 @@
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
__all__ = [
"Algorithm",
"AlgorithmConfig",
]

View file

@ -2,7 +2,7 @@ import logging
import math
from typing import Optional
from ray.rllib.agents.trainer import Trainer
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.a3c.a3c import A3CConfig, A3C
from ray.rllib.execution.common import (
STEPS_TRAINED_COUNTER,
@ -22,9 +22,9 @@ from ray.rllib.utils.metrics import (
WORKER_UPDATE_TIMER,
)
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
PartialAlgorithmConfigDict,
ResultDict,
TrainerConfigDict,
AlgorithmConfigDict,
)
logger = logging.getLogger(__name__)
@ -39,7 +39,7 @@ class A2CConfig(A3CConfig):
... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=2)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
@ -62,7 +62,7 @@ class A2CConfig(A3CConfig):
def __init__(self):
"""Initializes a A2CConfig instance."""
super().__init__(trainer_class=A2C)
super().__init__(algo_class=A2C)
# fmt: off
# __sphinx_doc_begin__
@ -93,7 +93,7 @@ class A2CConfig(A3CConfig):
memory. To enable, set this to a value less than the train batch size.
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -107,11 +107,11 @@ class A2CConfig(A3CConfig):
class A2C(A3C):
@classmethod
@override(A3C)
def get_default_config(cls) -> TrainerConfigDict:
def get_default_config(cls) -> AlgorithmConfigDict:
return A2CConfig().to_dict()
@override(A3C)
def validate_config(self, config: TrainerConfigDict) -> None:
def validate_config(self, config: AlgorithmConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
@ -130,8 +130,8 @@ class A2C(A3C):
"Otherwise, microbatches of desired size won't be achievable."
)
@override(Trainer)
def setup(self, config: PartialTrainerConfigDict):
@override(Algorithm)
def setup(self, config: PartialAlgorithmConfigDict):
super().setup(config)
# Create a microbatch variable for collecting gradients on microbatches'.
@ -146,11 +146,11 @@ class A2C(A3C):
@override(A3C)
def training_step(self) -> ResultDict:
# W/o microbatching: Identical to Trainer's default implementation.
# Only difference to a default Trainer being the value function loss term
# W/o microbatching: Identical to Algorithm's default implementation.
# Only difference to a default Algorithm being the value function loss term
# and its value computations alongside each action.
if self.config["microbatch_size"] is None:
return Trainer.training_step(self)
return Algorithm.training_step(self)
# In microbatch mode, we want to compute gradients on experience
# microbatches, average a number of these microbatches, and then

View file

@ -2,8 +2,8 @@ import logging
from typing import Any, Dict, List, Optional, Type, Union
from ray.actor import ActorHandle
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.execution.parallel_requests import (
AsyncRequestsManager,
@ -23,15 +23,15 @@ from ray.rllib.utils.metrics import (
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
from ray.rllib.utils.typing import (
ResultDict,
TrainerConfigDict,
PartialTrainerConfigDict,
AlgorithmConfigDict,
PartialAlgorithmConfigDict,
)
logger = logging.getLogger(__name__)
class A3CConfig(TrainerConfig):
"""Defines a configuration class from which a A3C Trainer can be built.
class A3CConfig(AlgorithmConfig):
"""Defines a configuration class from which a A3C Algorithm can be built.
Example:
>>> from ray import tune
@ -39,7 +39,7 @@ class A3CConfig(TrainerConfig):
... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
@ -60,9 +60,9 @@ class A3CConfig(TrainerConfig):
... )
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes a A3CConfig instance."""
super().__init__(trainer_class=trainer_class or A3C)
super().__init__(algo_class=algo_class or A3C)
# fmt: off
# __sphinx_doc_begin__
@ -78,7 +78,7 @@ class A3CConfig(TrainerConfig):
self.entropy_coeff_schedule = None
self.sample_async = True
# Override some of TrainerConfig's default values with PPO-specific values.
# Override some of AlgorithmConfig's default values with PPO-specific values.
self.rollout_fragment_length = 10
self.lr = 0.0001
# Min time (in seconds) per reporting.
@ -89,7 +89,7 @@ class A3CConfig(TrainerConfig):
# __sphinx_doc_end__
# fmt: on
@override(TrainerConfig)
@override(AlgorithmConfig)
def training(
self,
*,
@ -125,7 +125,7 @@ class A3CConfig(TrainerConfig):
to async buffering of batches.
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -152,21 +152,21 @@ class A3CConfig(TrainerConfig):
return self
class A3C(Trainer):
class A3C(Algorithm):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return A3CConfig().to_dict()
@override(Trainer)
def setup(self, config: PartialTrainerConfigDict):
@override(Algorithm)
def setup(self, config: PartialAlgorithmConfigDict):
super().setup(config)
self._worker_manager = AsyncRequestsManager(
self.workers.remote_workers(), max_remote_requests_in_flight_per_worker=1
)
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@override(Algorithm)
def validate_config(self, config: AlgorithmConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
@ -175,8 +175,8 @@ class A3C(Trainer):
if config["num_workers"] <= 0 and config["sample_async"]:
raise ValueError("`num_workers` for A3C must be >= 1!")
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@override(Algorithm)
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
if config["framework"] == "torch":
from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy
@ -256,7 +256,7 @@ class A3C(Trainer):
return learner_info_builder.finalize()
@override(Trainer)
@override(Algorithm)
def on_worker_failures(
self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]
):

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -11,7 +11,7 @@ import ray
from ray.actor import ActorHandle
from ray.rllib.algorithms.alpha_star.distributed_learners import DistributedLearners
from ray.rllib.algorithms.alpha_star.league_builder import AlphaStarLeagueBuilder
from ray.rllib.agents.trainer import Trainer
from ray.rllib.algorithms.algorithm import Algorithm
import ray.rllib.algorithms.appo.appo as appo
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.execution.parallel_requests import (
@ -36,10 +36,10 @@ from ray.rllib.utils.metrics import (
)
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
PartialAlgorithmConfigDict,
PolicyID,
PolicyState,
TrainerConfigDict,
AlgorithmConfigDict,
ResultDict,
)
from ray.tune.utils.placement_groups import PlacementGroupFactory
@ -47,7 +47,7 @@ from ray.util.timer import _Timer
class AlphaStarConfig(appo.APPOConfig):
"""Defines a configuration class from which an AlphaStar Trainer can be built.
"""Defines a configuration class from which an AlphaStar Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.alpha_star import AlphaStarConfig
@ -55,7 +55,7 @@ class AlphaStarConfig(appo.APPOConfig):
... .resources(num_gpus=4)\
... .rollouts(num_rollout_workers=64)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
@ -78,9 +78,9 @@ class AlphaStarConfig(appo.APPOConfig):
... )
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes a AlphaStarConfig instance."""
super().__init__(trainer_class=trainer_class or AlphaStar)
super().__init__(algo_class=algo_class or AlphaStar)
# fmt: off
# __sphinx_doc_begin__
@ -141,7 +141,6 @@ class AlphaStarConfig(appo.APPOConfig):
# values.
self.vtrace_drop_last_ts = False
self.min_time_s_per_iteration = 2
self._disable_execution_plan_api = True
# __sphinx_doc_end__
# fmt: on
@ -194,14 +193,14 @@ class AlphaStarConfig(appo.APPOConfig):
`ray.rllib.algorithms.alpha_star.league_builder::AlphaStarLeagueBuilder`
(used by default by this algo) as an example.
max_num_policies_to_train: The maximum number of trainable policies for this
Trainer. Each trainable policy will exist as a independent remote actor,
co-located with a replay buffer. This is besides its existence inside
the RolloutWorkers for training and evaluation. Set to None for
Algorithm. Each trainable policy will exist as a independent remote
actor, co-located with a replay buffer. This is besides its existence
inside the RolloutWorkers for training and evaluation. Set to None for
automatically inferring this value from the number of trainable
policies found in the `multiagent` config.
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -244,7 +243,7 @@ class AlphaStar(appo.APPO):
)
@classmethod
@override(Trainer)
@override(Algorithm)
def default_resource_request(cls, config):
cf = dict(cls.get_default_config(), **config)
# Construct a dummy LeagueBuilder, such that it gets the opportunity to
@ -311,11 +310,11 @@ class AlphaStar(appo.APPO):
@classmethod
@override(appo.APPO)
def get_default_config(cls) -> TrainerConfigDict:
def get_default_config(cls) -> AlgorithmConfigDict:
return AlphaStarConfig().to_dict()
@override(appo.APPO)
def validate_config(self, config: TrainerConfigDict):
def validate_config(self, config: AlgorithmConfigDict):
# Create the LeagueBuilder object, allowing it to build the multiagent
# config as well.
self.league_builder = from_config(
@ -324,7 +323,7 @@ class AlphaStar(appo.APPO):
super().validate_config(config)
@override(appo.APPO)
def setup(self, config: PartialTrainerConfigDict):
def setup(self, config: PartialAlgorithmConfigDict):
# Call super's setup to validate config, create RolloutWorkers
# (train and eval), etc..
num_gpus_saved = config["num_gpus"]
@ -403,7 +402,7 @@ class AlphaStar(appo.APPO):
ray_wait_timeout_s=self.config["timeout_s_learner_manager"],
)
@override(Trainer)
@override(Algorithm)
def step(self) -> ResultDict:
# Perform a full step (including evaluation).
result = super().step()
@ -414,7 +413,7 @@ class AlphaStar(appo.APPO):
return result
@override(Trainer)
@override(Algorithm)
def training_step(self) -> ResultDict:
# Trigger asynchronous rollouts on all RolloutWorkers.
# - Rollout results are sent directly to correct replay buffer
@ -495,7 +494,7 @@ class AlphaStar(appo.APPO):
return train_infos
@override(Trainer)
@override(Algorithm)
def add_policy(
self,
policy_id: PolicyID,
@ -503,7 +502,7 @@ class AlphaStar(appo.APPO):
*,
observation_space: Optional[gym.spaces.Space] = None,
action_space: Optional[gym.spaces.Space] = None,
config: Optional[PartialTrainerConfigDict] = None,
config: Optional[PartialAlgorithmConfigDict] = None,
policy_state: Optional[PolicyState] = None,
**kwargs,
) -> Policy:
@ -536,7 +535,7 @@ class AlphaStar(appo.APPO):
return new_policy
@override(Trainer)
@override(Algorithm)
def cleanup(self) -> None:
super().cleanup()
# Stop all policy- and replay actors.

View file

@ -3,11 +3,11 @@ from typing import Any, Dict, List, Optional, Type
import ray
from ray.actor import ActorHandle
from ray.rllib.agents.trainer import Trainer
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
from ray.rllib.utils.typing import PolicyID, TrainerConfigDict
from ray.rllib.utils.typing import PolicyID, AlgorithmConfigDict
class DistributedLearners:
@ -30,7 +30,7 @@ class DistributedLearners:
"""Initializes a DistributedLearners instance.
Args:
config: The Trainer's config dict.
config: The Algorithm's config dict.
max_num_policies_to_train: Maximum number of policies that will ever be
trainable. For these policies, we'll have to create remote
policy actors, distributed across n "learner shards".
@ -161,7 +161,7 @@ class _Shard:
# Merge the policies config overrides with the main config.
# Also, adjust `num_gpus` (to indicate an individual policy's
# num_gpus, not the total number of GPUs).
cfg = Trainer.merge_trainer_configs(
cfg = Algorithm.merge_trainer_configs(
self.config,
dict(policy_spec.config, **{"num_gpus": self.num_gpus_per_policy}),
)
@ -207,7 +207,7 @@ class _Shard:
self,
policy_id: PolicyID,
policy_spec: PolicySpec,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
):
assert self.replay_actor is None
assert len(self.policy_actors) == 0

View file

@ -5,26 +5,26 @@ import numpy as np
import re
from typing import Any, DefaultDict, Dict
from ray.rllib.agents.trainer import Trainer
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.numpy import softmax
from ray.rllib.utils.typing import PolicyID, TrainerConfigDict, ResultDict
from ray.rllib.utils.typing import PolicyID, AlgorithmConfigDict, ResultDict
logger = logging.getLogger(__name__)
@ExperimentalAPI
class LeagueBuilder(metaclass=ABCMeta):
def __init__(self, trainer: Trainer, trainer_config: TrainerConfigDict):
def __init__(self, trainer: Algorithm, trainer_config: AlgorithmConfigDict):
"""Initializes a LeagueBuilder instance.
Args:
trainer: The Trainer object by which this league builder is used.
Trainer calls `build_league()` after each training step.
trainer: The Algorithm object by which this league builder is used.
Algorithm calls `build_league()` after each training step.
trainer_config: The (not yet validated) config dict to be
used on the Trainer. Child classes of `LeagueBuilder`
used on the Algorithm. Child classes of `LeagueBuilder`
should preprocess this to add e.g. multiagent settings
to this config.
"""
@ -67,8 +67,8 @@ class NoLeagueBuilder(LeagueBuilder):
class AlphaStarLeagueBuilder(LeagueBuilder):
def __init__(
self,
trainer: Trainer,
trainer_config: TrainerConfigDict,
trainer: Algorithm,
trainer_config: AlgorithmConfigDict,
num_random_policies: int = 2,
num_learning_league_exploiters: int = 4,
num_learning_main_exploiters: int = 4,
@ -86,11 +86,11 @@ class AlphaStarLeagueBuilder(LeagueBuilder):
M: Main self-play (main vs main).
Args:
trainer: The Trainer object by which this league builder is used.
Trainer calls `build_league()` after each training step to reconfigure
trainer: The Algorithm object by which this league builder is used.
Algorithm calls `build_league()` after each training step to reconfigure
the league structure (e.g. to add/remove policies).
trainer_config: The (not yet validated) config dict to be
used on the Trainer. Child classes of `LeagueBuilder`
used on the Algorithm. Child classes of `LeagueBuilder`
should preprocess this to add e.g. multiagent settings
to this config.
num_random_policies: The number of random policies to add to the

View file

@ -1,9 +1,9 @@
import logging
from typing import List, Optional, Type, Union
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.replay_ops import (
SimpleReplayBuffer,
@ -37,7 +37,7 @@ from ray.rllib.utils.metrics import (
SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
from ray.rllib.utils.typing import ResultDict, AlgorithmConfigDict
from ray.util.iter import LocalIterator
from ray.rllib.algorithms.alpha_zero.alpha_zero_policy import AlphaZeroPolicy
@ -63,8 +63,8 @@ class AlphaZeroDefaultCallbacks(DefaultCallbacks):
episode.user_data["initial_state"] = state
class AlphaZeroConfig(TrainerConfig):
"""Defines a configuration class from which an AlphaZero Trainer can be built.
class AlphaZeroConfig(AlgorithmConfig):
"""Defines a configuration class from which an AlphaZero Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.alpha_zero import AlphaZeroConfig
@ -72,7 +72,7 @@ class AlphaZeroConfig(TrainerConfig):
... .resources(num_gpus=0)\
... .rollouts(num_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
@ -95,9 +95,9 @@ class AlphaZeroConfig(TrainerConfig):
... )
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes a PPOConfig instance."""
super().__init__(trainer_class=trainer_class or AlphaZero)
super().__init__(algo_class=algo_class or AlphaZero)
# fmt: off
# __sphinx_doc_begin__
@ -134,7 +134,7 @@ class AlphaZeroConfig(TrainerConfig):
"num_init_rewards": 100,
}
# Override some of TrainerConfig's default values with AlphaZero-specific
# Override some of AlgorithmConfig's default values with AlphaZero-specific
# values.
self.framework_str = "torch"
self.callbacks_class = AlphaZeroDefaultCallbacks
@ -154,7 +154,7 @@ class AlphaZeroConfig(TrainerConfig):
self.buffer_size = DEPRECATED_VALUE
@override(TrainerConfig)
@override(AlgorithmConfig)
def training(
self,
*,
@ -220,7 +220,7 @@ class AlphaZeroConfig(TrainerConfig):
from: https://arxiv.org/pdf/1807.01672.pdf
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -271,7 +271,7 @@ class AlphaZeroPolicyWrapperClass(AlphaZeroPolicy):
model = ModelCatalog.get_model_v2(
obs_space, action_space, action_space.n, config["model"], "torch"
)
_, env_creator = Trainer._get_env_id_and_creator(config["env"], config)
_, env_creator = Algorithm._get_env_id_and_creator(config["env"], config)
if config["ranked_rewards"]["enable"]:
# if r2 is enabled, tne env is wrapped to include a rewards buffer
# used to normalize rewards
@ -302,23 +302,23 @@ class AlphaZeroPolicyWrapperClass(AlphaZeroPolicy):
)
class AlphaZero(Trainer):
class AlphaZero(Algorithm):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return AlphaZeroConfig().to_dict()
def validate_config(self, config: TrainerConfigDict) -> None:
def validate_config(self, config: AlgorithmConfigDict) -> None:
"""Checks and updates the config based on settings."""
# Call super's validation method.
super().validate_config(config)
validate_buffer_config(config)
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@override(Algorithm)
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
return AlphaZeroPolicyWrapperClass
@override(Trainer)
@override(Algorithm)
def training_step(self) -> ResultDict:
"""TODO:
@ -374,9 +374,9 @@ class AlphaZero(Trainer):
return train_results
@staticmethod
@override(Trainer)
@override(Algorithm)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
workers: WorkerSet, config: AlgorithmConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0

View file

@ -1,13 +1,13 @@
from typing import List, Optional
from ray.actor import ActorHandle
from ray.rllib.agents import Trainer
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQN
from ray.rllib.algorithms.ddpg.ddpg import DDPG, DDPGConfig
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.typing import PartialTrainerConfigDict
from ray.rllib.utils.typing import AlgorithmConfigDict
from ray.rllib.utils.typing import PartialAlgorithmConfigDict
from ray.rllib.utils.typing import ResultDict
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
from ray.util.iter import LocalIterator
@ -44,9 +44,9 @@ class ApexDDPGConfig(DDPGConfig):
... )
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes an ApexDDPGConfig instance."""
super().__init__(trainer_class=trainer_class or ApexDDPG)
super().__init__(algo_class=algo_class or ApexDDPG)
# fmt: off
# __sphinx_doc_begin__
@ -174,11 +174,11 @@ class ApexDDPGConfig(DDPGConfig):
class ApexDDPG(DDPG, ApexDQN):
@classmethod
@override(DDPG)
def get_default_config(cls) -> TrainerConfigDict:
def get_default_config(cls) -> AlgorithmConfigDict:
return ApexDDPGConfig().to_dict()
@override(DDPG)
def setup(self, config: PartialTrainerConfigDict):
def setup(self, config: PartialAlgorithmConfigDict):
return ApexDQN.setup(self, config)
@override(DDPG)
@ -186,7 +186,7 @@ class ApexDDPG(DDPG, ApexDQN):
"""Use APEX-DQN's training iteration function."""
return ApexDQN.training_step(self)
@override(Trainer)
@override(Algorithm)
def on_worker_failures(
self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]
):

View file

@ -2,7 +2,7 @@
Distributed Prioritized Experience Replay (Ape-X)
=================================================
This file defines a DQN trainer using the Ape-X architecture.
This file defines a DQN algorithm using the Ape-X architecture.
Ape-X uses a single GPU learner and many CPU workers for experience collection.
Experience collection can scale to hundreds of CPU workers due to the
@ -21,7 +21,7 @@ from typing import Dict, List, Type, Optional, Callable
import ray
from ray.actor import ActorHandle
from ray.rllib import Policy
from ray.rllib.agents import Trainer
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.dqn.dqn import DQN, DQNConfig
from ray.rllib.algorithms.dqn.learner_thread import LearnerThread
from ray.rllib.evaluation.rollout_worker import RolloutWorker
@ -48,16 +48,17 @@ from ray.rllib.utils.metrics import (
TARGET_NET_UPDATE_TIMER,
)
from ray.rllib.utils.typing import (
TrainerConfigDict,
AlgorithmConfigDict,
ResultDict,
PartialTrainerConfigDict,
PartialAlgorithmConfigDict,
)
from ray.tune.trainable import Trainable
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.util.ml_utils.dict import merge_dicts
class ApexDQNConfig(DQNConfig):
"""Defines a configuration class from which an ApexDQN Trainer can be built.
"""Defines a configuration class from which an ApexDQN Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQNConfig
@ -75,9 +76,9 @@ class ApexDQNConfig(DQNConfig):
>>> .resources(num_gpus=1)\
>>> .rollouts(num_rollout_workers=30)\
>>> .environment("CartPole-v1")
>>> trainer = config.build()
>>> algo = config.build()
>>> while True:
>>> trainer.train()
>>> algo.train()
Example:
>>> from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQNConfig
@ -120,9 +121,9 @@ class ApexDQNConfig(DQNConfig):
>>> .exploration(exploration_config=explore_config)
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes a ApexConfig instance."""
super().__init__(trainer_class=trainer_class or ApexDQN)
super().__init__(algo_class=algo_class or ApexDQN)
# fmt: off
# __sphinx_doc_begin__
@ -350,8 +351,8 @@ class ApexDQNConfig(DQNConfig):
class ApexDQN(DQN):
@override(Trainer)
def setup(self, config: PartialTrainerConfigDict):
@override(Trainable)
def setup(self, config: PartialAlgorithmConfigDict):
super().setup(config)
# Shortcut: If execution_plan, thread and buffer will be created in there.
@ -423,7 +424,7 @@ class ApexDQN(DQN):
@classmethod
@override(DQN)
def get_default_config(cls) -> TrainerConfigDict:
def get_default_config(cls) -> AlgorithmConfigDict:
return ApexDQNConfig().to_dict()
@override(DQN)
@ -641,7 +642,7 @@ class ApexDQN(DQN):
STEPS_TRAINED_COUNTER
]
@override(Trainer)
@override(Algorithm)
def on_worker_failures(
self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]
):
@ -654,7 +655,7 @@ class ApexDQN(DQN):
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)
@override(Trainer)
@override(Algorithm)
def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
result = super()._compile_iteration_results(
step_ctx=step_ctx, iteration_results=iteration_results
@ -679,7 +680,7 @@ class ApexDQN(DQN):
return result
@classmethod
@override(Trainer)
@override(Algorithm)
def default_resource_request(cls, config):
cf = dict(cls.get_default_config(), **config)

View file

@ -2,7 +2,7 @@
Asynchronous Proximal Policy Optimization (APPO)
================================================
This file defines the distributed Trainer class for the asynchronous version
This file defines the distributed Algorithm class for the asynchronous version
of proximal policy optimization (APPO).
See `appo_[tf|torch]_policy.py` for the definition of the policy loss.
@ -24,16 +24,16 @@ from ray.rllib.utils.metrics import (
)
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
PartialAlgorithmConfigDict,
ResultDict,
TrainerConfigDict,
AlgorithmConfigDict,
)
logger = logging.getLogger(__name__)
class APPOConfig(ImpalaConfig):
"""Defines a configuration class from which an APPO Trainer can be built.
"""Defines a configuration class from which an APPO Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.appo import APPOConfig
@ -41,7 +41,7 @@ class APPOConfig(ImpalaConfig):
... .resources(num_gpus=1)\
... .rollouts(num_rollout_workers=16)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
@ -64,9 +64,9 @@ class APPOConfig(ImpalaConfig):
... )
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes a APPOConfig instance."""
super().__init__(trainer_class=trainer_class or APPO)
super().__init__(algo_class=algo_class or APPO)
# fmt: off
# __sphinx_doc_begin__
@ -141,7 +141,7 @@ class APPOConfig(ImpalaConfig):
`kl_coeff` automatically).
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -238,12 +238,12 @@ class APPO(Impala):
@classmethod
@override(Impala)
def get_default_config(cls) -> TrainerConfigDict:
def get_default_config(cls) -> AlgorithmConfigDict:
return APPOConfig().to_dict()
@override(Impala)
def get_default_policy_class(
self, config: PartialTrainerConfigDict
self, config: PartialAlgorithmConfigDict
) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy

View file

@ -10,7 +10,7 @@ import time
from typing import Optional
import ray
from ray.rllib.agents import Trainer, TrainerConfig
from ray.rllib.algorithms import Algorithm, AlgorithmConfig
from ray.rllib.algorithms.ars.ars_tf_policy import ARSTFPolicy
from ray.rllib.algorithms.es import optimizers, utils
from ray.rllib.algorithms.es.es_tf_policy import rollout
@ -26,7 +26,7 @@ from ray.rllib.utils.metrics import (
NUM_ENV_STEPS_TRAINED,
)
from ray.rllib.utils.torch_utils import set_torch_seed
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.typing import AlgorithmConfigDict
logger = logging.getLogger(__name__)
@ -43,8 +43,8 @@ Result = namedtuple(
)
class ARSConfig(TrainerConfig):
"""Defines a configuration class from which an ARS Trainer can be built.
class ARSConfig(AlgorithmConfig):
"""Defines a configuration class from which an ARS Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.ars import ARSConfig
@ -52,7 +52,7 @@ class ARSConfig(TrainerConfig):
... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
@ -77,7 +77,7 @@ class ARSConfig(TrainerConfig):
def __init__(self):
"""Initializes a ARSConfig instance."""
super().__init__(trainer_class=ARS)
super().__init__(algo_class=ARS)
# fmt: off
# __sphinx_doc_begin__
@ -93,10 +93,10 @@ class ARSConfig(TrainerConfig):
self.report_length = 10
self.offset = 0
# Override some of TrainerConfig's default values with ARS-specific values.
# Override some of AlgorithmConfig's default values with ARS-specific values.
self.num_workers = 2
self.observation_filter = "MeanStdFilter"
# ARS will use Trainer's evaluation WorkerSet (if evaluation_interval > 0).
# ARS will use Algorithm's evaluation WorkerSet (if evaluation_interval > 0).
# Therefore, we must be careful not to use more than 1 env per eval worker
# (would break ARSPolicy's compute_single_action method) and to not do
# obs-filtering.
@ -106,7 +106,7 @@ class ARSConfig(TrainerConfig):
# __sphinx_doc_end__
# fmt: on
@override(TrainerConfig)
@override(AlgorithmConfig)
def training(
self,
*,
@ -139,7 +139,7 @@ class ARSConfig(TrainerConfig):
from humanoid) during rollouts.
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -312,16 +312,16 @@ def get_policy_class(config):
return policy_cls
class ARS(Trainer):
class ARS(Algorithm):
"""Large-scale implementation of Augmented Random Search in Ray."""
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return ARSConfig().to_dict()
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@override(Algorithm)
def validate_config(self, config: AlgorithmConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
@ -341,7 +341,7 @@ class ARS(Trainer):
"`NoFilter` for ARS!"
)
@override(Trainer)
@override(Algorithm)
def setup(self, config):
# Setup our config: Merge the user-supplied config (which could
# be a partial config dict with the class' default).
@ -387,7 +387,7 @@ class ARS(Trainer):
self.reward_list = []
self.tstart = time.time()
@override(Trainer)
@override(Algorithm)
def get_policy(self, policy=DEFAULT_POLICY_ID):
if policy != DEFAULT_POLICY_ID:
raise ValueError(
@ -396,7 +396,7 @@ class ARS(Trainer):
)
return self.policy
@override(Trainer)
@override(Algorithm)
def step(self):
config = self.config
@ -505,20 +505,20 @@ class ARS(Trainer):
return result
@override(Trainer)
@override(Algorithm)
def cleanup(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for w in self.workers:
w.__ray_terminate__.remote()
@override(Trainer)
@override(Algorithm)
def compute_single_action(self, observation, *args, **kwargs):
action, _, _ = self.policy.compute_actions([observation], update=True)
if kwargs.get("full_fetch"):
return action[0], [], {}
return action[0]
@override(Trainer)
@override(Algorithm)
def _sync_weights_to_workers(self, *, worker_set=None, workers=None):
# Broadcast the new policy weights to all evaluation workers.
assert worker_set is not None

View file

@ -15,7 +15,7 @@ class TestARS(unittest.TestCase):
ray.shutdown()
def test_ars_compilation(self):
"""Test whether an ARSTrainer can be built on all frameworks."""
"""Test whether an ARSAlgorithm can be built on all frameworks."""
config = ars.ARSConfig()
# Keep it simple.

View file

@ -1,19 +1,19 @@
import logging
from typing import Type, Union
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.bandit.bandit_tf_policy import BanditTFPolicy
from ray.rllib.algorithms.bandit.bandit_torch_policy import BanditTorchPolicy
from ray.rllib.agents.trainer import Trainer
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.utils.typing import AlgorithmConfigDict
from ray.rllib.utils.deprecation import Deprecated
logger = logging.getLogger(__name__)
class BanditConfig(TrainerConfig):
class BanditConfig(AlgorithmConfig):
"""Defines a contextual bandit configuration class from which
a contexual bandit algorithm can be built. Note this config is shared
between BanditLinUCB and BanditLinTS. You likely
@ -21,18 +21,18 @@ class BanditConfig(TrainerConfig):
instead.
"""
def __init__(self, trainer_class: Union["BanditLinTS", "BanditLinUCB"] = None):
super().__init__(trainer_class=trainer_class)
def __init__(self, algo_class: Union["BanditLinTS", "BanditLinUCB"] = None):
super().__init__(algo_class=algo_class)
# fmt: off
# __sphinx_doc_begin__
# Override some of TrainerConfig's default values with bandit-specific values.
# Override some of AlgorithmConfig's default values with bandit-specific values.
self.framework_str = "torch"
self.num_workers = 0
self.rollout_fragment_length = 1
self.train_batch_size = 1
# Make sure, a `train()` call performs at least 100 env sampling
# timesteps, before reporting results. Not setting this (default is 0)
# would significantly slow down the Bandit Trainer.
# would significantly slow down the Bandit Algorithm.
self.min_sample_timesteps_per_iteration = 100
# __sphinx_doc_end__
# fmt: on
@ -46,16 +46,16 @@ class BanditLinTSConfig(BanditConfig):
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
>>> config = BanditLinTSConfig().rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env=WheelBanditEnv)
>>> trainer.train()
"""
def __init__(self):
super().__init__(trainer_class=BanditLinTS)
super().__init__(algo_class=BanditLinTS)
# fmt: off
# __sphinx_doc_begin__
# Override some of TrainerConfig's default values with bandit-specific values.
# Override some of AlgorithmConfig's default values with bandit-specific values.
self.exploration_config = {"type": "ThompsonSampling"}
# __sphinx_doc_end__
# fmt: on
@ -69,31 +69,31 @@ class BanditLinUCBConfig(BanditConfig):
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
>>> config = BanditLinUCBConfig().rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env=WheelBanditEnv)
>>> trainer.train()
"""
def __init__(self):
super().__init__(trainer_class=BanditLinUCB)
super().__init__(algo_class=BanditLinUCB)
# fmt: off
# __sphinx_doc_begin__
# Override some of TrainerConfig's default values with bandit-specific values.
# Override some of AlgorithmConfig's default values with bandit-specific values.
self.exploration_config = {"type": "UpperConfidenceBound"}
# __sphinx_doc_end__
# fmt: on
class BanditLinTS(Trainer):
"""Bandit Trainer using ThompsonSampling exploration."""
class BanditLinTS(Algorithm):
"""Bandit Algorithm using ThompsonSampling exploration."""
@classmethod
@override(Trainer)
@override(Algorithm)
def get_default_config(cls) -> BanditLinTSConfig:
return BanditLinTSConfig().to_dict()
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@override(Algorithm)
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
if config["framework"] == "torch":
return BanditTorchPolicy
elif config["framework"] == "tf2":
@ -102,14 +102,14 @@ class BanditLinTS(Trainer):
raise NotImplementedError("Only `framework=[torch|tf2]` supported!")
class BanditLinUCB(Trainer):
class BanditLinUCB(Algorithm):
@classmethod
@override(Trainer)
@override(Algorithm)
def get_default_config(cls) -> BanditLinUCBConfig:
return BanditLinUCBConfig().to_dict()
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@override(Algorithm)
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
if config["framework"] == "torch":
return BanditTorchPolicy
elif config["framework"] == "tf2":

View file

@ -20,7 +20,7 @@ from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.tf_utils import make_tf_callable
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
from ray.rllib.utils.typing import TensorType, AlgorithmConfigDict
from ray.util.debug import log_once
logger = logging.getLogger(__name__)
@ -71,7 +71,7 @@ def validate_spaces(
policy: Policy,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> None:
"""Validates the observation- and action spaces used for the Policy.

View file

@ -29,15 +29,15 @@ class TestBandits(unittest.TestCase):
):
for train_batch_size in [1, 10]:
config.training(train_batch_size=train_batch_size)
trainer = config.build()
algo = config.build()
results = None
for i in range(num_iterations):
results = trainer.train()
for _ in range(num_iterations):
results = algo.train()
check_train_results(results)
print(results)
# Force good learning behavior (this is a very simple env).
self.assertTrue(results["episode_reward_mean"] == 10.0)
trainer.stop()
algo.stop()
def test_bandit_lin_ucb_compilation(self):
"""Test whether BanditLinUCB can be built on all frameworks."""

View file

@ -1,7 +1,7 @@
from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.typing import AlgorithmConfigDict
class BCConfig(MARWILConfig):
@ -38,8 +38,8 @@ class BCConfig(MARWILConfig):
... )
"""
def __init__(self, trainer_class=None):
super().__init__(trainer_class=trainer_class or BC)
def __init__(self, algo_class=None):
super().__init__(algo_class=algo_class or BC)
# fmt: off
# __sphinx_doc_begin__
@ -62,11 +62,11 @@ class BC(MARWIL):
@classmethod
@override(MARWIL)
def get_default_config(cls) -> TrainerConfigDict:
def get_default_config(cls) -> AlgorithmConfigDict:
return BCConfig().to_dict()
@override(MARWIL)
def validate_config(self, config: TrainerConfigDict) -> None:
def validate_config(self, config: AlgorithmConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

View file

@ -0,0 +1,609 @@
import numpy as np
import os
import tracemalloc
from typing import Dict, Optional, Tuple, TYPE_CHECKING
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.utils.annotations import (
is_overridden,
OverrideToImplementCustomLogic,
PublicAPI,
)
from ray.rllib.utils.deprecation import deprecation_warning, Deprecated
from ray.rllib.utils.exploration.random_encoder import (
_MovingMeanStd,
compute_states_entropy,
update_beta,
)
from ray.rllib.utils.typing import AgentID, EnvType, PolicyID
from ray.tune.callback import _CallbackMeta
# Import psutil after ray so the packaged version is used.
import psutil
if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.evaluation import RolloutWorker
@PublicAPI
class DefaultCallbacks(metaclass=_CallbackMeta):
"""Abstract base class for RLlib callbacks (similar to Keras callbacks).
These callbacks can be used for custom metrics and custom postprocessing.
By default, all of these callbacks are no-ops. To configure custom training
callbacks, subclass DefaultCallbacks and then set
{"callbacks": YourCallbacksClass} in the algo config.
"""
def __init__(self, legacy_callbacks_dict: Dict[str, callable] = None):
if legacy_callbacks_dict:
deprecation_warning(
"callbacks dict interface",
"a class extending rllib.algorithms.callbacks.DefaultCallbacks",
)
self.legacy_callbacks = legacy_callbacks_dict or {}
if is_overridden(self.on_trainer_init):
deprecation_warning(
old="on_trainer_init(trainer, **kwargs)",
new="on_algorithm_init(algorithm, **kwargs)",
error=True,
)
def on_sub_environment_created(
self,
*,
worker: "RolloutWorker",
sub_environment: EnvType,
env_context: EnvContext,
**kwargs,
) -> None:
"""Callback run when a new sub-environment has been created.
This method gets called after each sub-environment (usually a
gym.Env) has been created, validated (RLlib built-in validation
+ possible custom validation function implemented by overriding
`Algorithm.validate_env()`), wrapped (e.g. video-wrapper), and seeded.
Args:
worker: Reference to the current rollout worker.
sub_environment: The sub-environment instance that has been
created. This is usually a gym.Env object.
env_context: The `EnvContext` object that has been passed to
the env's constructor.
kwargs: Forward compatibility placeholder.
"""
pass
def on_algorithm_init(
self,
*,
algorithm: "Algorithm",
**kwargs,
) -> None:
"""Callback run when a new algorithm instance has finished setup.
This method gets called at the end of Algorithm.setup() after all
the initialization is done, and before actually training starts.
Args:
algorithm: Reference to the trainer instance.
kwargs: Forward compatibility placeholder.
"""
pass
def on_episode_start(
self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Episode,
**kwargs,
) -> None:
"""Callback run on the rollout worker before each episode starts.
Args:
worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies: Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
episode: Episode object which contains the episode's
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_episode_start"):
self.legacy_callbacks["on_episode_start"](
{
"env": base_env,
"policy": policies,
"episode": episode,
}
)
def on_episode_step(
self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Optional[Dict[PolicyID, Policy]] = None,
episode: Episode,
**kwargs,
) -> None:
"""Runs on each episode step.
Args:
worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies: Mapping of policy id to policy objects.
In single agent mode there will only be a single
"default_policy".
episode: Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_episode_step"):
self.legacy_callbacks["on_episode_step"](
{"env": base_env, "episode": episode}
)
def on_episode_end(
self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Episode,
**kwargs,
) -> None:
"""Runs when an episode is done.
Args:
worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies: Mapping of policy id to policy
objects. In single agent mode there will only be a single
"default_policy".
episode: Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_episode_end"):
self.legacy_callbacks["on_episode_end"](
{
"env": base_env,
"policy": policies,
"episode": episode,
}
)
def on_postprocess_trajectory(
self,
*,
worker: "RolloutWorker",
episode: Episode,
agent_id: AgentID,
policy_id: PolicyID,
policies: Dict[PolicyID, Policy],
postprocessed_batch: SampleBatch,
original_batches: Dict[AgentID, Tuple[Policy, SampleBatch]],
**kwargs,
) -> None:
"""Called immediately after a policy's postprocess_fn is called.
You can use this callback to do additional postprocessing for a policy,
including looking at the trajectory data of other agents in multi-agent
settings.
Args:
worker: Reference to the current rollout worker.
episode: Episode object.
agent_id: Id of the current agent.
policy_id: Id of the current policy for the agent.
policies: Mapping of policy id to policy objects. In single
agent mode there will only be a single "default_policy".
postprocessed_batch: The postprocessed sample batch
for this agent. You can mutate this object to apply your own
trajectory postprocessing.
original_batches: Mapping of agents to their unpostprocessed
trajectory data. You should not mutate this object.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_postprocess_traj"):
self.legacy_callbacks["on_postprocess_traj"](
{
"episode": episode,
"agent_id": agent_id,
"pre_batch": original_batches[agent_id],
"post_batch": postprocessed_batch,
"all_pre_batches": original_batches,
}
)
def on_sample_end(
self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs
) -> None:
"""Called at the end of RolloutWorker.sample().
Args:
worker: Reference to the current rollout worker.
samples: Batch to be returned. You can mutate this
object to modify the samples generated.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_sample_end"):
self.legacy_callbacks["on_sample_end"](
{
"worker": worker,
"samples": samples,
}
)
def on_learn_on_batch(
self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs
) -> None:
"""Called at the beginning of Policy.learn_on_batch().
Note: This is called before 0-padding via
`pad_batch_to_sequences_of_same_size`.
Also note, SampleBatch.INFOS column will not be available on
train_batch within this callback if framework is tf1, due to
the fact that tf1 static graph would mistake it as part of the
input dict if present.
It is available though, for tf2 and torch frameworks.
Args:
policy: Reference to the current Policy object.
train_batch: SampleBatch to be trained on. You can
mutate this object to modify the samples generated.
result: A results dict to add custom metrics to.
kwargs: Forward compatibility placeholder.
"""
pass
def on_train_result(
self,
*,
algorithm: Optional["Algorithm"] = None,
result: dict,
trainer=None,
**kwargs,
) -> None:
"""Called at the end of Trainable.train().
Args:
algorithm: Current trainer instance.
result: Dict of results returned from trainer.train() call.
You can mutate this object to add additional metrics.
kwargs: Forward compatibility placeholder.
"""
if trainer is not None:
algorithm = trainer
if self.legacy_callbacks.get("on_train_result"):
self.legacy_callbacks["on_train_result"](
{
"trainer": algorithm,
"result": result,
}
)
@OverrideToImplementCustomLogic
@Deprecated(error=True)
def on_trainer_init(self, *args, **kwargs):
raise DeprecationWarning
class MemoryTrackingCallbacks(DefaultCallbacks):
"""MemoryTrackingCallbacks can be used to trace and track memory usage
in rollout workers.
The Memory Tracking Callbacks uses tracemalloc and psutil to track
python allocations during rollouts,
in training or evaluation.
The tracking data is logged to the custom_metrics of an episode and
can therefore be viewed in tensorboard
(or in WandB etc..)
Add MemoryTrackingCallbacks callback to the tune config
e.g. { ...'callbacks': MemoryTrackingCallbacks ...}
Note:
This class is meant for debugging and should not be used
in production code as tracemalloc incurs
a significant slowdown in execution speed.
"""
def __init__(self):
super().__init__()
# Will track the top 10 lines where memory is allocated
tracemalloc.start(10)
def on_episode_end(
self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Episode,
env_index: Optional[int] = None,
**kwargs,
) -> None:
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics("lineno")
for stat in top_stats[:10]:
count = stat.count
size = stat.size
trace = str(stat.traceback)
episode.custom_metrics[f"tracemalloc/{trace}/size"] = size
episode.custom_metrics[f"tracemalloc/{trace}/count"] = count
process = psutil.Process(os.getpid())
worker_rss = process.memory_info().rss
worker_data = process.memory_info().data
worker_vms = process.memory_info().vms
episode.custom_metrics["tracemalloc/worker/rss"] = worker_rss
episode.custom_metrics["tracemalloc/worker/data"] = worker_data
episode.custom_metrics["tracemalloc/worker/vms"] = worker_vms
class MultiCallbacks(DefaultCallbacks):
"""MultiCallbacks allows multiple callbacks to be registered at
the same time in the config of the environment.
Example:
.. code-block:: python
'callbacks': MultiCallbacks([
MyCustomStatsCallbacks,
MyCustomVideoCallbacks,
MyCustomTraceCallbacks,
....
])
"""
IS_CALLBACK_CONTAINER = True
def __init__(self, callback_class_list):
super().__init__()
self._callback_class_list = callback_class_list
self._callback_list = []
def __call__(self, *args, **kwargs):
self._callback_list = [
callback_class() for callback_class in self._callback_class_list
]
return self
def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
for callback in self._callback_list:
callback.on_algorithm_init(algorithm=algorithm, **kwargs)
def on_sub_environment_created(
self,
*,
worker: "RolloutWorker",
sub_environment: EnvType,
env_context: EnvContext,
**kwargs,
) -> None:
for callback in self._callback_list:
callback.on_sub_environment_created(
worker=worker,
sub_environment=sub_environment,
env_context=env_context,
**kwargs,
)
def on_episode_start(
self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Episode,
env_index: Optional[int] = None,
**kwargs,
) -> None:
for callback in self._callback_list:
callback.on_episode_start(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode,
env_index=env_index,
**kwargs,
)
def on_episode_step(
self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Optional[Dict[PolicyID, Policy]] = None,
episode: Episode,
env_index: Optional[int] = None,
**kwargs,
) -> None:
for callback in self._callback_list:
callback.on_episode_step(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode,
env_index=env_index,
**kwargs,
)
def on_episode_end(
self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Episode,
env_index: Optional[int] = None,
**kwargs,
) -> None:
for callback in self._callback_list:
callback.on_episode_end(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode,
env_index=env_index,
**kwargs,
)
def on_postprocess_trajectory(
self,
*,
worker: "RolloutWorker",
episode: Episode,
agent_id: AgentID,
policy_id: PolicyID,
policies: Dict[PolicyID, Policy],
postprocessed_batch: SampleBatch,
original_batches: Dict[AgentID, Tuple[Policy, SampleBatch]],
**kwargs,
) -> None:
for callback in self._callback_list:
callback.on_postprocess_trajectory(
worker=worker,
episode=episode,
agent_id=agent_id,
policy_id=policy_id,
policies=policies,
postprocessed_batch=postprocessed_batch,
original_batches=original_batches,
**kwargs,
)
def on_sample_end(
self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs
) -> None:
for callback in self._callback_list:
callback.on_sample_end(worker=worker, samples=samples, **kwargs)
def on_learn_on_batch(
self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs
) -> None:
for callback in self._callback_list:
callback.on_learn_on_batch(
policy=policy, train_batch=train_batch, result=result, **kwargs
)
def on_train_result(
self, *, algorithm=None, result: dict, trainer=None, **kwargs
) -> None:
if trainer is not None:
algorithm = trainer
for callback in self._callback_list:
# TODO: Remove `trainer` arg at some point to fully deprecate the old term.
callback.on_train_result(
algorithm=algorithm, result=result, trainer=algorithm, **kwargs
)
# This Callback is used by the RE3 exploration strategy.
# See rllib/examples/re3_exploration.py for details.
class RE3UpdateCallbacks(DefaultCallbacks):
"""Update input callbacks to mutate batch with states entropy rewards."""
_step = 0
def __init__(
self,
*args,
embeds_dim: int = 128,
k_nn: int = 50,
beta: float = 0.1,
rho: float = 0.0001,
beta_schedule: str = "constant",
**kwargs,
):
self.embeds_dim = embeds_dim
self.k_nn = k_nn
self.beta = beta
self.rho = rho
self.beta_schedule = beta_schedule
self._rms = _MovingMeanStd()
super().__init__(*args, **kwargs)
def on_learn_on_batch(
self,
*,
policy: Policy,
train_batch: SampleBatch,
result: dict,
**kwargs,
):
super().on_learn_on_batch(
policy=policy, train_batch=train_batch, result=result, **kwargs
)
states_entropy = compute_states_entropy(
train_batch[SampleBatch.OBS_EMBEDS], self.embeds_dim, self.k_nn
)
states_entropy = update_beta(
self.beta_schedule, self.beta, self.rho, RE3UpdateCallbacks._step
) * np.reshape(
self._rms(states_entropy),
train_batch[SampleBatch.OBS_EMBEDS].shape[:-1],
)
train_batch[SampleBatch.REWARDS] = (
train_batch[SampleBatch.REWARDS] + states_entropy
)
if Postprocessing.ADVANTAGES in train_batch:
train_batch[Postprocessing.ADVANTAGES] = (
train_batch[Postprocessing.ADVANTAGES] + states_entropy
)
train_batch[Postprocessing.VALUE_TARGETS] = (
train_batch[Postprocessing.VALUE_TARGETS] + states_entropy
)
def on_train_result(
self, *, result: dict, algorithm=None, trainer=None, **kwargs
) -> None:
if trainer is not None:
algorithm = trainer
# TODO(gjoliver): Remove explicit _step tracking and pass
# trainer._iteration as a parameter to on_learn_on_batch() call.
RE3UpdateCallbacks._step = result["training_iteration"]
# TODO: Remove `trainer` arg at some point to fully deprecate the old term.
super().on_train_result(
algorithm=algorithm, result=result, trainer=algorithm, **kwargs
)

View file

@ -32,7 +32,7 @@ from ray.rllib.utils.metrics import (
SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
from ray.rllib.utils.typing import ResultDict, AlgorithmConfigDict
tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
@ -52,8 +52,8 @@ class CQLConfig(SACConfig):
>>> trainer.train()
"""
def __init__(self, trainer_class=None):
super().__init__(trainer_class=trainer_class or CQL)
def __init__(self, algo_class=None):
super().__init__(algo_class=algo_class or CQL)
# fmt: off
# __sphinx_doc_begin__
@ -99,7 +99,7 @@ class CQLConfig(SACConfig):
min_q_weight: in Q weight multiplier.
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -165,11 +165,11 @@ class CQL(SAC):
@classmethod
@override(SAC)
def get_default_config(cls) -> TrainerConfigDict:
def get_default_config(cls) -> AlgorithmConfigDict:
return CQLConfig().to_dict()
@override(SAC)
def validate_config(self, config: TrainerConfigDict) -> None:
def validate_config(self, config: AlgorithmConfigDict) -> None:
# First check, whether old `timesteps_per_iteration` is used. If so
# convert right away as for CQL, we must measure in training timesteps,
# never sampling timesteps (CQL does not sample).
@ -206,7 +206,7 @@ class CQL(SAC):
try_import_tfp(error=True)
@override(SAC)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
if config["framework"] == "torch":
return CQLTorchPolicy
else:

View file

@ -35,7 +35,7 @@ from ray.rllib.utils.typing import (
LocalOptimizer,
ModelGradients,
TensorType,
TrainerConfigDict,
AlgorithmConfigDict,
)
tf1, tf, tfv = try_import_tf()
@ -314,7 +314,7 @@ def setup_early_mixins(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> None:
"""Call mixin classes' constructors before Policy's initialization.

View file

@ -30,7 +30,7 @@ from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import LocalOptimizer, TensorType, TrainerConfigDict
from ray.rllib.utils.typing import LocalOptimizer, TensorType, AlgorithmConfigDict
from ray.rllib.utils.torch_utils import (
apply_grad_clipping,
convert_to_torch_tensor,
@ -340,7 +340,7 @@ def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]
def cql_optimizer_fn(
policy: Policy, config: TrainerConfigDict
policy: Policy, config: AlgorithmConfigDict
) -> Tuple[LocalOptimizer]:
policy.cur_iter = 0
opt_list = optimizer_fn(policy, config)
@ -365,7 +365,7 @@ def cql_setup_late_mixins(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> None:
setup_late_mixins(policy, obs_space, action_space, config)
if config["lagrangian"]:

View file

@ -3,7 +3,7 @@ import numpy as np
from typing import Type, List, Optional
import tree
from ray.rllib.agents.trainer import Trainer, TrainerConfig
from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig
from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step,
train_one_step,
@ -18,17 +18,17 @@ from ray.rllib.utils.metrics import (
TARGET_NET_UPDATE_TIMER,
)
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
PartialAlgorithmConfigDict,
ResultDict,
TrainerConfigDict,
AlgorithmConfigDict,
)
logger = logging.getLogger(__name__)
class CRRConfig(TrainerConfig):
def __init__(self, trainer_class=None):
super().__init__(trainer_class=trainer_class or CRR)
class CRRConfig(AlgorithmConfig):
def __init__(self, algo_class=None):
super().__init__(algo_class=algo_class or CRR)
# fmt: off
# __sphinx_doc_begin__
@ -142,13 +142,13 @@ class CRRConfig(TrainerConfig):
NUM_GRADIENT_UPDATES = "num_grad_updates"
class CRR(Trainer):
class CRR(Algorithm):
# TODO: we have a circular dependency for get
# default config. config -> Trainer -> config
# defining Config class in the same file for now as a workaround.
def setup(self, config: PartialTrainerConfigDict):
def setup(self, config: PartialAlgorithmConfigDict):
super().setup(config)
# initial setup for handling the offline data in form of a replay buffer
# Add the entire dataset to Replay Buffer (global variable)
@ -194,12 +194,12 @@ class CRR(Trainer):
self._counters[NUM_TARGET_UPDATES] = 0
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return CRRConfig().to_dict()
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@override(Algorithm)
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
if config["framework"] == "torch":
from ray.rllib.algorithms.crr.torch import CRRTorchPolicy
@ -207,7 +207,7 @@ class CRR(Trainer):
else:
raise ValueError("Non-torch frameworks are not supported yet!")
@override(Trainer)
@override(Algorithm)
def training_step(self) -> ResultDict:
total_transitions = len(self.local_replay_buffer)

View file

@ -10,7 +10,7 @@ from typing import (
Union,
)
from ray.rllib.agents import TrainerConfig
from ray.rllib.algorithms import AlgorithmConfig
from ray.rllib.algorithms.crr.torch import CRRModel
from ray.rllib.algorithms.ddpg.noop_model import TorchNoopModel
from ray.rllib.algorithms.sac.sac_torch_policy import TargetNetworkMixin
@ -384,6 +384,6 @@ if __name__ == "__main__":
obs_space = gym.spaces.Box(np.array((-1, -1)), np.array((1, 1)))
act_space = gym.spaces.Box(np.array((-1, -1)), np.array((1, 1)))
config = TrainerConfig().framework(framework="torch").to_dict()
config = AlgorithmConfig().framework(framework="torch").to_dict()
print(config["framework"])
CRRTorchPolicy(obs_space, act_space, config=config)

View file

@ -1,12 +1,12 @@
import logging
from typing import List, Optional, Type
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.simple_q.simple_q import SimpleQ, SimpleQConfig
from ray.rllib.algorithms.ddpg.ddpg_tf_policy import DDPGTFPolicy
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.typing import AlgorithmConfigDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import Deprecated
@ -44,9 +44,9 @@ class DDPGConfig(SimpleQConfig):
... )
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes a DDPGConfig instance."""
super().__init__(trainer_class=trainer_class or DDPG)
super().__init__(algo_class=algo_class or DDPG)
# fmt: off
# __sphinx_doc_begin__
@ -129,7 +129,7 @@ class DDPGConfig(SimpleQConfig):
# Deprecated.
self.worker_side_prioritization = DEPRECATED_VALUE
@override(TrainerConfig)
@override(AlgorithmConfig)
def training(
self,
*,
@ -256,12 +256,12 @@ class DDPGConfig(SimpleQConfig):
class DDPG(SimpleQ):
@classmethod
@override(SimpleQ)
# TODO make this return a TrainerConfig
def get_default_config(cls) -> TrainerConfigDict:
# TODO make this return a AlgorithmConfig
def get_default_config(cls) -> AlgorithmConfigDict:
return DDPGConfig().to_dict()
@override(SimpleQ)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
if config["framework"] == "torch":
from ray.rllib.algorithms.ddpg.ddpg_torch_policy import DDPGTorchPolicy
@ -270,7 +270,7 @@ class DDPG(SimpleQ):
return DDPGTFPolicy
@override(SimpleQ)
def validate_config(self, config: TrainerConfigDict) -> None:
def validate_config(self, config: AlgorithmConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

View file

@ -29,7 +29,7 @@ from ray.rllib.utils.framework import get_variable, try_import_tf
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.tf_utils import huber_loss, make_tf_callable
from ray.rllib.utils.typing import (
TrainerConfigDict,
AlgorithmConfigDict,
TensorType,
LocalOptimizer,
ModelGradients,
@ -45,7 +45,7 @@ def build_ddpg_models(
policy: Policy,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> ModelV2:
if policy.config["use_state_preprocessor"]:
default_model = None # catalog decides
@ -379,7 +379,7 @@ def setup_early_mixins(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> None:
"""Call mixin classes' constructors before Policy's initialization.
@ -425,13 +425,13 @@ def setup_mid_mixins(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> None:
ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss)
class TargetNetworkMixin:
def __init__(self, config: TrainerConfigDict):
def __init__(self, config: AlgorithmConfigDict):
@make_tf_callable(self.get_session())
def update_target_fn(tau):
tau = tf.convert_to_tensor(tau, dtype=tf.float32)
@ -466,7 +466,7 @@ def setup_late_mixins(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> None:
TargetNetworkMixin.__init__(policy, config)
@ -475,7 +475,7 @@ def validate_spaces(
policy: Policy,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> None:
if not isinstance(action_space, Box):
raise UnsupportedSpaceException(

View file

@ -28,7 +28,7 @@ from ray.rllib.utils.torch_utils import (
l2_loss,
)
from ray.rllib.utils.typing import (
TrainerConfigDict,
AlgorithmConfigDict,
TensorType,
LocalOptimizer,
GradInfoDict,
@ -43,7 +43,7 @@ def build_ddpg_models_and_action_dist(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> Tuple[ModelV2, ActionDistribution]:
model = build_ddpg_models(policy, obs_space, action_space, config)
@ -202,7 +202,7 @@ def ddpg_actor_critic_loss(
def make_ddpg_optimizers(
policy: Policy, config: TrainerConfigDict
policy: Policy, config: AlgorithmConfigDict
) -> Tuple[LocalOptimizer]:
"""Create separate optimizers for actor & critic losses."""
@ -248,7 +248,7 @@ def before_init_fn(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> None:
# Create global step for counting the number of update operations.
policy.global_step = 0
@ -285,7 +285,7 @@ def setup_late_mixins(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> None:
ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss)
TargetNetworkMixin.__init__(policy)

View file

@ -47,21 +47,20 @@ class TestDDPG(unittest.TestCase):
# Test against all frameworks.
for _ in framework_iterator(config, with_eager_tracing=True):
""""""
trainer = config.build(env="Pendulum-v1")
algo = config.build(env="Pendulum-v1")
for i in range(num_iterations):
results = trainer.train()
results = algo.train()
check_train_results(results)
print(results)
check_compute_single_action(trainer)
check_compute_single_action(algo)
# Ensure apply_gradient_fn is being called and updating global_step
pol = trainer.get_policy()
pol = algo.get_policy()
if config.framework_str == "tf":
a = pol.get_session().run(pol.global_step)
else:
a = pol.global_step
check(a, 500)
trainer.stop()
algo.stop()
def test_ddpg_exploration_and_with_random_prerun(self):
"""Tests DDPG's Exploration (w/ random actions for n timesteps)."""
@ -74,21 +73,21 @@ class TestDDPG(unittest.TestCase):
config = ddpg.DDPGConfig().rollouts(num_rollout_workers=0)
config.seed = 42
# Default OUNoise setup.
trainer = config.build(env="Pendulum-v1")
algo = config.build(env="Pendulum-v1")
# Setting explore=False should always return the same action.
a_ = trainer.compute_single_action(obs, explore=False)
check(trainer.get_policy().global_timestep, 1)
a_ = algo.compute_single_action(obs, explore=False)
check(algo.get_policy().global_timestep, 1)
for i in range(50):
a = trainer.compute_single_action(obs, explore=False)
check(trainer.get_policy().global_timestep, i + 2)
a = algo.compute_single_action(obs, explore=False)
check(algo.get_policy().global_timestep, i + 2)
check(a, a_)
# explore=None (default: explore) should return different actions.
actions = []
for i in range(50):
actions.append(trainer.compute_single_action(obs))
check(trainer.get_policy().global_timestep, i + 52)
actions.append(algo.compute_single_action(obs))
check(algo.get_policy().global_timestep, i + 52)
check(np.std(actions), 0.0, false=True)
trainer.stop()
algo.stop()
# Check randomness at beginning.
config.exploration_config.update(
@ -102,30 +101,30 @@ class TestDDPG(unittest.TestCase):
}
)
trainer = ddpg.DDPG(config=config, env="Pendulum-v1")
algo = ddpg.DDPG(config=config, env="Pendulum-v1")
# ts=0 (get a deterministic action as per explore=False).
deterministic_action = trainer.compute_single_action(obs, explore=False)
check(trainer.get_policy().global_timestep, 1)
deterministic_action = algo.compute_single_action(obs, explore=False)
check(algo.get_policy().global_timestep, 1)
# ts=1-49 (in random window).
random_a = []
for i in range(1, 50):
random_a.append(trainer.compute_single_action(obs, explore=True))
check(trainer.get_policy().global_timestep, i + 1)
random_a.append(algo.compute_single_action(obs, explore=True))
check(algo.get_policy().global_timestep, i + 1)
check(random_a[-1], deterministic_action, false=True)
self.assertTrue(np.std(random_a) > 0.5)
# ts > 50 (a=deterministic_action + scale * N[0,1])
for i in range(50):
a = trainer.compute_single_action(obs, explore=True)
check(trainer.get_policy().global_timestep, i + 51)
a = algo.compute_single_action(obs, explore=True)
check(algo.get_policy().global_timestep, i + 51)
check(a, deterministic_action, rtol=0.1)
# ts >> 50 (BUT: explore=False -> expect deterministic action).
for i in range(50):
a = trainer.compute_single_action(obs, explore=False)
check(trainer.get_policy().global_timestep, i + 101)
a = algo.compute_single_action(obs, explore=False)
check(algo.get_policy().global_timestep, i + 101)
check(a, deterministic_action)
trainer.stop()
algo.stop()
def test_ddpg_loss_function(self):
"""Tests DDPG loss function results across all frameworks."""
@ -147,7 +146,7 @@ class TestDDPG(unittest.TestCase):
# Use very simple nets.
config.actor_hiddens = [10]
config.critic_hiddens = [10]
# Make sure, timing differences do not affect trainer.train().
# Make sure, timing differences do not affect Algorithm.train().
config.min_time_s_per_iteration = 0
config.min_sample_timesteps_per_iteration = 100
@ -222,9 +221,9 @@ class TestDDPG(unittest.TestCase):
for fw, sess in framework_iterator(
config, frameworks=("tf", "torch"), session=True
):
# Generate Trainer and get its default Policy object.
trainer = config.build(env=env)
policy = trainer.get_policy()
# Generate Algorithm and get its default Policy object.
algo = config.build(env=env)
policy = algo.get_policy()
p_sess = None
if sess:
p_sess = policy.get_session()
@ -367,9 +366,9 @@ class TestDDPG(unittest.TestCase):
tf_inputs.append(in_)
# Set a fake-batch to use
# (instead of sampling from replay buffer).
buf = trainer.local_replay_buffer
buf = algo.local_replay_buffer
patch_buffer_with_fake_sampling_method(buf, in_)
trainer.train()
algo.train()
updated_weights = policy.get_weights()
# Net must have changed.
if tf_updated_weights:
@ -388,9 +387,9 @@ class TestDDPG(unittest.TestCase):
in_ = tf_inputs[update_iteration]
# Set a fake-batch to use
# (instead of sampling from replay buffer).
buf = trainer.local_replay_buffer
buf = algo.local_replay_buffer
patch_buffer_with_fake_sampling_method(buf, in_)
trainer.train()
algo.train()
# Compare updated model and target weights.
for tf_key in tf_weights.keys():
tf_var = tf_weights[tf_key]
@ -407,7 +406,7 @@ class TestDDPG(unittest.TestCase):
else:
check(tf_var, torch_var, atol=0.1)
trainer.stop()
algo.stop()
def _get_batch_helper(self, obs_size, actions, batch_size):
return SampleBatch(

View file

@ -42,9 +42,9 @@ from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
from ray.rllib.utils.sgd import do_minibatch_sgd
from ray.rllib.utils.typing import (
EnvType,
PartialTrainerConfigDict,
PartialAlgorithmConfigDict,
ResultDict,
TrainerConfigDict,
AlgorithmConfigDict,
)
from ray.tune.logger import Logger
@ -52,7 +52,7 @@ logger = logging.getLogger(__name__)
class DDPPOConfig(PPOConfig):
"""Defines a configuration class from which a DDPPO Trainer can be built.
"""Defines a configuration class from which a DDPPO Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.ddppo import DDPPOConfig
@ -60,7 +60,7 @@ class DDPPOConfig(PPOConfig):
... .resources(num_gpus=1)\
... .rollouts(num_workers=10)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
@ -83,9 +83,9 @@ class DDPPOConfig(PPOConfig):
... )
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes a DDPPOConfig instance."""
super().__init__(trainer_class=trainer_class or DDPPO)
super().__init__(algo_class=algo_class or DDPPO)
# fmt: off
# __sphinx_doc_begin__
@ -93,7 +93,7 @@ class DDPPOConfig(PPOConfig):
self.keep_local_weights_in_sync = True
self.torch_distributed_backend = "gloo"
# Override some of PPO/Trainer's default values with DDPPO-specific values.
# Override some of PPO/Algorithm's default values with DDPPO-specific values.
# During the sampling phase, each rollout worker will collect a batch
# `rollout_fragment_length * num_envs_per_worker` steps in size.
self.rollout_fragment_length = 100
@ -144,7 +144,7 @@ class DDPPOConfig(PPOConfig):
distributed.
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -160,7 +160,7 @@ class DDPPOConfig(PPOConfig):
class DDPPO(PPO):
def __init__(
self,
config: Optional[PartialTrainerConfigDict] = None,
config: Optional[PartialAlgorithmConfigDict] = None,
env: Optional[Union[str, EnvType]] = None,
logger_creator: Optional[Callable[[], Logger]] = None,
remote_checkpoint_dir: Optional[str] = None,
@ -196,12 +196,12 @@ class DDPPO(PPO):
@classmethod
@override(PPO)
def get_default_config(cls) -> TrainerConfigDict:
def get_default_config(cls) -> AlgorithmConfigDict:
return DDPPOConfig().to_dict()
@override(PPO)
def validate_config(self, config):
"""Validates the Trainer's config dict.
"""Validates the Algorithm's config dict.
Args:
config: The Trainer's config to check.
@ -251,7 +251,7 @@ class DDPPO(PPO):
raise ValueError("DDPPO doesn't support KL penalties like PPO-1")
@override(PPO)
def setup(self, config: PartialTrainerConfigDict):
def setup(self, config: PartialAlgorithmConfigDict):
super().setup(config)
# Initialize torch process group for

View file

@ -2,7 +2,7 @@
Deep Q-Networks (DQN, Rainbow, Parametric DQN)
==============================================
This file defines the distributed Trainer class for the Deep Q-Networks
This file defines the distributed Algorithm class for the Deep Q-Networks
algorithm. See `dqn_[tf|torch]_policy.py` for the definition of the policies.
Detailed documentation:
@ -32,7 +32,7 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer
from ray.rllib.utils.typing import (
ResultDict,
TrainerConfigDict,
AlgorithmConfigDict,
)
from ray.rllib.utils.metrics import (
NUM_ENV_STEPS_SAMPLED,
@ -53,7 +53,7 @@ logger = logging.getLogger(__name__)
class DQNConfig(SimpleQConfig):
"""Defines a configuration class from which a DQN Trainer can be built.
"""Defines a configuration class from which a DQN Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.dqn.dqn import DQNConfig
@ -115,9 +115,9 @@ class DQNConfig(SimpleQConfig):
>>> .exploration(exploration_config=explore_config)
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes a DQNConfig instance."""
super().__init__(trainer_class=trainer_class or DQN)
super().__init__(algo_class=algo_class or DQN)
# DQN specific config settings.
# fmt: off
@ -248,7 +248,7 @@ class DQNConfig(SimpleQConfig):
zero, there is still a chance of drawing the sample.
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -281,7 +281,7 @@ class DQNConfig(SimpleQConfig):
return self
def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:
def calculate_rr_weights(config: AlgorithmConfigDict) -> List[float]:
"""Calculate the round robin weights for the rollout and train steps"""
if not config["training_intensity"]:
return [1, 1]
@ -311,11 +311,11 @@ def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:
class DQN(SimpleQ):
@classmethod
@override(SimpleQ)
def get_default_config(cls) -> TrainerConfigDict:
def get_default_config(cls) -> AlgorithmConfigDict:
return DEFAULT_CONFIG
@override(SimpleQ)
def validate_config(self, config: TrainerConfigDict) -> None:
def validate_config(self, config: AlgorithmConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
@ -325,7 +325,7 @@ class DQN(SimpleQ):
@override(SimpleQ)
def get_default_policy_class(
self, config: TrainerConfigDict
self, config: AlgorithmConfigDict
) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
return DQNTorchPolicy

View file

@ -27,7 +27,7 @@ from ray.rllib.utils.tf_utils import (
minimize_and_clip,
reduce_mean_ignore_inf,
)
from ray.rllib.utils.typing import ModelGradients, TensorType, TrainerConfigDict
from ray.rllib.utils.typing import ModelGradients, TensorType, AlgorithmConfigDict
tf1, tf, tfv = try_import_tf()
@ -152,7 +152,7 @@ def build_q_model(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> ModelV2:
"""Build q_model and target_model for DQN
@ -160,7 +160,7 @@ def build_q_model(
policy: The Policy, which will use the model for optimization.
obs_space (gym.spaces.Space): The policy's observation space.
action_space (gym.spaces.Space): The policy's action space.
config (TrainerConfigDict):
config (AlgorithmConfigDict):
Returns:
ModelV2: The Model for the Policy to use.
@ -328,7 +328,7 @@ def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> Tensor
def adam_optimizer(
policy: Policy, config: TrainerConfigDict
policy: Policy, config: AlgorithmConfigDict
) -> "tf.keras.optimizers.Optimizer":
if policy.config["framework"] in ["tf2", "tfe"]:
return tf.keras.optimizers.Adam(
@ -372,7 +372,7 @@ def setup_late_mixins(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> None:
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)

View file

@ -35,7 +35,7 @@ from ray.rllib.utils.torch_utils import (
reduce_mean_ignore_inf,
softmax_cross_entropy_with_logits,
)
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
from ray.rllib.utils.typing import TensorType, AlgorithmConfigDict
torch, nn = try_import_torch()
F = None
@ -145,7 +145,7 @@ def build_q_model_and_distribution(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> Tuple[ModelV2, TorchDistributionWrapper]:
"""Build q_model and target_model for DQN
@ -153,7 +153,7 @@ def build_q_model_and_distribution(
policy: The policy, which will use the model for optimization.
obs_space (gym.spaces.Space): The policy's observation space.
action_space (gym.spaces.Space): The policy's action space.
config (TrainerConfigDict):
config (AlgorithmConfigDict):
Returns:
(q_model, TorchCategorical)
@ -354,7 +354,7 @@ def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> Tensor
def adam_optimizer(
policy: Policy, config: TrainerConfigDict
policy: Policy, config: AlgorithmConfigDict
) -> "torch.optim.Optimizer":
# By this time, the models have been moved to the GPU - if any - and we
@ -384,7 +384,7 @@ def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
def setup_early_mixins(
policy: Policy, obs_space, action_space, config: TrainerConfigDict
policy: Policy, obs_space, action_space, config: AlgorithmConfigDict
) -> None:
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
@ -393,7 +393,7 @@ def before_loss_init(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
config: AlgorithmConfigDict,
) -> None:
ComputeTDErrorMixin.__init__(policy)
TargetNetworkMixin.__init__(policy)

View file

@ -3,9 +3,9 @@ import numpy as np
import random
from typing import Optional
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.dreamer.dreamer_torch_policy import DreamerTorchPolicy
from ray.rllib.agents.trainer import Trainer
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.evaluation.metrics import collect_metrics
@ -18,17 +18,17 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
PartialAlgorithmConfigDict,
SampleBatchType,
TrainerConfigDict,
AlgorithmConfigDict,
ResultDict,
)
logger = logging.getLogger(__name__)
class DreamerConfig(TrainerConfig):
"""Defines a configuration class from which a Dreamer Trainer can be built.
class DreamerConfig(AlgorithmConfig):
"""Defines a configuration class from which a Dreamer Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.dreamer import DreamerConfig
@ -36,7 +36,7 @@ class DreamerConfig(TrainerConfig):
... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
@ -61,7 +61,7 @@ class DreamerConfig(TrainerConfig):
def __init__(self):
"""Initializes a PPOConfig instance."""
super().__init__(trainer_class=Dreamer)
super().__init__(algo_class=Dreamer)
# fmt: off
# __sphinx_doc_begin__
@ -92,7 +92,7 @@ class DreamerConfig(TrainerConfig):
"action_init_std": 5.0,
}
# Override some of TrainerConfig's default values with PPO-specific values.
# Override some of AlgorithmConfig's default values with PPO-specific values.
# .rollouts()
self.num_workers = 0
self.num_envs_per_worker = 1
@ -112,7 +112,7 @@ class DreamerConfig(TrainerConfig):
# __sphinx_doc_end__
# fmt: on
@override(TrainerConfig)
@override(AlgorithmConfig)
def training(
self,
*,
@ -293,14 +293,14 @@ class DreamerIteration:
return _postprocess_gif(gif=gif)
class Dreamer(Trainer):
class Dreamer(Algorithm):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return DreamerConfig().to_dict()
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@override(Algorithm)
def validate_config(self, config: AlgorithmConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
@ -323,12 +323,12 @@ class Dreamer(Trainer):
if config["action_repeat"] > 1:
config["horizon"] = config["horizon"] / config["action_repeat"]
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict):
@override(Algorithm)
def get_default_policy_class(self, config: AlgorithmConfigDict):
return DreamerTorchPolicy
@override(Trainer)
def setup(self, config: PartialTrainerConfigDict):
@override(Algorithm)
def setup(self, config: PartialAlgorithmConfigDict):
super().setup(config)
# `training_iteration` implementation: Setup buffer in `setup`, not
# in `execution_plan` (deprecated).
@ -344,7 +344,7 @@ class Dreamer(Trainer):
self.local_replay_buffer.add(samples)
@staticmethod
@override(Trainer)
@override(Algorithm)
def execution_plan(workers, config, **kwargs):
assert (
len(kwargs) == 0
@ -376,7 +376,7 @@ class Dreamer(Trainer):
)
return rollouts
@override(Trainer)
@override(Algorithm)
def training_step(self) -> ResultDict:
local_worker = self.workers.local_worker()

View file

@ -8,7 +8,7 @@ from ray.rllib.utils.test_utils import framework_iterator
class TestDreamer(unittest.TestCase):
"""Sanity tests for DreamerTrainer."""
"""Sanity tests for Dreamer."""
def setUp(self):
ray.init()
@ -17,7 +17,7 @@ class TestDreamer(unittest.TestCase):
ray.shutdown()
def test_dreamer_compilation(self):
"""Test whether an DreamerTrainer can be built with all frameworks."""
"""Test whether an Dreamer can be built with all frameworks."""
config = dreamer.DreamerConfig()
config.environment(
env=RandomEnv,

View file

@ -9,7 +9,7 @@ import time
from typing import Optional
import ray
from ray.rllib.agents import Trainer, TrainerConfig
from ray.rllib.algorithms import Algorithm, AlgorithmConfig
from ray.rllib.algorithms.es import optimizers, utils
from ray.rllib.algorithms.es.es_tf_policy import ESTFPolicy, rollout
from ray.rllib.env.env_context import EnvContext
@ -24,7 +24,7 @@ from ray.rllib.utils.metrics import (
NUM_ENV_STEPS_TRAINED,
)
from ray.rllib.utils.torch_utils import set_torch_seed
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.typing import AlgorithmConfigDict
logger = logging.getLogger(__name__)
@ -41,8 +41,8 @@ Result = namedtuple(
)
class ESConfig(TrainerConfig):
"""Defines a configuration class from which an ES Trainer can be built.
class ESConfig(AlgorithmConfig):
"""Defines a configuration class from which an ES Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.es import ESConfig
@ -50,7 +50,7 @@ class ESConfig(TrainerConfig):
... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
@ -75,7 +75,7 @@ class ESConfig(TrainerConfig):
def __init__(self):
"""Initializes a ESConfig instance."""
super().__init__(trainer_class=ES)
super().__init__(algo_class=ES)
# fmt: off
# __sphinx_doc_begin__
@ -91,11 +91,11 @@ class ESConfig(TrainerConfig):
self.noise_size = 250000000
self.report_length = 10
# Override some of TrainerConfig's default values with ES-specific values.
# Override some of AlgorithmConfig's default values with ES-specific values.
self.train_batch_size = 10000
self.num_workers = 10
self.observation_filter = "MeanStdFilter"
# ARS will use Trainer's evaluation WorkerSet (if evaluation_interval > 0).
# ARS will use Algorithm's evaluation WorkerSet (if evaluation_interval > 0).
# Therefore, we must be careful not to use more than 1 env per eval worker
# (would break ARSPolicy's compute_single_action method) and to not do
# obs-filtering.
@ -105,7 +105,7 @@ class ESConfig(TrainerConfig):
# __sphinx_doc_end__
# fmt: on
@override(TrainerConfig)
@override(AlgorithmConfig)
def training(
self,
*,
@ -137,7 +137,7 @@ class ESConfig(TrainerConfig):
report_length: How many of the last rewards we average over.
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -319,16 +319,16 @@ def get_policy_class(config):
return policy_cls
class ES(Trainer):
class ES(Algorithm):
"""Large-scale implementation of Evolution Strategies in Ray."""
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return ESConfig().to_dict()
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@override(Algorithm)
def validate_config(self, config: AlgorithmConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
@ -348,7 +348,7 @@ class ES(Trainer):
"`NoFilter` for ES!"
)
@override(Trainer)
@override(Algorithm)
def setup(self, config):
# Setup our config: Merge the user-supplied config (which could
# be a partial config dict with the class' default).
@ -393,7 +393,7 @@ class ES(Trainer):
self.reward_list = []
self.tstart = time.time()
@override(Trainer)
@override(Algorithm)
def get_policy(self, policy=DEFAULT_POLICY_ID):
if policy != DEFAULT_POLICY_ID:
raise ValueError(
@ -402,7 +402,7 @@ class ES(Trainer):
)
return self.policy
@override(Trainer)
@override(Algorithm)
def step(self):
config = self.config
@ -503,7 +503,7 @@ class ES(Trainer):
return result
@override(Trainer)
@override(Algorithm)
def compute_single_action(self, observation, *args, **kwargs):
action, _, _ = self.policy.compute_actions([observation], update=False)
if kwargs.get("full_fetch"):
@ -514,7 +514,7 @@ class ES(Trainer):
def compute_action(self, observation, *args, **kwargs):
return self.compute_single_action(observation, *args, **kwargs)
@override(Trainer)
@override(Algorithm)
def _sync_weights_to_workers(self, *, worker_set=None, workers=None):
# Broadcast the new policy weights to all evaluation workers.
assert worker_set is not None
@ -522,7 +522,7 @@ class ES(Trainer):
weights = ray.put(self.policy.get_flat_weights())
worker_set.foreach_policy(lambda p, pid: p.set_flat_weights(ray.get(weights)))
@override(Trainer)
@override(Algorithm)
def cleanup(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for w in self.workers:

View file

@ -7,7 +7,7 @@ from ray.rllib.utils.test_utils import check_compute_single_action, framework_it
class TestES(unittest.TestCase):
def test_es_compilation(self):
"""Test whether an ESTrainer can be built on all frameworks."""
"""Test whether an ESAlgorithm can be built on all frameworks."""
ray.init(num_cpus=4)
config = es.ESConfig()
# Keep it simple.

View file

@ -8,7 +8,8 @@ from typing import Optional, Type, List, Dict, Union, Callable, Any
import ray
from ray.actor import ActorHandle
from ray.rllib import SampleBatch
from ray.rllib.agents.trainer import Trainer, TrainerConfig
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer
from ray.rllib.execution.learner_thread import LearnerThread
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
@ -38,9 +39,9 @@ from ray.rllib.utils.metrics import (
# from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
PartialAlgorithmConfigDict,
ResultDict,
TrainerConfigDict,
AlgorithmConfigDict,
SampleBatchType,
T,
)
@ -55,7 +56,7 @@ from ray.types import ObjectRef
logger = logging.getLogger(__name__)
class ImpalaConfig(TrainerConfig):
class ImpalaConfig(AlgorithmConfig):
"""Defines a configuration class from which an Impala can be built.
Example:
@ -64,7 +65,7 @@ class ImpalaConfig(TrainerConfig):
... .resources(num_gpus=4)\
... .rollouts(num_rollout_workers=64)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
@ -87,9 +88,9 @@ class ImpalaConfig(TrainerConfig):
... )
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes a ImpalaConfig instance."""
super().__init__(trainer_class=trainer_class or Impala)
super().__init__(algo_class=algo_class or Impala)
# fmt: off
# __sphinx_doc_begin__
@ -127,7 +128,7 @@ class ImpalaConfig(TrainerConfig):
self._lr_vf = 0.0005
self.after_train_step = None
# Override some of TrainerConfig's default values with ARS-specific values.
# Override some of AlgorithmConfig's default values with ARS-specific values.
self.rollout_fragment_length = 50
self.train_batch_size = 500
self.num_workers = 2
@ -140,7 +141,7 @@ class ImpalaConfig(TrainerConfig):
# Deprecated value.
self.num_data_loader_buffers = DEPRECATED_VALUE
@override(TrainerConfig)
@override(AlgorithmConfig)
def training(
self,
*,
@ -266,7 +267,7 @@ class ImpalaConfig(TrainerConfig):
in flight, or enable compression in your experiment of timesteps.
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -433,8 +434,8 @@ class BroadcastUpdateLearnerWeights:
self.workers.local_worker().set_global_vars(_get_global_vars())
class Impala(Trainer):
"""Importance weighted actor/learner architecture (IMPALA) Trainer
class Impala(Algorithm):
"""Importance weighted actor/learner architecture (IMPALA) Algorithm
== Overview of data flow in IMPALA ==
1. Policy evaluation in parallel across `num_workers` actors produces
@ -448,13 +449,13 @@ class Impala(Trainer):
"""
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return ImpalaConfig().to_dict()
@override(Trainer)
@override(Algorithm)
def get_default_policy_class(
self, config: PartialTrainerConfigDict
self, config: PartialAlgorithmConfigDict
) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
if config["vtrace"]:
@ -488,7 +489,7 @@ class Impala(Trainer):
return A3CTFPolicy
@override(Trainer)
@override(Algorithm)
def validate_config(self, config):
# Call the super class' validation method first.
super().validate_config(config)
@ -536,8 +537,8 @@ class Impala(Trainer):
)
config["_tf_policy_handles_more_than_one_loss"] = True
@override(Trainer)
def setup(self, config: PartialTrainerConfigDict):
@override(Algorithm)
def setup(self, config: PartialAlgorithmConfigDict):
super().setup(config)
if self.config["_disable_execution_plan_api"]:
@ -607,7 +608,7 @@ class Impala(Trainer):
self._learner_thread.start()
self.workers_that_need_updates = set()
@override(Trainer)
@override(Algorithm)
def training_step(self) -> ResultDict:
unprocessed_sample_batches = self.get_samples_from_workers()
@ -629,7 +630,7 @@ class Impala(Trainer):
return train_results
@staticmethod
@override(Trainer)
@override(Algorithm)
def execution_plan(workers, config, **kwargs):
assert (
len(kwargs) == 0
@ -687,7 +688,7 @@ class Impala(Trainer):
)
@classmethod
@override(Trainer)
@override(Algorithm)
def default_resource_request(cls, config):
cf = dict(cls.get_default_config(), **config)
@ -887,7 +888,7 @@ class Impala(Trainer):
# Update global vars of the local worker.
self.workers.local_worker().set_global_vars(global_vars)
@override(Trainer)
@override(Algorithm)
def on_worker_failures(
self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]
):
@ -900,7 +901,7 @@ class Impala(Trainer):
self._sampling_actor_manager.remove_workers(removed_workers)
self._sampling_actor_manager.add_workers(new_workers)
@override(Trainer)
@override(Algorithm)
def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
result = super()._compile_iteration_results(
step_ctx=step_ctx, iteration_results=iteration_results
@ -915,7 +916,7 @@ class Impala(Trainer):
class AggregatorWorker:
"""A worker for doing tree aggregation of collected episodes"""
def __init__(self, config: TrainerConfigDict):
def __init__(self, config: AlgorithmConfigDict):
self.config = config
self._mixin_buffer = MixInMultiAgentReplayBuffer(
capacity=(

View file

@ -82,7 +82,7 @@ class VTraceLoss:
bootstrap_value: A float32 tensor of shape [B].
dist_class: action distribution class for logits.
valid_mask: A bool tensor of valid RNN input elements (#2992).
config: Trainer config dict.
config: Algorithm config dict.
"""
# Compute vtrace on the CPU for better perf.

View file

@ -80,7 +80,7 @@ class VTraceLoss:
bootstrap_value: A float32 tensor of shape [B].
dist_class: action distribution class for logits.
valid_mask: A bool tensor of valid RNN input elements (#2992).
config: Trainer config dict.
config: Algorithm config dict.
"""
import ray.rllib.algorithms.impala.vtrace_torch as vtrace

View file

@ -12,21 +12,21 @@ and the README for how to run with the multi-agent particle envs.
import logging
from typing import List, Optional, Type
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.dqn.dqn import DQN
from ray.rllib.algorithms.maddpg.maddpg_tf_policy import MADDPGTFPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import Deprecated, override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.typing import AlgorithmConfigDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class MADDPGConfig(TrainerConfig):
"""Defines a configuration class from which a MADDPG Trainer can be built.
class MADDPGConfig(AlgorithmConfig):
"""Defines a configuration class from which a MADDPG Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.maddpg.maddpg import MADDPGConfig
@ -44,9 +44,9 @@ class MADDPGConfig(TrainerConfig):
>>> .resources(num_gpus=0)\
>>> .rollouts(num_rollout_workers=4)\
>>> .environment("CartPole-v1")
>>> trainer = config.build()
>>> algo = config.build()
>>> while True:
>>> trainer.train()
>>> algo.train()
Example:
>>> from ray.rllib.algorithms.maddpg.maddpg import MADDPGConfig
@ -61,9 +61,9 @@ class MADDPGConfig(TrainerConfig):
>>> )
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes a DQNConfig instance."""
super().__init__(trainer_class=trainer_class or MADDPG)
super().__init__(algo_class=algo_class or MADDPG)
# fmt: off
# __sphinx_doc_begin__
@ -97,7 +97,7 @@ class MADDPGConfig(TrainerConfig):
self.actor_feature_reg = 0.001
self.grad_norm_clipping = 0.5
# Changes to Trainer's default:
# Changes to Algorithm's default:
self.rollout_fragment_length = 100
self.train_batch_size = 1024
self.num_workers = 1
@ -105,7 +105,7 @@ class MADDPGConfig(TrainerConfig):
# fmt: on
# __sphinx_doc_end__
@override(TrainerConfig)
@override(AlgorithmConfig)
def training(
self,
*,
@ -201,7 +201,7 @@ class MADDPGConfig(TrainerConfig):
value.
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
@ -277,11 +277,11 @@ def before_learn_on_batch(multi_agent_batch, policies, train_batch_size):
class MADDPG(DQN):
@classmethod
@override(DQN)
def get_default_config(cls) -> TrainerConfigDict:
def get_default_config(cls) -> AlgorithmConfigDict:
return MADDPGConfig().to_dict()
@override(DQN)
def validate_config(self, config: TrainerConfigDict) -> None:
def validate_config(self, config: AlgorithmConfigDict) -> None:
"""Adds the `before_learn_on_batch` hook to the config.
This hook is called explicitly prior to TrainOneStep() in the execution
@ -299,7 +299,7 @@ class MADDPG(DQN):
config["before_learn_on_batch"] = f
@override(DQN)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
return MADDPGTFPolicy

View file

@ -46,12 +46,12 @@ class TestMADDPG(unittest.TestCase):
# Only working for tf right now.
for _ in framework_iterator(config, frameworks="tf"):
trainer = config.build()
algo = config.build()
for i in range(num_iterations):
results = trainer.train()
results = algo.train()
check_train_results(results)
print(results)
trainer.stop()
algo.stop()
if __name__ == "__main__":

View file

@ -2,8 +2,8 @@ import logging
import numpy as np
from typing import Optional, Type
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import (
@ -20,20 +20,20 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.sgd import standardized
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.typing import AlgorithmConfigDict
from ray.util.iter import from_actors, LocalIterator
logger = logging.getLogger(__name__)
class MAMLConfig(TrainerConfig):
"""Defines a configuration class from which a MAML Trainer can be built.
class MAMLConfig(AlgorithmConfig):
"""Defines a configuration class from which a MAML Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.maml import MAMLConfig
>>> config = MAMLConfig().training(use_gae=False).resources(num_gpus=1)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
@ -56,9 +56,9 @@ class MAMLConfig(TrainerConfig):
... )
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes a PGConfig instance."""
super().__init__(trainer_class=trainer_class or MAML)
super().__init__(algo_class=algo_class or MAML)
# fmt: off
# __sphinx_doc_begin__
@ -77,7 +77,7 @@ class MAMLConfig(TrainerConfig):
self.inner_lr = 0.1
self.use_meta_env = True
# Override some of TrainerConfig's default values with MAML-specific values.
# Override some of AlgorithmConfig's default values with MAML-specific values.
self.rollout_fragment_length = 200
self.create_env_on_local_worker = True
self.lr = 1e-3
@ -136,7 +136,7 @@ class MAMLConfig(TrainerConfig):
use_meta_env: Use Meta Env Template.
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -252,14 +252,14 @@ def inner_adaptation(workers, samples):
e.learn_on_batch.remote(samples[i])
class MAML(Trainer):
class MAML(Algorithm):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return MAMLConfig().to_dict()
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@override(Algorithm)
def validate_config(self, config: AlgorithmConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
@ -281,8 +281,8 @@ class MAML(Trainer):
"(local) worker! Set `create_env_on_driver` to True."
)
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@override(Algorithm)
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
if config["framework"] == "torch":
from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy
@ -297,9 +297,9 @@ class MAML(Trainer):
return MAMLEagerTFPolicy
@staticmethod
@override(Trainer)
@override(Algorithm)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
workers: WorkerSet, config: AlgorithmConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0

View file

@ -1,7 +1,7 @@
from typing import Optional, Type
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
@ -21,13 +21,13 @@ from ray.rllib.utils.metrics import (
)
from ray.rllib.utils.typing import (
ResultDict,
TrainerConfigDict,
AlgorithmConfigDict,
)
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
class MARWILConfig(TrainerConfig):
"""Defines a configuration class from which a MARWIL Trainer can be built.
class MARWILConfig(AlgorithmConfig):
"""Defines a configuration class from which a MARWIL Algorithm can be built.
Example:
@ -36,7 +36,7 @@ class MARWILConfig(TrainerConfig):
>>> config = MARWILConfig().training(beta=1.0, lr=0.00001, gamma=0.99)\
... .offline_data(input_=["./rllib/tests/data/cartpole/large.json"])
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build()
>>> trainer.train()
@ -61,9 +61,9 @@ class MARWILConfig(TrainerConfig):
... )
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes a MARWILConfig instance."""
super().__init__(trainer_class=trainer_class or MARWIL)
super().__init__(algo_class=algo_class or MARWIL)
# fmt: off
# __sphinx_doc_begin__
@ -91,7 +91,7 @@ class MARWILConfig(TrainerConfig):
self.vf_coeff = 1.0
self.grad_clip = None
# Override some of TrainerConfig's default values with MARWIL-specific values.
# Override some of AlgorithmConfig's default values with MARWIL-specific values.
# You should override input_ to point to an offline dataset
# (see trainer.py and trainer_config.py).
@ -114,7 +114,7 @@ class MARWILConfig(TrainerConfig):
# __sphinx_doc_end__
# fmt: on
@override(TrainerConfig)
@override(AlgorithmConfig)
def training(
self,
*,
@ -181,7 +181,7 @@ class MARWILConfig(TrainerConfig):
grad_clip: If specified, clip the global norm of gradients by this amount.
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -206,14 +206,14 @@ class MARWILConfig(TrainerConfig):
return self
class MARWIL(Trainer):
class MARWIL(Algorithm):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return MARWILConfig().to_dict()
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@override(Algorithm)
def validate_config(self, config: AlgorithmConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
@ -231,8 +231,8 @@ class MARWIL(Trainer):
"calculate accum., discounted returns)!"
)
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@override(Algorithm)
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
if config["framework"] == "torch":
from ray.rllib.algorithms.marwil.marwil_torch_policy import (
MARWILTorchPolicy,
@ -250,7 +250,7 @@ class MARWIL(Trainer):
return MARWILTF2Policy
@override(Trainer)
@override(Algorithm)
def training_step(self) -> ResultDict:
# Collect SampleBatches from sample workers.
batch = synchronous_parallel_sample(worker_set=self.workers)

View file

@ -18,7 +18,7 @@ torch, _ = try_import_torch()
class MARWILTorchPolicy(ValueNetworkMixin, PostprocessAdvantages, TorchPolicyV2):
"""PyTorch policy class used with MarwilTrainer."""
"""PyTorch policy class used with Marwil."""
def __init__(self, observation_space, action_space, config):
config = dict(

View file

@ -31,7 +31,7 @@ class TestMARWIL(unittest.TestCase):
ray.shutdown()
def test_marwil_compilation_and_learning_from_offline_file(self):
"""Test whether a MARWILTrainer can be built with all frameworks.
"""Test whether a MARWILAlgorithm can be built with all frameworks.
Learns from a historic-data file.
To generate this data, first run:
@ -83,7 +83,7 @@ class TestMARWIL(unittest.TestCase):
if not learnt:
raise ValueError(
"MARWILTrainer did not reach {} reward from expert "
"MARWILAlgorithm did not reach {} reward from expert "
"offline data!".format(min_reward)
)

View file

@ -5,8 +5,8 @@ from typing import List, Optional, Type
import ray
from ray.rllib.algorithms.mbmpo.model_ensemble import DynamicsEnsembleCustomModel
from ray.rllib.algorithms.mbmpo.utils import calculate_gae_advantages, MBMPOExploration
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.wrappers.model_vector_env import model_vector_env
from ray.rllib.evaluation.metrics import (
@ -29,14 +29,14 @@ from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.sgd import standardized
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.utils.typing import EnvType, TrainerConfigDict
from ray.rllib.utils.typing import EnvType, AlgorithmConfigDict
from ray.util.iter import from_actors, LocalIterator
logger = logging.getLogger(__name__)
class MBMPOConfig(TrainerConfig):
"""Defines a configuration class from which an MBMPO Trainer can be built.
class MBMPOConfig(AlgorithmConfig):
"""Defines a configuration class from which an MBMPO Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.mbmpo import MBMPOConfig
@ -44,7 +44,7 @@ class MBMPOConfig(TrainerConfig):
... .resources(num_gpus=4)\
... .rollouts(num_rollout_workers=64)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
@ -67,9 +67,9 @@ class MBMPOConfig(TrainerConfig):
... )
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes a MBMPOConfig instance."""
super().__init__(trainer_class=trainer_class or MBMPO)
super().__init__(algo_class=algo_class or MBMPO)
# fmt: off
# __sphinx_doc_begin__
@ -127,7 +127,7 @@ class MBMPOConfig(TrainerConfig):
# How many iterations through MAML per MBMPO iteration.
self.num_maml_steps = 10
# Override some of TrainerConfig's default values with MBMPO-specific
# Override some of AlgorithmConfig's default values with MBMPO-specific
# values.
self.batch_mode = "complete_episodes"
# Size of batches collected from each worker.
@ -149,7 +149,7 @@ class MBMPOConfig(TrainerConfig):
self.vf_share_layers = DEPRECATED_VALUE
self._disable_execution_plan_api = False
@override(TrainerConfig)
@override(AlgorithmConfig)
def training(
self,
*,
@ -198,7 +198,7 @@ class MBMPOConfig(TrainerConfig):
num_maml_steps: How many iterations through MAML per MBMPO iteration.
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -360,7 +360,7 @@ def inner_adaptation(workers: WorkerSet, samples: List[SampleBatch]):
"""Performs one gradient descend step on each remote worker.
Args:
workers: The WorkerSet of the Trainer.
workers: The WorkerSet of the Algorithm.
samples (List[SampleBatch]): The list of SampleBatches to perform
a training step on (one for each remote worker).
"""
@ -422,7 +422,7 @@ def sync_stats(workers: WorkerSet) -> None:
e.foreach_policy.remote(set_func, normalizations=normalization_dict)
def post_process_samples(samples, config: TrainerConfigDict):
def post_process_samples(samples, config: AlgorithmConfigDict):
# Instead of using NN for value function, we use regression
split_lst = []
for sample in samples:
@ -446,10 +446,10 @@ def post_process_samples(samples, config: TrainerConfigDict):
return samples, split_lst
class MBMPO(Trainer):
"""Model-Based Meta Policy Optimization (MB-MPO) Trainer.
class MBMPO(Algorithm):
"""Model-Based Meta Policy Optimization (MB-MPO) Algorithm.
This file defines the distributed Trainer class for model-based meta
This file defines the distributed Algorithm class for model-based meta
policy optimization.
See `mbmpo_[tf|torch]_policy.py` for the definition of the policy loss.
@ -458,12 +458,12 @@ class MBMPO(Trainer):
"""
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return DEFAULT_CONFIG
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@override(Algorithm)
def validate_config(self, config: AlgorithmConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
@ -491,16 +491,16 @@ class MBMPO(Trainer):
"(local) worker! Set `create_env_on_driver` to True."
)
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@override(Algorithm)
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
from ray.rllib.algorithms.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
return MBMPOTorchPolicy
@staticmethod
@override(Trainer)
@override(Algorithm)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
workers: WorkerSet, config: AlgorithmConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0
@ -576,7 +576,7 @@ class MBMPO(Trainer):
return train_op
@staticmethod
@override(Trainer)
@override(Algorithm)
def validate_env(env: EnvType, env_context: EnvContext) -> None:
"""Validates the local_worker's env object (after creation).

158
rllib/algorithms/mock.py Normal file
View file

@ -0,0 +1,158 @@
import os
import pickle
import numpy as np
from ray.tune import result as tune_result
from ray.rllib.algorithms.algorithm import Algorithm, with_common_config
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import AlgorithmConfigDict
class _MockTrainer(Algorithm):
"""Mock trainer for use in tests"""
@classmethod
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return with_common_config(
{
"mock_error": False,
"persistent_error": False,
"test_variable": 1,
"num_workers": 0,
"user_checkpoint_freq": 0,
"framework": "tf",
}
)
@classmethod
def default_resource_request(cls, config):
return None
@override(Algorithm)
def setup(self, config):
# Setup our config: Merge the user-supplied config (which could
# be a partial config dict with the class' default).
self.config = self.merge_trainer_configs(
self.get_default_config(), config, self._allow_unknown_configs
)
self.config["env"] = self._env_id
self.validate_config(self.config)
self.callbacks = self.config["callbacks"]()
# Add needed properties.
self.info = None
self.restored = False
@override(Algorithm)
def step(self):
if (
self.config["mock_error"]
and self.iteration == 1
and (self.config["persistent_error"] or not self.restored)
):
raise Exception("mock error")
result = dict(
episode_reward_mean=10, episode_len_mean=10, timesteps_this_iter=10, info={}
)
if self.config["user_checkpoint_freq"] > 0 and self.iteration > 0:
if self.iteration % self.config["user_checkpoint_freq"] == 0:
result.update({tune_result.SHOULD_CHECKPOINT: True})
return result
@override(Algorithm)
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
with open(path, "wb") as f:
pickle.dump(self.info, f)
return path
@override(Algorithm)
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path, "rb") as f:
info = pickle.load(f)
self.info = info
self.restored = True
@staticmethod
@override(Algorithm)
def _get_env_id_and_creator(env_specifier, config):
# No env to register.
return None, None
def set_info(self, info):
self.info = info
return info
def get_info(self, sess=None):
return self.info
class _SigmoidFakeData(_MockTrainer):
"""Trainer that returns sigmoid learning curves.
This can be helpful for evaluating early stopping algorithms."""
@classmethod
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return with_common_config(
{
"width": 100,
"height": 100,
"offset": 0,
"iter_time": 10,
"iter_timesteps": 1,
"num_workers": 0,
}
)
def step(self):
i = max(0, self.iteration - self.config["offset"])
v = np.tanh(float(i) / self.config["width"])
v *= self.config["height"]
return dict(
episode_reward_mean=v,
episode_len_mean=v,
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"],
info={},
)
class _ParameterTuningTrainer(_MockTrainer):
@classmethod
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return with_common_config(
{
"reward_amt": 10,
"dummy_param": 10,
"dummy_param2": 15,
"iter_time": 10,
"iter_timesteps": 1,
"num_workers": 0,
}
)
def step(self):
return dict(
episode_reward_mean=self.config["reward_amt"] * self.iteration,
episode_len_mean=self.config["reward_amt"],
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"],
info={},
)
def _algorithm_import_failed(trace):
"""Returns dummy Algorithm class for if PyTorch etc. is not installed."""
class _TrainerImportFailed(Algorithm):
_name = "TrainerImportFailed"
def setup(self, config):
raise ImportError(trace)
return _TrainerImportFailed

View file

@ -1,21 +1,21 @@
from typing import Type
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.typing import AlgorithmConfigDict
class PGConfig(TrainerConfig):
"""Defines a configuration class from which a PG Trainer can be built.
class PGConfig(AlgorithmConfig):
"""Defines a configuration class from which a PG Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.pg import PGConfig
>>> config = PGConfig().training(lr=0.01).resources(num_gpus=1)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
@ -41,11 +41,11 @@ class PGConfig(TrainerConfig):
def __init__(self):
"""Initializes a PGConfig instance."""
super().__init__(trainer_class=PG)
super().__init__(algo_class=PG)
# fmt: off
# __sphinx_doc_begin__
# Override some of TrainerConfig's default values with PG-specific values.
# Override some of AlgorithmConfig's default values with PG-specific values.
self.num_workers = 0
self.lr = 0.0004
self._disable_preprocessor_api = True
@ -53,7 +53,7 @@ class PGConfig(TrainerConfig):
# fmt: on
class PG(Trainer):
class PG(Algorithm):
"""Policy Gradient (PG) Trainer.
Defines the distributed Trainer class for policy gradients.
@ -69,11 +69,11 @@ class PG(Trainer):
"""
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return PGConfig().to_dict()
@override(Trainer)
@override(Algorithm)
def get_default_policy_class(self, config) -> Type[Policy]:
if config["framework"] == "torch":
from ray.rllib.algorithms.pg.pg_torch_policy import PGTorchPolicy

View file

@ -2,7 +2,7 @@
Proximal Policy Optimization (PPO)
==================================
This file defines the distributed Trainer class for proximal policy
This file defines the distributed Algorithm class for proximal policy
optimization.
See `ppo_[tf|torch]_policy.py` for the definition of the policy loss.
@ -14,8 +14,8 @@ from typing import List, Optional, Type, Union
import math
from ray.util.debug import log_once
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.execution.rollout_ops import (
standardize_fields,
)
@ -27,9 +27,13 @@ from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.deprecation import (
Deprecated,
DEPRECATED_VALUE,
deprecation_warning,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
from ray.rllib.utils.typing import TrainerConfigDict, ResultDict
from ray.rllib.utils.typing import AlgorithmConfigDict, ResultDict
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
@ -40,8 +44,8 @@ from ray.rllib.utils.metrics import (
logger = logging.getLogger(__name__)
class PPOConfig(TrainerConfig):
"""Defines a configuration class from which a PPO Trainer can be built.
class PPOConfig(AlgorithmConfig):
"""Defines a configuration class from which a PPO Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.ppo import PPOConfig
@ -49,7 +53,7 @@ class PPOConfig(TrainerConfig):
... .resources(num_gpus=0)\
... .rollouts(num_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
@ -72,9 +76,9 @@ class PPOConfig(TrainerConfig):
... )
"""
def __init__(self, trainer_class=None):
def __init__(self, algo_class=None):
"""Initializes a PPOConfig instance."""
super().__init__(trainer_class=trainer_class or PPO)
super().__init__(algo_class=algo_class or PPO)
# fmt: off
# __sphinx_doc_begin__
@ -95,7 +99,7 @@ class PPOConfig(TrainerConfig):
self.grad_clip = None
self.kl_target = 0.01
# Override some of TrainerConfig's default values with PPO-specific values.
# Override some of AlgorithmConfig's default values with PPO-specific values.
self.rollout_fragment_length = 200
self.train_batch_size = 4000
self.lr = 5e-5
@ -103,7 +107,10 @@ class PPOConfig(TrainerConfig):
# __sphinx_doc_end__
# fmt: on
@override(TrainerConfig)
# Deprecated keys.
self.vf_share_layers = DEPRECATED_VALUE
@override(AlgorithmConfig)
def training(
self,
*,
@ -122,6 +129,8 @@ class PPOConfig(TrainerConfig):
vf_clip_param: Optional[float] = None,
grad_clip: Optional[float] = None,
kl_target: Optional[float] = None,
# Deprecated.
vf_share_layers=None,
**kwargs,
) -> "PPOConfig":
"""Sets the training related configuration.
@ -155,7 +164,7 @@ class PPOConfig(TrainerConfig):
kl_target: Target value for KL divergence.
Returns:
This updated TrainerConfig object.
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
@ -191,6 +200,14 @@ class PPOConfig(TrainerConfig):
if kl_target is not None:
self.kl_target = kl_target
if vf_share_layers is not None:
self.model["vf_share_layers"] = vf_share_layers
deprecation_warning(
old="ppo.DEFAULT_CONFIG['vf_share_layers']",
new="PPOConfig().training(model={'vf_share_layers': ...})",
error=False,
)
return self
@ -268,20 +285,20 @@ def warn_about_bad_reward_scales(config, result):
return result
class PPO(Trainer):
# TODO: Change the return value of this method to return a TrainerConfig object
class PPO(Algorithm):
# TODO: Change the return value of this method to return a AlgorithmConfig object
# instead.
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return PPOConfig().to_dict()
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
"""Validates the Trainer's config dict.
@override(Algorithm)
def validate_config(self, config: AlgorithmConfigDict) -> None:
"""Validates the Algorithm's config dict.
Args:
config: The Trainer's config to check.
config: The Algorithm's config to check.
Raises:
ValueError: In case something is wrong with the config.
@ -364,8 +381,8 @@ class PPO(Trainer):
"simple_optimizer=True if this doesn't work for you."
)
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
@override(Algorithm)
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
if config["framework"] == "torch":
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy

Some files were not shown because too many files have changed in this diff Show more