mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Trainer
to Algorithm
renaming. (#25539)
This commit is contained in:
parent
0c527b4502
commit
130b7eeaba
240 changed files with 6667 additions and 6124 deletions
|
@ -2,7 +2,7 @@
|
|||
|
||||
.. include:: /_includes/rllib/we_are_hiring.rst
|
||||
|
||||
.. TODO: We need trainers, environments, algorithms, policies, models here. Likely in that order.
|
||||
.. TODO: We need algorithms, environments, policies, models here. Likely in that order.
|
||||
Execution plans are not a "core" concept for users. Sample batches should probably also be left out.
|
||||
|
||||
.. _rllib-core-concepts:
|
||||
|
@ -10,31 +10,59 @@
|
|||
Key Concepts
|
||||
============
|
||||
|
||||
On this page, we'll cover the key concepts to help you understand how RLlib works and how to use it.
|
||||
In RLlib you use `trainers` to train `algorithms`.
|
||||
These algorithms use `policies` to select actions for your agents.
|
||||
Given a policy, `evaluation` of a policy produces `sample batches` of experiences.
|
||||
You can also customize the `execution plans` of your RL experiments.
|
||||
On this page, we'll cover the key concepts to help you understand how RLlib works and
|
||||
how to use it. In RLlib you use ``algorithms`` to learn in problem environments.
|
||||
These algorithms use ``policies`` to select actions for your agents.
|
||||
Given a policy, ``evaluation`` of a policy produces ``sample batches`` of experiences.
|
||||
You can also customize the ``training_iteration``\s of your RL experiments.
|
||||
|
||||
Trainers
|
||||
--------
|
||||
Algorithms
|
||||
----------
|
||||
|
||||
Trainers bring all RLlib components together, making algorithms accessible via RLlib's Python API and its command line interface (CLI).
|
||||
They manage algorithm configuration, setup of the rollout workers and optimizer, and collection of training metrics.
|
||||
Trainers also implement the :ref:`Tune Trainable API <tune-60-seconds>` for easy experiment management.
|
||||
Algorithms bring all RLlib components together, making learning of different tasks
|
||||
accessible via RLlib's Python API and its command line interface (CLI).
|
||||
Each ``Algorithm`` class is managed by its respective ``AlgorithmConfig``, for example to
|
||||
configure a ``PPO`` instance, you should use the ``PPOConfig`` class.
|
||||
An ``Algorithm`` sets up its rollout workers and optimizers, and collects training metrics.
|
||||
``Algorithms`` also implement the :ref:`Tune Trainable API <tune-60-seconds>` for
|
||||
easy experiment management.
|
||||
|
||||
You have three ways to interact with a trainer. You can use the basic Python API or the command line to train it, or you
|
||||
You have three ways to interact with an algorithm. You can use the basic Python API or the command line to train it, or you
|
||||
can use Ray Tune to tune hyperparameters of your reinforcement learning algorithm.
|
||||
The following example shows three equivalent ways of interacting with the ``PPO`` Trainer,
|
||||
The following example shows three equivalent ways of interacting with ``PPO``,
|
||||
which implements the proximal policy optimization algorithm in RLlib.
|
||||
|
||||
.. tabbed:: Basic RLlib Trainer
|
||||
.. tabbed:: Basic RLlib Algorithm
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = PPO(env="CartPole-v0", config={"train_batch_size": 4000})
|
||||
# Configure.
|
||||
from ray.rllib.algorithms import PPOConfig
|
||||
config = PPOConfig().environment("CartPole-v0").training(train_batch_size=4000)
|
||||
|
||||
# Build.
|
||||
algo = config.build()
|
||||
|
||||
# Train.
|
||||
while True:
|
||||
print(trainer.train())
|
||||
print(algo.train())
|
||||
|
||||
|
||||
.. tabbed:: RLlib Algorithms and Tune
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray import tune
|
||||
|
||||
# Configure.
|
||||
from ray.rllib.algorithms import PPOConfig
|
||||
config = PPOConfig().environment("CartPole-v0").training(train_batch_size=4000)
|
||||
|
||||
# Train via Ray Tune.
|
||||
# Note that Ray Tune does not yet support AlgorithmConfig objects, hence
|
||||
# we need to convert back to old-style config dicts.
|
||||
tune.run("PPO", config=config.to_dict())
|
||||
|
||||
|
||||
.. tabbed:: RLlib Command Line
|
||||
|
||||
|
@ -42,17 +70,9 @@ which implements the proximal policy optimization algorithm in RLlib.
|
|||
|
||||
rllib train --run=PPO --env=CartPole-v0 --config='{"train_batch_size": 4000}'
|
||||
|
||||
.. tabbed:: RLlib Tune Trainer
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray import tune
|
||||
tune.run(PPO, config={"env": "CartPole-v0", "train_batch_size": 4000})
|
||||
|
||||
|
||||
|
||||
RLlib `Trainer classes <rllib-concepts.html#trainers>`__ coordinate the distributed workflow of running rollouts and optimizing policies.
|
||||
Trainer classes leverage parallel iterators to implement the desired computation pattern.
|
||||
RLlib `Algorithm classes <rllib-concepts.html#algorithms>`__ coordinate the distributed workflow of running rollouts and optimizing policies.
|
||||
Algorithm classes leverage parallel iterators to implement the desired computation pattern.
|
||||
The following figure shows *synchronous sampling*, the simplest of `these patterns <rllib-algorithms.html>`__:
|
||||
|
||||
.. figure:: images/a2c-arch.svg
|
||||
|
@ -180,13 +200,13 @@ of a sequence of repeating steps, or *dataflow*, of:
|
|||
2. ``ConcatBatches``: The experiences are concatenated into one batch for training.
|
||||
3. ``TrainOneStep``: Take a gradient step with respect to the policy loss, and update the worker weights.
|
||||
|
||||
In code, this dataflow can be expressed as the following execution plan, which is a static method that can be overridden in your custom Trainer sub-classes to define new algorithms.
|
||||
In code, this dataflow can be expressed as the following execution plan, which is a static method that can be overridden in your custom Algorithm sub-classes to define new algorithms.
|
||||
It takes in a ``WorkerSet`` and config, and returns an iterator over training results:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@staticmethod
|
||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict):
|
||||
def execution_plan(workers: WorkerSet, config: AlgorithmConfigDict):
|
||||
# type: LocalIterator[SampleBatchType]
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
|
@ -204,7 +224,7 @@ As you can see, each step returns an *iterator* over objects (if you're unfamili
|
|||
The reason it is a ``LocalIterator`` is that, though it is based on a parallel computation, the iterator has been turned into one that can be consumed locally in sequence by the program.
|
||||
A couple other points to note:
|
||||
|
||||
- The reason the plan returns an iterator over training results, is that ``trainer.train()`` is pulling results from this iterator to return as the result of the train call.
|
||||
- The reason the plan returns an iterator over training results, is that ``algorithm.train()`` is pulling results from this iterator to return as the result of the train call.
|
||||
- The rollout workers have been already created ahead of time in the ``WorkerSet``, so the execution plan function is only defining a sequence of operations over the results of the rollouts.
|
||||
|
||||
These iterators represent the infinite stream of data items that can be produced from the dataflow.
|
||||
|
@ -236,7 +256,8 @@ You'll see output like this on the console:
|
|||
(pid=6555) I saw <class 'ray.rllib.policy.sample_batch.SampleBatch'>
|
||||
(pid=6555) I saw <class 'ray.rllib.policy.sample_batch.SampleBatch'>
|
||||
|
||||
It is important to understand that the iterators of an execution plan are evaluated *lazily*. This means that no computation happens until the `trainer <#trainers>`__ tries to read the next item from the iterator (i.e., get the next training result for a ``Trainer.train()`` call).
|
||||
It is important to understand that the iterators of an execution plan are evaluated *lazily*. This means that no computation happens until the `algorithm <#algorithms>`__ tries to read the next item from the iterator
|
||||
(i.e., get the next training result for a ``Algorithms.train()`` call).
|
||||
|
||||
Execution Plan Concepts
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -282,7 +303,7 @@ Examples
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict):
|
||||
def execution_plan(workers: WorkerSet, config: AlgorithmConfigDict):
|
||||
# type: LocalIterator[(ModelGradients, int)]
|
||||
grads = AsyncGradients(workers)
|
||||
|
||||
|
@ -300,7 +321,7 @@ Examples
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict):
|
||||
def execution_plan(workers: WorkerSet, config: AlgorithmConfigDict):
|
||||
# Construct a replay buffer.
|
||||
replay_buffer = LocalReplayBuffer(...)
|
||||
|
||||
|
|
|
@ -24,8 +24,8 @@ prep.transform(env.reset()).shape
|
|||
import numpy as np
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
|
||||
trainer = PPO(env="CartPole-v0", config={"framework": "tf2", "num_workers": 0})
|
||||
policy = trainer.get_policy()
|
||||
algo = PPO(env="CartPole-v0", config={"framework": "tf2", "num_workers": 0})
|
||||
policy = algo.get_policy()
|
||||
# <ray.rllib.policy.eager_tf_policy.PPOTFPolicy_eager object at 0x7fd020165470>
|
||||
|
||||
# Run a forward pass to get model output logits. Note that complex observations
|
||||
|
@ -82,8 +82,8 @@ _____________________________________________________________________
|
|||
import numpy as np
|
||||
from ray.rllib.algorithms.dqn import DQN
|
||||
|
||||
trainer = DQN(env="CartPole-v0", config={"framework": "tf2"})
|
||||
model = trainer.get_policy().model
|
||||
algo = DQN(env="CartPole-v0", config={"framework": "tf2"})
|
||||
model = algo.get_policy().model
|
||||
# <ray.rllib.models.catalog.FullyConnectedNetwork_as_DistributionalQModel ...>
|
||||
|
||||
# List of all model variables
|
||||
|
|
|
@ -67,7 +67,7 @@ To be able to run our Atari examples, you should also install:
|
|||
|
||||
After these quick pip installs, you can start coding against RLlib.
|
||||
|
||||
Here is an example of running a PPO Trainer on the "`Taxi domain <https://www.gymlibrary.ml/environments/toy_text/taxi/>`_"
|
||||
Here is an example of running PPO on the "`Taxi domain <https://www.gymlibrary.ml/environments/toy_text/taxi/>`_"
|
||||
for a few training iterations, then perform a single evaluation loop
|
||||
(with rendering enabled):
|
||||
|
||||
|
|
56
doc/source/rllib/package_ref/algorithm.rst
Normal file
56
doc/source/rllib/package_ref/algorithm.rst
Normal file
|
@ -0,0 +1,56 @@
|
|||
.. algorithm-reference-docs:
|
||||
|
||||
Algorithm API
|
||||
=============
|
||||
|
||||
The :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` class is the highest-level API in RLlib.
|
||||
It allows you to train and evaluate policies, save an experiment's progress and restore from
|
||||
a prior saved experiment when continuing an RL run.
|
||||
:py:class:`~ray.rllib.algorithms.algorithm.Algorithm` is a sub-class
|
||||
of :py:class:`~ray.tune.Trainable`
|
||||
and thus fully supports distributed hyperparameter tuning for RL.
|
||||
|
||||
.. https://docs.google.com/drawings/d/1J0nfBMZ8cBff34e-nSPJZMM1jKOuUL11zFJm6CmWtJU/edit
|
||||
.. figure:: ../images/trainer_class_overview.svg
|
||||
:align: left
|
||||
|
||||
**A typical RLlib Algorithm object:** The components sitting inside an Algorithm are
|
||||
normally N :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet`s
|
||||
(each consisting of one local :py:class:`~ray.rllib.evaluation.RolloutWorker`
|
||||
and zero or more \@ray.remote
|
||||
:py:class:`~ray.rllib.evaluation.RolloutWorker`s),
|
||||
a set of :py:class:`~ray.rllib.policy.Policy`(ies)
|
||||
and their NN models per worker, and a (already vectorized)
|
||||
RLlib :py:class:`~ray.rllib.env.base_env.BaseEnv` per worker.
|
||||
|
||||
|
||||
Building Custom Algorithm Classes
|
||||
---------------------------------
|
||||
|
||||
.. warning::
|
||||
As of Ray >= 1.9, it is no longer recommended to use the `build_trainer()` utility
|
||||
function for creating custom Algorithm sub-classes.
|
||||
Instead, follow the simple guidelines here for directly sub-classing from
|
||||
:py:class:`~ray.rllib.algorithms.algorithm.Algorithm`.
|
||||
|
||||
In order to create a custom Algorithm, sub-class the
|
||||
:py:class:`~ray.rllib.algorithms.algorithm.Algorithm` class
|
||||
and override one or more of its methods. Those are in particular:
|
||||
|
||||
* :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.get_default_config`
|
||||
* :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.validate_config`
|
||||
* :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.get_default_policy_class`
|
||||
* :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.setup`
|
||||
* :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.training_iteration`
|
||||
|
||||
`See here for an example on how to override Algorithm <https://github.com/ray-project/ray/blob/master/rllib/algorithms/pg/pg.py>`_.
|
||||
|
||||
|
||||
Algorithm base class (ray.rllib.algorithms.algorithm.Algorithm)
|
||||
---------------------------------------------------------------
|
||||
|
||||
.. autoclass:: ray.rllib.algorithms.algorithm.Algorithm
|
||||
:special-members: __init__
|
||||
:members:
|
||||
|
||||
.. static-members: get_default_config, execution_plan
|
|
@ -6,7 +6,7 @@ Evaluation and Environment Rollout
|
|||
Data ingest via either environment rollouts or other data-generating methods
|
||||
(e.g. reading from offline files) is done in RLlib by :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker`s,
|
||||
which sit inside a :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet`
|
||||
(together with other parallel ``RolloutWorkers``) in the RLlib :py:class:`~ray.rllib.agents.trainer.Trainer`
|
||||
(together with other parallel ``RolloutWorkers``) in the RLlib :py:class:`~ray.rllib.algorithms.algorithm.Algorithm`
|
||||
(under the ``self.workers`` property):
|
||||
|
||||
|
||||
|
@ -15,7 +15,7 @@ which sit inside a :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet`
|
|||
:width: 600
|
||||
:align: left
|
||||
|
||||
**A typical RLlib WorkerSet setup inside an RLlib Trainer:** Each :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet` contains
|
||||
**A typical RLlib WorkerSet setup inside an RLlib Algorithm:** Each :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet` contains
|
||||
exactly one local :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker` object and n ray remote
|
||||
:py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker` (ray actors).
|
||||
The workers contain a policy map (with one or more policies), and - in case a simulator
|
||||
|
@ -23,7 +23,7 @@ which sit inside a :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet`
|
|||
(containing m sub-environments) and a :py:class:`~ray.rllib.evaluation.sampler.SamplerInput` (either synchronous or asynchronous) which controls
|
||||
the environment data collection loop.
|
||||
In the online (environment is available) as well as the offline case (no environment),
|
||||
:py:class:`~ray.rllib.agents.trainer.Trainer` uses the :py:meth:`~ray.rllib.evaluation.rollout_worker.RolloutWorker.sample` method to
|
||||
:py:class:`~ray.rllib.algorithms.algorithm.Algorithm` uses the :py:meth:`~ray.rllib.evaluation.rollout_worker.RolloutWorker.sample` method to
|
||||
get :py:class:`~ray.rllib.policy.sample_batch.SampleBatch` objects for training.
|
||||
|
||||
|
||||
|
|
|
@ -9,11 +9,11 @@ The Policies are used to calculate actions for the next environment steps, losse
|
|||
model updates, and other functionalities covered by RLlib's :py:class:`~ray.rllib.policy.policy.Policy` API.
|
||||
A mapping function is used by episode objects to map AgentIDs produced by the environment to one of the PolicyIDs.
|
||||
|
||||
It is possible to add and remove policies to/from the :py:class:`~ray.rllib.agents.trainer.Trainer`'s workers at any given time
|
||||
It is possible to add and remove policies to/from the :py:class:`~ray.rllib.algorithms.algorithm.Algorithm`'s workers at any given time
|
||||
(even within an ongoing episode) as well as to change the policy mapping function.
|
||||
See the Trainer's methods: :py:meth:`~ray.rllib.agents.trainer.Trainer.add_policy`,
|
||||
:py:meth:`~ray.rllib.agents.trainer.Trainer.remove_policy`, and
|
||||
:py:meth:`~ray.rllib.agents.trainer.Trainer.change_policy_mapping_fn` for more details.
|
||||
See the Algorithm's methods: :py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.add_policy`,
|
||||
:py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.remove_policy`, and
|
||||
:py:meth:`~ray.rllib.algorithms.algorithm.Algorithm.change_policy_mapping_fn` for more details.
|
||||
|
||||
.. autoclass:: ray.rllib.policy.policy_map.PolicyMap
|
||||
:members:
|
||||
|
|
|
@ -4,13 +4,13 @@ RolloutWorker
|
|||
=============
|
||||
|
||||
RolloutWorkers are used as ``@ray.remote`` actors to collect and return samples
|
||||
from environments or offline files in parallel. An RLlib :py:class:`~ray.rllib.agents.trainer.Trainer` usually has
|
||||
from environments or offline files in parallel. An RLlib :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` usually has
|
||||
``num_workers`` :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker`s plus a
|
||||
single "local" :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker` (not ``@ray.remote``) in
|
||||
its :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet` under ``self.workers``.
|
||||
|
||||
Depending on its evaluation config settings, an additional :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet` with
|
||||
:py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker`s for evaluation may be present in the :py:class:`~ray.rllib.agents.trainer.Trainer`
|
||||
:py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker`s for evaluation may be present in the :py:class:`~ray.rllib.algorithms.algorithm.Algorithm`
|
||||
under ``self.evaluation_workers``.
|
||||
|
||||
.. autoclass:: ray.rllib.evaluation.rollout_worker.RolloutWorker
|
||||
|
|
|
@ -20,7 +20,7 @@ If you think there is anything missing, please open an issue on `Github`_.
|
|||
:maxdepth: 2
|
||||
|
||||
env.rst
|
||||
trainer.rst
|
||||
algorithm.rst
|
||||
policy.rst
|
||||
models.rst
|
||||
evaluation.rst
|
||||
|
|
|
@ -1,57 +0,0 @@
|
|||
.. _trainer-reference-docs:
|
||||
|
||||
Trainer API
|
||||
===========
|
||||
|
||||
The :py:class:`~ray.rllib.agents.trainer.Trainer` class is the highest-level API in RLlib.
|
||||
It allows you to train and evaluate policies, save an experiment's progress and restore from
|
||||
a prior saved experiment when continuing an RL run.
|
||||
:py:class:`~ray.rllib.agents.trainer.Trainer` is a sub-class
|
||||
of :py:class:`~ray.tune.Trainable`
|
||||
and thus fully supports distributed hyperparameter tuning for RL.
|
||||
|
||||
.. https://docs.google.com/drawings/d/1J0nfBMZ8cBff34e-nSPJZMM1jKOuUL11zFJm6CmWtJU/edit
|
||||
.. figure:: ../images/trainer_class_overview.svg
|
||||
:align: left
|
||||
|
||||
**A typical RLlib Trainer object:** The components sitting inside a Trainer are
|
||||
normally N :py:class:`~ray.rllib.evaluation.worker_set.WorkerSet`s
|
||||
(each consisting of one local :py:class:`~ray.rllib.evaluation.RolloutWorker`
|
||||
and zero or more \@ray.remote
|
||||
:py:class:`~ray.rllib.evaluation.RolloutWorker`s),
|
||||
a set of :py:class:`~ray.rllib.policy.Policy`(ies)
|
||||
and their NN models per worker, and a (already vectorized)
|
||||
RLlib :py:class:`~ray.rllib.env.base_env.BaseEnv` per worker.
|
||||
|
||||
|
||||
Building Custom Trainer Classes
|
||||
-------------------------------
|
||||
|
||||
.. warning::
|
||||
As of Ray >= 1.9, it is no longer recommended to use the `build_trainer()` utility
|
||||
function for creating custom Trainer sub-classes.
|
||||
Instead, follow the simple guidelines here for directly sub-classing from
|
||||
:py:class:`~ray.rllib.agents.trainer.Trainer`.
|
||||
|
||||
In order to create a custom Trainer, sub-class the
|
||||
:py:class:`~ray.rllib.agents.trainer.Trainer` class
|
||||
and override one or more of its methods. Those are in particular:
|
||||
|
||||
* :py:meth:`~ray.rllib.agents.trainer.Trainer.get_default_config`
|
||||
* :py:meth:`~ray.rllib.agents.trainer.Trainer.validate_config`
|
||||
* :py:meth:`~ray.rllib.agents.trainer.Trainer.get_default_policy_class`
|
||||
* :py:meth:`~ray.rllib.agents.trainer.Trainer.setup`
|
||||
* :py:meth:`~ray.rllib.agents.trainer.Trainer.step_attempt`
|
||||
* :py:meth:`~ray.rllib.agents.trainer.Trainer.execution_plan`
|
||||
|
||||
`See here for an example on how to override Trainer <https://github.com/ray-project/ray/blob/master/rllib/algorithms/pg/pg.py>`_.
|
||||
|
||||
|
||||
Trainer base class (ray.rllib.agents.trainer.Trainer)
|
||||
-----------------------------------------------------
|
||||
|
||||
.. autoclass:: ray.rllib.agents.trainer.Trainer
|
||||
:special-members: __init__
|
||||
:members:
|
||||
|
||||
.. static-members: get_default_config, execution_plan
|
|
@ -207,7 +207,7 @@ Decentralized Distributed Proximal Policy Optimization (DD-PPO)
|
|||
|pytorch|
|
||||
`[paper] <https://arxiv.org/abs/1911.00357>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/algorithms/ddppo/ddppo.py>`__
|
||||
Unlike APPO or PPO, with DD-PPO policy improvement is no longer done centralized in the trainer process. Instead, gradients are computed remotely on each rollout worker and all-reduced at each mini-batch using `torch distributed <https://pytorch.org/docs/stable/distributed.html>`__. This allows each worker's GPU to be used both for sampling and for training.
|
||||
Unlike APPO or PPO, with DD-PPO policy improvement is no longer done centralized in the algorithm process. Instead, gradients are computed remotely on each rollout worker and all-reduced at each mini-batch using `torch distributed <https://pytorch.org/docs/stable/distributed.html>`__. This allows each worker's GPU to be used both for sampling and for training.
|
||||
|
||||
.. tip::
|
||||
|
||||
|
@ -895,7 +895,7 @@ Tuned examples:
|
|||
|
||||
**Activating Curiosity**
|
||||
The curiosity plugin can be easily activated by specifying it as the Exploration class to-be-used
|
||||
in the main Trainer config. Most of its parameters usually do not have to be specified
|
||||
in the main Algorithm config. Most of its parameters usually do not have to be specified
|
||||
as the module uses the values from the paper by default. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -933,7 +933,7 @@ In such environments, agents have to navigate (and change the underlying state o
|
|||
For example, the task could be to find a key in some room, pick it up, find a matching door (matching the color of the key), and eventually unlock this door with the key to reach a goal state,
|
||||
all the while not seeing any rewards.
|
||||
Such problems are impossible to solve with standard RL exploration methods like epsilon-greedy or stochastic sampling.
|
||||
The Curiosity module - when configured as the Exploration class to use via the Trainer's config (see above on how to do this) - automatically adds three simple models to the Policy's ``self.model``:
|
||||
The Curiosity module - when configured as the Exploration class to use via the Algorithm's config (see above on how to do this) - automatically adds three simple models to the Policy's ``self.model``:
|
||||
a) a latent space learning ("feature") model, taking an environment observation and outputting a latent vector, which represents this observation and
|
||||
b) a "forward" model, predicting the next latent vector, given the current observation vector and an action to take next.
|
||||
c) a so-called "inverse" net, only used to train the "feature" net. The inverse net tries to predict the action taken between two latent vectors (obs and next obs).
|
||||
|
@ -963,7 +963,7 @@ Examples:
|
|||
|
||||
**Activating RE3**
|
||||
The RE3 plugin can be easily activated by specifying it as the Exploration class to-be-used
|
||||
in the main Trainer config and inheriting the `RE3UpdateCallbacks` as shown in this `example <https://github.com/ray-project/ray/blob/c9c3f0745a9291a4de0872bdfa69e4ffdfac3657/rllib/utils/exploration/tests/test_random_encoder.py#L35>`__. Most of its parameters usually do not have to be specified as the module uses the values from the paper by default. For example:
|
||||
in the main Algorithm config and inheriting the `RE3UpdateCallbacks` as shown in this `example <https://github.com/ray-project/ray/blob/c9c3f0745a9291a4de0872bdfa69e4ffdfac3657/rllib/utils/exploration/tests/test_random_encoder.py#L35>`__. Most of its parameters usually do not have to be specified as the module uses the values from the paper by default. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
|
@ -146,20 +146,20 @@ In the above snippet, ``actions`` is a Tensor placeholder of shape ``[batch_size
|
|||
name="MyTFPolicy",
|
||||
loss_fn=policy_gradient_loss)
|
||||
|
||||
We can create a `Trainer <#trainers>`__ and try running this policy on a toy env with two parallel rollout workers:
|
||||
We can create an `Algorithm <#algorithms>`__ and try running this policy on a toy env with two parallel rollout workers:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
|
||||
class MyTrainer(Trainer):
|
||||
class MyAlgo(Algorithm):
|
||||
def get_default_policy_class(self, config):
|
||||
return MyTFPolicy
|
||||
|
||||
ray.init()
|
||||
tune.run(MyTrainer, config={"env": "CartPole-v0", "num_workers": 2})
|
||||
tune.run(MyAlgo, config={"env": "CartPole-v0", "num_workers": 2})
|
||||
|
||||
|
||||
If you run the above snippet `(runnable file here) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_tf_policy.py>`__, you'll probably notice that CartPole doesn't learn so well:
|
||||
|
@ -202,7 +202,7 @@ Let's modify our policy loss to include rewards summed over time. To enable this
|
|||
loss_fn=policy_gradient_loss,
|
||||
postprocess_fn=postprocess_advantages)
|
||||
|
||||
The ``postprocess_advantages()`` function above uses calls RLlib's ``compute_advantages`` function to compute advantages for each timestep. If you re-run the trainer with this improved policy, you'll find that it quickly achieves the max reward of 200.
|
||||
The ``postprocess_advantages()`` function above uses calls RLlib's ``compute_advantages`` function to compute advantages for each timestep. If you re-run the algorithm with this improved policy, you'll find that it quickly achieves the max reward of 200.
|
||||
|
||||
You might be wondering how RLlib makes the advantages placeholder automatically available as ``train_batch[Postprocessing.ADVANTAGES]``. When building your policy, RLlib will create a "dummy" trajectory batch where all observations, actions, rewards, etc. are zeros. It then calls your ``postprocess_fn``, and generates TF placeholders based on the numpy shapes of the postprocessed batch. RLlib tracks which placeholders that ``loss_fn`` and ``stats_fn`` access, and then feeds the corresponding sample data into those placeholders during loss optimization. You can also access these placeholders via ``policy.get_placeholder(<name>)`` after loss initialization.
|
||||
|
||||
|
@ -210,38 +210,37 @@ You might be wondering how RLlib makes the advantages placeholder automatically
|
|||
|
||||
In the above section you saw how to compose a simple policy gradient algorithm with RLlib.
|
||||
In this example, we'll dive into how PPO is defined within RLlib and how you can modify it.
|
||||
First, check out the `PPO trainer definition <https://github.com/ray-project/ray/blob/master/rllib/algorithms/ppo/ppo.py>`__:
|
||||
First, check out the `PPO definition <https://github.com/ray-project/ray/blob/master/rllib/algorithms/ppo/ppo.py>`__:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class PPO(Trainer):
|
||||
class PPO(Algorithm):
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return DEFAULT_CONFIG
|
||||
|
||||
@override(Trainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
@override(Algorithm)
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
...
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(self, config):
|
||||
return PPOTFPolicy
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
@override(Algorithm)
|
||||
def training_iteration(self):
|
||||
...
|
||||
|
||||
Besides some boilerplate for defining the PPO configuration and some warnings, the most important method to take note of is the ``execution_plan``.
|
||||
|
||||
The trainer's `execution plan <#execution-plans>`__ defines the distributed training workflow.
|
||||
Depending on the ``simple_optimizer`` trainer config,
|
||||
The algorithm's `execution plan <#execution-plans>`__ defines the distributed training workflow.
|
||||
Depending on the ``simple_optimizer`` config key,
|
||||
PPO can switch between a simple synchronous plan, or a multi-GPU plan that implements minibatch SGD (the default):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict):
|
||||
def execution_plan(workers: WorkerSet, config: AlgorithmConfigDict):
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# Collect large batches of relevant experiences & standardize.
|
||||
|
@ -316,7 +315,7 @@ Now let's look at each PPO policy definition:
|
|||
before_loss_init=setup_mixins,
|
||||
mixins=[LearningRateSchedule, KLCoeffMixin, ValueNetworkMixin])
|
||||
|
||||
``stats_fn``: The stats function returns a dictionary of Tensors that will be reported with the training results. This also includes the ``kl`` metric which is used by the trainer to adjust the KL penalty. Note that many of the values below reference ``policy.loss_obj``, which is assigned by ``loss_fn`` (not shown here since the PPO loss is quite complex). RLlib will always call ``stats_fn`` after ``loss_fn``, so you can rely on using values saved by ``loss_fn`` as part of your statistics:
|
||||
``stats_fn``: The stats function returns a dictionary of Tensors that will be reported with the training results. This also includes the ``kl`` metric which is used by the algorithm to adjust the KL penalty. Note that many of the values below reference ``policy.loss_obj``, which is assigned by ``loss_fn`` (not shown here since the PPO loss is quite complex). RLlib will always call ``stats_fn`` after ``loss_fn``, so you can rely on using values saved by ``loss_fn`` as part of your statistics:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -464,7 +463,8 @@ Here's an example of using eager ops embedded
|
|||
Building Policies in PyTorch
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a trainer given a Torch policy is exactly the same). Here's a simple example of a trivial torch policy `(runnable file here) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_torch_policy.py>`__:
|
||||
Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a algorithm given a Torch policy is exactly the same).
|
||||
Here's a simple example of a trivial torch policy `(runnable file here) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_torch_policy.py>`__:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
|
@ -55,13 +55,13 @@ Contributing Algorithms
|
|||
These are the guidelines for merging new algorithms into RLlib:
|
||||
|
||||
* Contributed algorithms (`rllib/contrib <https://github.com/ray-project/ray/tree/master/rllib/contrib>`__):
|
||||
- must subclass Trainer and implement the ``step()`` method
|
||||
- must subclass Algorithm and implement the ``step()`` method
|
||||
- must include a lightweight test (`example <https://github.com/ray-project/ray/blob/6bb110393008c9800177490688c6ed38b2da52a9/test/jenkins_tests/run_multi_node_tests.sh#L45>`__) to ensure the algorithm runs
|
||||
- should include tuned hyperparameter examples and documentation
|
||||
- should offer functionality not present in existing algorithms
|
||||
|
||||
* Fully integrated algorithms (`rllib/agents <https://github.com/ray-project/ray/tree/master/rllib/agents>`__) have the following additional requirements:
|
||||
- must fully implement the Trainer API
|
||||
- must fully implement the Algorithm API
|
||||
- must offer substantial new functionality not possible to add to other algorithms
|
||||
- should support custom models and preprocessors
|
||||
- should use RLlib abstractions and support distributed execution
|
||||
|
@ -70,14 +70,14 @@ Both integrated and contributed algorithms ship with the ``ray`` PyPI package, a
|
|||
|
||||
How to add an algorithm to ``contrib``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
It takes just two changes to add an algorithm to `contrib <https://github.com/ray-project/ray/tree/master/rllib/contrib>`__. A minimal example can be found `here <https://github.com/ray-project/ray/tree/master/rllib/contrib/random_agent/random_agent.py>`__. First, subclass `Trainer <https://github.com/ray-project/ray/commits/master/rllib/agents/trainer.py>`__ and implement the ``_init`` and ``step`` methods:
|
||||
It takes just two changes to add an algorithm to `contrib <https://github.com/ray-project/ray/tree/master/rllib/contrib>`__. A minimal example can be found `here <https://github.com/ray-project/ray/tree/master/rllib/contrib/random_agent/random_agent.py>`__. First, subclass `Algorithm <https://github.com/ray-project/ray/commits/master/rllib/algorithms/algorithm.py>`__ and implement the ``_init`` and ``step`` methods:
|
||||
|
||||
.. literalinclude:: ../../../rllib/contrib/random_agent/random_agent.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
||||
Second, register the trainer with a name in `contrib/registry.py <https://github.com/ray-project/ray/blob/master/rllib/contrib/registry.py>`__.
|
||||
Second, register the algorithm with a name in `contrib/registry.py <https://github.com/ray-project/ray/blob/master/rllib/contrib/registry.py>`__.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -110,7 +110,7 @@ Finding Memory Leaks In Workers
|
|||
|
||||
Keeping the memory usage of long running workers stable can be challenging. The ``MemoryTrackingCallbacks`` class can be used to track memory usage of workers.
|
||||
|
||||
.. autoclass:: ray.rllib.agents.callbacks.MemoryTrackingCallbacks
|
||||
.. autoclass:: ray.rllib.algorithms.callbacks.MemoryTrackingCallbacks
|
||||
|
||||
The objects with the top 20 memory usage in the workers will be added as custom metrics. These can then be monitored using tensorboard or other metrics integrations like Weights and Biases:
|
||||
|
||||
|
|
|
@ -16,12 +16,13 @@ RLlib works with several different types of environments, including `OpenAI Gym
|
|||
Configuring Environments
|
||||
------------------------
|
||||
|
||||
You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name <https://www.gymlibrary.ml/>`__. Custom env classes passed directly to the trainer must take a single ``env_config`` parameter in their constructor:
|
||||
You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name <https://www.gymlibrary.ml/>`__.
|
||||
Custom env classes passed directly to the algorithm must take a single ``env_config`` parameter in their constructor:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import gym, ray
|
||||
from ray.rllib.agents import ppo
|
||||
from ray.rllib.algorithms import ppo
|
||||
|
||||
class MyEnv(gym.Env):
|
||||
def __init__(self, env_config):
|
||||
|
@ -33,12 +34,12 @@ You can pass either a string name or a Python class to specify an environment. B
|
|||
return <obs>, <reward: float>, <done: bool>, <info: dict>
|
||||
|
||||
ray.init()
|
||||
trainer = ppo.PPO(env=MyEnv, config={
|
||||
algo = ppo.PPO(env=MyEnv, config={
|
||||
"env_config": {}, # config to pass to env class
|
||||
})
|
||||
|
||||
while True:
|
||||
print(trainer.train())
|
||||
print(algo.train())
|
||||
|
||||
You can also register a custom env creator function with a string name. This function must take a single ``env_config`` (dict) parameter and return an env instance:
|
||||
|
||||
|
@ -50,7 +51,7 @@ You can also register a custom env creator function with a string name. This fun
|
|||
return MyEnv(...) # return an env instance
|
||||
|
||||
register_env("my_env", env_creator)
|
||||
trainer = ppo.PPO(env="my_env")
|
||||
algo = ppo.PPO(env="my_env")
|
||||
|
||||
For a full runnable code example using the custom environment API, see `custom_env.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__.
|
||||
|
||||
|
@ -58,7 +59,10 @@ For a full runnable code example using the custom environment API, see `custom_e
|
|||
|
||||
The gym registry is not compatible with Ray. Instead, always use the registration flows documented above to ensure Ray workers can access the environment.
|
||||
|
||||
In the above example, note that the ``env_creator`` function takes in an ``env_config`` object. This is a dict containing options passed in through your trainer. You can also access ``env_config.worker_index`` and ``env_config.vector_index`` to get the worker id and env id within the worker (if ``num_envs_per_worker > 0``). This can be useful if you want to train over an ensemble of different environments, for example:
|
||||
In the above example, note that the ``env_creator`` function takes in an ``env_config`` object.
|
||||
This is a dict containing options passed in through your algorithm.
|
||||
You can also access ``env_config.worker_index`` and ``env_config.vector_index`` to get the worker id and env id within the worker (if ``num_envs_per_worker > 0``).
|
||||
This can be useful if you want to train over an ensemble of different environments, for example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -225,12 +229,12 @@ If all the agents will be using the same algorithm class to train, then you can
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = pg.PGAgent(env="my_multiagent_env", config={
|
||||
algo = pg.PGAgent(env="my_multiagent_env", config={
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
# Use the PolicySpec namedtuple to specify an individual policy:
|
||||
"car1": PolicySpec(
|
||||
policy_class=None, # infer automatically from Trainer
|
||||
policy_class=None, # infer automatically from Algorithm
|
||||
observation_space=None, # infer automatically from env
|
||||
action_space=None, # infer automatically from env
|
||||
config={"gamma": 0.85}, # use main config plus <- this override here
|
||||
|
@ -238,7 +242,7 @@ If all the agents will be using the same algorithm class to train, then you can
|
|||
|
||||
# Deprecated way: Tuple specifying class, obs-/action-spaces,
|
||||
# config-overrides for each policy as a tuple.
|
||||
# If class is None -> Uses Trainer's default policy class.
|
||||
# If class is None -> Uses Algorithm's default policy class.
|
||||
"car2": (None, car_obs_space, car_act_space, {"gamma": 0.99}),
|
||||
|
||||
# New way: Use PolicySpec() with keywords: `policy_class`,
|
||||
|
@ -257,7 +261,7 @@ If all the agents will be using the same algorithm class to train, then you can
|
|||
})
|
||||
|
||||
while True:
|
||||
print(trainer.train())
|
||||
print(algo.train())
|
||||
|
||||
To exclude some policies in your ``multiagent.policies`` dictionary, you can use the ``multiagent.policies_to_train`` setting.
|
||||
For example, you may want to have one or more random (non learning) policies interact with your learning ones:
|
||||
|
@ -275,7 +279,7 @@ For example, you may want to have one or more random (non learning) policies int
|
|||
# (start player) and sometimes player2 (player to move 2nd).
|
||||
return "learning_policy" if episode.episode_id % 2 == agent_idx else "random_policy"
|
||||
|
||||
trainer = pg.PGAgent(env="two_player_game", config={
|
||||
algo = pg.PGAgent(env="two_player_game", config={
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
"learning_policy": PolicySpec(), # <- use default class & infer obs-/act-spaces from env.
|
||||
|
@ -505,11 +509,11 @@ External Application Clients
|
|||
|
||||
For applications that are running entirely outside the Ray cluster (i.e., cannot be packaged into a Python environment of any form), RLlib provides the ``PolicyServerInput`` application connector, which can be connected to over the network using ``PolicyClient`` instances.
|
||||
|
||||
You can configure any Trainer to launch a policy server with the following config:
|
||||
You can configure any Algorithm to launch a policy server with the following config:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trainer_config = {
|
||||
config = {
|
||||
# An environment class is still required, but it doesn't need to be runnable.
|
||||
# You only need to define its action and observation space attributes.
|
||||
# See examples/serving/unity3d_server.py for an example using a RandomMultiAgentEnv stub.
|
||||
|
@ -518,13 +522,16 @@ You can configure any Trainer to launch a policy server with the following confi
|
|||
"input": (
|
||||
lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, SERVER_PORT)
|
||||
),
|
||||
# Use the existing trainer process to run the server.
|
||||
# Use the existing algorithm process to run the server.
|
||||
"num_workers": 0,
|
||||
# Disable OPE, since the rollouts are coming from online clients.
|
||||
"off_policy_estimation_methods": {},
|
||||
}
|
||||
|
||||
Clients can then connect in either *local* or *remote* inference mode. In local inference mode, copies of the policy are downloaded from the server and cached on the client for a configurable period of time. This allows actions to be computed by the client without requiring a network round trip each time. In remote inference mode, each computed action requires a network call to the server.
|
||||
Clients can then connect in either *local* or *remote* inference mode.
|
||||
In local inference mode, copies of the policy are downloaded from the server and cached on the client for a configurable period of time.
|
||||
This allows actions to be computed by the client without requiring a network round trip each time.
|
||||
In remote inference mode, each computed action requires a network call to the server.
|
||||
|
||||
Example:
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ Environments and Adapters
|
|||
- `Registering a custom env and model <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__:
|
||||
Example of defining and registering a gym env and model for use with RLlib.
|
||||
- `Local Unity3D multi-agent environment example <https://github.com/ray-project/ray/tree/master/rllib/examples/unity3d_env_local.py>`__:
|
||||
Example of how to setup an RLlib Trainer against a locally running Unity3D editor instance to
|
||||
Example of how to setup an RLlib Algorithm against a locally running Unity3D editor instance to
|
||||
learn any Unity3D game (including support for multi-agent).
|
||||
Use this example to try things out and watch the game and the learning progress live in the editor.
|
||||
Providing a compiled game, this example could also run in distributed fashion with `num_workers > 0`.
|
||||
|
@ -50,7 +50,7 @@ Environments and Adapters
|
|||
- `DMLab Watermaze example <https://github.com/ray-project/ray/blob/master/rllib/examples/dmlab_watermaze.py>`__:
|
||||
Example for how to use a DMLab environment (Watermaze).
|
||||
- `RecSym environment example (for recommender systems) using the SlateQ algorithm <https://github.com/ray-project/ray/blob/master/rllib/examples/recommender_system_with_recsim_and_slateq.py>`__:
|
||||
Script showing how to train a SlateQTrainer on a RecSym environment.
|
||||
Script showing how to train SlateQ on a RecSym environment.
|
||||
- `SUMO (Simulation of Urban MObility) environment example <https://github.com/ray-project/ray/blob/master/rllib/examples/sumo_env_local.py>`__:
|
||||
Example demonstrating how to use the SUMO simulator in connection with RLlib.
|
||||
- `VizDoom example script using RLlib's auto-attention wrapper <https://github.com/ray-project/ray/blob/master/rllib/examples/vizdoom_with_attention_net.py>`__:
|
||||
|
@ -107,7 +107,7 @@ Training Workflows
|
|||
- `Using rollout workers directly for control over the whole training workflow <https://github.com/ray-project/ray/blob/master/rllib/examples/rollout_worker_custom_workflow.py>`__:
|
||||
Example of how to use RLlib's lower-level building blocks to implement a fully customized training workflow.
|
||||
- `Custom execution plan function handling two different Policies (DQN and PPO) at the same time <https://github.com/ray-project/ray/blob/master/rllib/examples/two_trainer_workflow.py>`__:
|
||||
Example of how to use the exec. plan of a Trainer to trin two different policies in parallel (also using multi-agent API).
|
||||
Example of how to use the exec. plan of an Algorithm to trin two different policies in parallel (also using multi-agent API).
|
||||
- `Custom tune experiment <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_experiment.py>`__:
|
||||
How to run a custom Ray Tune experiment with RLlib with custom training- and evaluation phases.
|
||||
|
||||
|
@ -164,8 +164,8 @@ Multi-Agent and Hierarchical
|
|||
Example of running a custom hand-coded policy alongside trainable policies.
|
||||
- `Weight sharing between policies <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent_cartpole.py>`__:
|
||||
Example of how to define weight-sharing layers between two different policies.
|
||||
- `Multiple trainers <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent_two_trainers.py>`__:
|
||||
Example of alternating training between two DQN and PPO trainers.
|
||||
- `Multiple algorithms <https://github.com/ray-project/ray/blob/master/rllib/examples/multi_agent_two_trainers.py>`__:
|
||||
Example of alternating training between DQN and PPO.
|
||||
- `Hierarchical training <https://github.com/ray-project/ray/blob/master/rllib/examples/hierarchical_training.py>`__:
|
||||
Example of hierarchical training using the multi-agent API.
|
||||
- `Iterated Prisoner's Dilemma environment example <https://github.com/ray-project/ray/blob/master/rllib/examples/iterated_prisoners_dilemma_env.py>`__:
|
||||
|
|
|
@ -43,7 +43,7 @@ observation space. Thereby, the following simple rules apply:
|
|||
observations: ``dict_or_tuple_obs = restore_original_dimensions(input_dict["obs"], self.obs_space, "tf|torch")``
|
||||
|
||||
For Atari observation spaces, RLlib defaults to using the `DeepMind preprocessors <https://github.com/ray-project/ray/blob/master/rllib/env/wrappers/atari_wrappers.py>`__
|
||||
(``preprocessor_pref=deepmind``). However, if the Trainer's config key ``preprocessor_pref`` is set to "rllib",
|
||||
(``preprocessor_pref=deepmind``). However, if the Algorithm's config key ``preprocessor_pref`` is set to "rllib",
|
||||
the following mappings apply for Atari-type observation spaces:
|
||||
|
||||
- Images of shape ``(210, 160, 3)`` are downscaled to ``dim x dim``, where
|
||||
|
@ -75,7 +75,7 @@ and some special options for Atari environments:
|
|||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
||||
The dict above (or an overriding sub-set) is handed to the Trainer via the ``model`` key within
|
||||
The dict above (or an overriding sub-set) is handed to the Algorithm via the ``model`` key within
|
||||
the main config dict like so:
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -90,7 +90,7 @@ the main config dict like so:
|
|||
"fcnet_activation": "relu",
|
||||
},
|
||||
|
||||
# ... other Trainer config keys, e.g. "lr" ...
|
||||
# ... other Algorithm config keys, e.g. "lr" ...
|
||||
"lr": 0.00001,
|
||||
}
|
||||
|
||||
|
@ -107,7 +107,7 @@ based on simple heuristics:
|
|||
- A fully connected network (`TF <https://github.com/ray-project/ray/blob/master/rllib/models/tf/fcnet.py>`__ or `Torch <https://github.com/ray-project/ray/blob/master/rllib/models/torch/fcnet.py>`__)
|
||||
for everything else.
|
||||
|
||||
These default model types can further be configured via the ``model`` config key inside your Trainer config (as discussed above).
|
||||
These default model types can further be configured via the ``model`` config key inside your Algorithm config (as discussed above).
|
||||
Available settings are `listed above <#default-model-config-settings>`__ and also documented in the `model catalog file <https://github.com/ray-project/ray/blob/master/rllib/models/catalog.py>`__.
|
||||
|
||||
Note that for the vision network case, you'll probably have to configure ``conv_filters``, if your environment observations
|
||||
|
@ -227,7 +227,7 @@ Once implemented, your TF model can then be registered and used in place of a bu
|
|||
ModelCatalog.register_custom_model("my_tf_model", MyModelClass)
|
||||
|
||||
ray.init()
|
||||
trainer = ppo.PPO(env="CartPole-v0", config={
|
||||
algo = ppo.PPO(env="CartPole-v0", config={
|
||||
"model": {
|
||||
"custom_model": "my_tf_model",
|
||||
# Extra kwargs to be passed to your model's c'tor.
|
||||
|
@ -270,7 +270,7 @@ Once implemented, your PyTorch model can then be registered and used in place of
|
|||
import torch.nn as nn
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents import ppo
|
||||
from ray.rllib.algorithms import ppo
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
|
||||
|
@ -282,7 +282,7 @@ Once implemented, your PyTorch model can then be registered and used in place of
|
|||
ModelCatalog.register_custom_model("my_torch_model", CustomTorchModel)
|
||||
|
||||
ray.init()
|
||||
trainer = ppo.PPO(env="CartPole-v0", config={
|
||||
algo = ppo.PPO(env="CartPole-v0", config={
|
||||
"framework": "torch",
|
||||
"model": {
|
||||
"custom_model": "my_torch_model",
|
||||
|
@ -368,7 +368,7 @@ Custom Model APIs (on Top of Default- or Custom Models)
|
|||
```````````````````````````````````````````````````````
|
||||
|
||||
So far we talked about a) the default models that are built into RLlib and are being provided
|
||||
automatically if you don't specify anything in your Trainer's config and b) custom Models through
|
||||
automatically if you don't specify anything in your Algorithm's config and b) custom Models through
|
||||
which you can define any arbitrary forward passes.
|
||||
|
||||
Another typical situation in which you would have to customize a model would be to
|
||||
|
@ -508,7 +508,7 @@ Similar to custom models and preprocessors, you can also specify a custom action
|
|||
ModelCatalog.register_custom_action_dist("my_dist", MyActionDist)
|
||||
|
||||
ray.init()
|
||||
trainer = ppo.PPO(env="CartPole-v0", config={
|
||||
algo = ppo.PPO(env="CartPole-v0", config={
|
||||
"model": {
|
||||
"custom_action_dist": "my_dist",
|
||||
},
|
||||
|
|
|
@ -83,13 +83,13 @@ This example plot shows the Q-value metric in addition to importance sampling (I
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = DQN(...)
|
||||
algo = DQN(...)
|
||||
... # train policy offline
|
||||
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
||||
|
||||
estimator = WeightedImportanceSamplingEstimator(trainer.get_policy(), gamma=0.99)
|
||||
estimator = WeightedImportanceSamplingEstimator(algo.get_policy(), gamma=0.99)
|
||||
reader = JsonReader("/path/to/data")
|
||||
for _ in range(1000):
|
||||
batch = reader.next()
|
||||
|
@ -246,11 +246,11 @@ Input API
|
|||
You can configure experience input for an agent using the following options:
|
||||
|
||||
.. tip::
|
||||
Plain python config dicts will soon be replaced by :py:class:`~ray.rllib.agents.trainer_config.TrainerConfig`
|
||||
Plain python config dicts will soon be replaced by :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig`
|
||||
objects, which have the advantage of being type safe, allowing users to set different config settings within
|
||||
meaningful sub-categories (e.g. ``my_config.offline_data(input_=[xyz])``), and offer the ability to
|
||||
construct a Trainer instance from these config objects (via their ``.build()`` method).
|
||||
So far, this is only supported for some Trainer classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
|
||||
construct an Algorithm instance from these config objects (via their ``.build()`` method).
|
||||
So far, this is only supported for some Algorithm classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
|
||||
but we are rolling this out right now across all RLlib.
|
||||
|
||||
|
||||
|
@ -343,11 +343,11 @@ Output API
|
|||
You can configure experience output for an agent using the following options:
|
||||
|
||||
.. tip::
|
||||
Plain python config dicts will soon be replaced by :py:class:`~ray.rllib.agents.trainer_config.TrainerConfig`
|
||||
Plain python config dicts will soon be replaced by :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig`
|
||||
objects, which have the advantage of being type safe, allowing users to set different config settings within
|
||||
meaningful sub-categories (e.g. ``my_config.offline_data(input_=[xyz])``), and offer the ability to
|
||||
construct a Trainer instance from these config objects (via their ``.build()`` method).
|
||||
So far, this is only supported for some Trainer classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
|
||||
construct an Algorithm instance from these config objects (via their ``.build()`` method).
|
||||
So far, this is only supported for some Algorithm classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
|
||||
but we are rolling this out right now across all RLlib.
|
||||
|
||||
.. code-block:: python
|
||||
|
|
|
@ -29,7 +29,7 @@ This is done using a dict that maps strings (column names) to `ViewRequirement`
|
|||
|
||||
|
||||
The exact behavior for a single such rollout and the number of environment transitions therein
|
||||
are determined by the following Trainer config keys:
|
||||
are determined by the following Algorithm config keys:
|
||||
|
||||
**batch_mode [truncate_episodes|complete_episodes]**:
|
||||
*truncated_episodes (default value)*:
|
||||
|
@ -65,7 +65,7 @@ of each episode (arrow heads). This way, RLlib makes sure that the
|
|||
|
||||
|
||||
**multiagent.count_steps_by [env_steps|agent_steps]**:
|
||||
Within the Trainer's ``multiagent`` config dict, you can set the unit, by which RLlib will count a) rollout fragment lengths as well as b) the size of the final train_batch (see below). The two supported values are:
|
||||
Within the Algorithm's ``multiagent`` config dict, you can set the unit, by which RLlib will count a) rollout fragment lengths as well as b) the size of the final train_batch (see below). The two supported values are:
|
||||
|
||||
*env_steps (default)*:
|
||||
Each call to ``[Env].step()`` is counted as one. It does not
|
||||
|
@ -109,7 +109,7 @@ RLlib's default ``SampleCollector`` class is the ``SimpleListCollector``, which
|
|||
to lists, then builds SampleBatches from these and sends them to the downstream processing functions.
|
||||
It thereby tries to avoid collecting duplicate data separately (OBS and NEXT_OBS use the same underlying list).
|
||||
If you want to implement your own collection logic and data structures, you can sub-class ``SampleCollector``
|
||||
and specify that new class under the Trainer's "sample_collector" config key.
|
||||
and specify that new class under the Algorithm's "sample_collector" config key.
|
||||
|
||||
Let's now look at how the Policy's Model lets the RolloutWorker and its SampleCollector
|
||||
know, what data in the ongoing episode/trajectory to use for the different required method calls
|
||||
|
|
|
@ -8,13 +8,13 @@ Training APIs
|
|||
Getting Started
|
||||
---------------
|
||||
|
||||
At a high level, RLlib provides an ``Trainer`` class which
|
||||
holds a policy for environment interaction. Through the trainer interface, the policy can
|
||||
be trained, checkpointed, or an action computed. In multi-agent training, the trainer manages the querying and optimization of multiple policies at once.
|
||||
At a high level, RLlib provides an ``Algorithm`` class which
|
||||
holds a policy for environment interaction. Through the algorithm's interface, the policy can
|
||||
be trained, checkpointed, or an action computed. In multi-agent training, the algorithm manages the querying and optimization of multiple policies at once.
|
||||
|
||||
.. image:: images/rllib-api.svg
|
||||
|
||||
You can train a simple DQN trainer with the following commands:
|
||||
You can train DQN with the following commands:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
@ -75,8 +75,8 @@ Specifying Parameters
|
|||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Each algorithm has specific hyperparameters that can be set with ``--config``, in addition to a number of
|
||||
`common hyperparameters <https://github.com/ray-project/ray/blob/master/rllib/agents/trainer.py>`__
|
||||
(soon to be replaced by `TrainerConfig objects <https://github.com/ray-project/ray/blob/master/rllib/agents/trainer_config.py>`__).
|
||||
`common hyperparameters <https://github.com/ray-project/ray/blob/master/rllib/algorithms/algorithm.py>`__
|
||||
(soon to be replaced by `AlgorithmConfig objects <https://github.com/ray-project/ray/blob/master/rllib/algorithms/algorithm_config.py>`__).
|
||||
|
||||
See the `algorithms documentation <rllib-algorithms.html>`__ for more information.
|
||||
|
||||
|
@ -90,10 +90,10 @@ Specifying Resources
|
|||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
You can control the degree of parallelism used by setting the ``num_workers``
|
||||
hyperparameter for most algorithms. The Trainer will construct that many
|
||||
hyperparameter for most algorithms. The Algorithm will construct that many
|
||||
"remote worker" instances (`see RolloutWorker class <https://github.com/ray-project/ray/blob/master/rllib/evaluation/rollout_worker.py>`__)
|
||||
that are constructed as ray.remote actors, plus exactly one "local worker", a ``RolloutWorker`` object that is not a
|
||||
ray actor, but lives directly inside the Trainer.
|
||||
ray actor, but lives directly inside the Algorithm.
|
||||
For most algorithms, learning updates are performed on the local worker and sample collection from
|
||||
one or more environments is performed by the remote workers (in parallel).
|
||||
For example, setting ``num_workers=0`` will only create the local worker, in which case both
|
||||
|
@ -106,7 +106,7 @@ to that worker via the ``num_gpus`` setting.
|
|||
Similarly, the resource allocation to remote workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``.
|
||||
|
||||
The number of GPUs can be fractional quantities (e.g. 0.5) to allocate only a fraction
|
||||
of a GPU. For example, with DQN you can pack five trainers onto one GPU by setting
|
||||
of a GPU. For example, with DQN you can pack five algorithms onto one GPU by setting
|
||||
``num_gpus: 0.2``. Check out `this fractional GPU example here <https://github.com/ray-project/ray/blob/master/rllib/examples/fractional_gpus.py>`__
|
||||
as well that also demonstrates how environments (running on the remote workers) that
|
||||
require a GPU can benefit from the ``num_gpus_per_worker`` setting.
|
||||
|
@ -150,8 +150,8 @@ Here are some rules of thumb for scaling training with RLlib.
|
|||
In case you are using lots of workers (``num_workers >> 10``) and you observe worker failures for whatever reasons, which normally interrupt your RLlib training runs, consider using
|
||||
the config settings ``ignore_worker_failures=True`` or ``recreate_failed_workers=True``:
|
||||
|
||||
``ignore_worker_failures=True`` allows your Trainer to not crash due to a single worker error, but to continue for as long as there is at least one functional worker remaining.
|
||||
``recreate_failed_workers=True`` will have your Trainer attempt to replace/recreate any failed worker(s) with a new one.
|
||||
``ignore_worker_failures=True`` allows your Algorithm to not crash due to a single worker error, but to continue for as long as there is at least one functional worker remaining.
|
||||
``recreate_failed_workers=True`` will have your Algorithm attempt to replace/recreate any failed worker(s) with a new one.
|
||||
|
||||
Both these settings will make your training runs much more stable and more robust against occasional OOM or other similar "once in a while" errors.
|
||||
|
||||
|
@ -160,11 +160,11 @@ Common Parameters
|
|||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. tip::
|
||||
Plain python config dicts will soon be replaced by :py:class:`~ray.rllib.agents.trainer_config.TrainerConfig`
|
||||
Plain python config dicts will soon be replaced by :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig`
|
||||
objects, which have the advantage of being type safe, allowing users to set different config settings within
|
||||
meaningful sub-categories (e.g. ``my_config.training(lr=0.0003)``), and offer the ability to
|
||||
construct a Trainer instance from these config objects (via their ``build()`` method).
|
||||
So far, this is only supported for some Trainer classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
|
||||
construct an Algorithm instance from these config objects (via their ``build()`` method).
|
||||
So far, this is only supported for some Algorithm classes, such as :py:class:`~ray.rllib.algorithms.ppo.ppo.PPO`,
|
||||
but we are rolling this out right now across all RLlib.
|
||||
|
||||
The following is a list of the common algorithm hyper-parameters:
|
||||
|
@ -173,7 +173,7 @@ The following is a list of the common algorithm hyper-parameters:
|
|||
|
||||
# === Settings for Rollout Worker processes ===
|
||||
# Number of rollout worker actors to create for parallel sampling. Setting
|
||||
# this to 0 will force rollouts to be done in the trainer actor.
|
||||
# this to 0 will force rollouts to be done in the algorithm's actor.
|
||||
"num_workers": 2,
|
||||
# Number of environments to evaluate vector-wise per worker. This enables
|
||||
# model inference batching, which can improve performance for inference
|
||||
|
@ -214,7 +214,7 @@ The following is a list of the common algorithm hyper-parameters:
|
|||
# terminates or a configured horizon (hard or soft) is hit.
|
||||
"batch_mode": "truncate_episodes",
|
||||
|
||||
# === Settings for the Trainer process ===
|
||||
# === Settings for the Algorithm process ===
|
||||
# Discount factor of the MDP.
|
||||
"gamma": 0.99,
|
||||
# The default learning rate.
|
||||
|
@ -257,7 +257,7 @@ The following is a list of the common algorithm hyper-parameters:
|
|||
# a PyBullet env, a ViZDoomGym env, or a fully qualified classpath to an
|
||||
# Env class, e.g. "ray.rllib.examples.env.random_env.RandomEnv".
|
||||
"env": None,
|
||||
# The observation- and action spaces for the Policies of this Trainer.
|
||||
# The observation- and action spaces for the Policies of this Algorithm.
|
||||
# Use None for automatically inferring these from the given env.
|
||||
"observation_space": None,
|
||||
"action_space": None,
|
||||
|
@ -399,10 +399,10 @@ The following is a list of the common algorithm hyper-parameters:
|
|||
# The unit, with which to count the evaluation duration. Either "episodes"
|
||||
# (default) or "timesteps".
|
||||
"evaluation_duration_unit": "episodes",
|
||||
# Whether to run evaluation in parallel to a Trainer.train() call
|
||||
# Whether to run evaluation in parallel to a Algorithm.train() call
|
||||
# using threading. Default=False.
|
||||
# E.g. evaluation_interval=2 -> For every other training iteration,
|
||||
# the Trainer.train() and Trainer.evaluate() calls run in parallel.
|
||||
# the Algorithm.train() and Algorithm.evaluate() calls run in parallel.
|
||||
# Note: This is experimental. Possible pitfalls could be race conditions
|
||||
# for weight synching at the beginning of the evaluation loop.
|
||||
"evaluation_parallel_to_training": False,
|
||||
|
@ -420,16 +420,16 @@ The following is a list of the common algorithm hyper-parameters:
|
|||
},
|
||||
|
||||
# Number of parallel workers to use for evaluation. Note that this is set
|
||||
# to zero by default, which means evaluation will be run in the trainer
|
||||
# to zero by default, which means evaluation will be run in the algorithm
|
||||
# process (only if evaluation_interval is not None). If you increase this,
|
||||
# it will increase the Ray resource usage of the trainer since evaluation
|
||||
# it will increase the Ray resource usage of the algorithm since evaluation
|
||||
# workers are created separately from rollout workers (used to sample data
|
||||
# for training).
|
||||
"evaluation_num_workers": 0,
|
||||
# Customize the evaluation method. This must be a function of signature
|
||||
# (trainer: Trainer, eval_workers: WorkerSet) -> metrics: dict. See the
|
||||
# Trainer.evaluate() method to see the default implementation.
|
||||
# The Trainer guarantees all eval workers have the latest policy state
|
||||
# (algorithm: Algorithm, eval_workers: WorkerSet) -> metrics: dict. See the
|
||||
# Algorithm.evaluate() method to see the default implementation.
|
||||
# The Algorithm guarantees all eval workers have the latest policy state
|
||||
# before this function is called.
|
||||
"custom_eval_function": None,
|
||||
# Make sure the latest available evaluation results are always attached to
|
||||
|
@ -504,15 +504,15 @@ The following is a list of the common algorithm hyper-parameters:
|
|||
# each worker, so that identically configured trials will have identical
|
||||
# results. This makes experiments reproducible.
|
||||
"seed": None,
|
||||
# Any extra python env vars to set in the trainer process, e.g.,
|
||||
# Any extra python env vars to set in the algorithm process, e.g.,
|
||||
# {"OMP_NUM_THREADS": "16"}
|
||||
"extra_python_environs_for_driver": {},
|
||||
# The extra python environments need to set for worker processes.
|
||||
"extra_python_environs_for_worker": {},
|
||||
|
||||
# === Resource Settings ===
|
||||
# Number of GPUs to allocate to the trainer process. Note that not all
|
||||
# algorithms can take advantage of trainer GPUs. Support for multi-GPU
|
||||
# Number of GPUs to allocate to the algorithm process. Note that not all
|
||||
# algorithms can take advantage of GPUs. Support for multi-GPU
|
||||
# is currently only available for tf-[PPO/IMPALA/DQN/PG].
|
||||
# This can be fractional (e.g., 0.3 GPUs).
|
||||
"num_gpus": 0,
|
||||
|
@ -528,13 +528,13 @@ The following is a list of the common algorithm hyper-parameters:
|
|||
"num_gpus_per_worker": 0,
|
||||
# Any custom Ray resources to allocate per worker.
|
||||
"custom_resources_per_worker": {},
|
||||
# Number of CPUs to allocate for the trainer. Note: this only takes effect
|
||||
# when running in Tune. Otherwise, the trainer runs in the main program.
|
||||
# Number of CPUs to allocate for the algorithm. Note: this only takes effect
|
||||
# when running in Tune. Otherwise, the algorithm runs in the main program.
|
||||
"num_cpus_for_driver": 1,
|
||||
# The strategy for the placement group factory returned by
|
||||
# `Trainer.default_resource_request()`. A PlacementGroup defines, which
|
||||
# `Algorithm.default_resource_request()`. A PlacementGroup defines, which
|
||||
# devices (resources) should always be co-located on the same node.
|
||||
# For example, a Trainer with 2 rollout workers, running with
|
||||
# For example, an Algorithm with 2 rollout workers, running with
|
||||
# num_gpus=1 will request a placement group with the bundles:
|
||||
# [{"gpu": 1, "cpu": 1}, {"cpu": 1}, {"cpu": 1}], where the first bundle is
|
||||
# for the driver and the other 2 bundles are for the two workers.
|
||||
|
@ -655,7 +655,7 @@ The following is a list of the common algorithm hyper-parameters:
|
|||
|
||||
# === API deprecations/simplifications/changes ===
|
||||
# If True, the execution plan API will not be used. Instead,
|
||||
# a Trainer's `training_step()` method will be called on each
|
||||
# a Algorithm's `training_step()` method will be called on each
|
||||
# training iteration.
|
||||
"_disable_execution_plan_api": True,
|
||||
|
||||
|
@ -716,17 +716,17 @@ Here is an example of the basic usage (for a more complete example, see `custom_
|
|||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
config["num_gpus"] = 0
|
||||
config["num_workers"] = 1
|
||||
trainer = ppo.PPO(config=config, env="CartPole-v0")
|
||||
algo = ppo.PPO(config=config, env="CartPole-v0")
|
||||
|
||||
# Can optionally call trainer.restore(path) to load a checkpoint.
|
||||
# Can optionally call algo.restore(path) to load a checkpoint.
|
||||
|
||||
for i in range(1000):
|
||||
# Perform one iteration of training the policy with PPO
|
||||
result = trainer.train()
|
||||
result = algo.train()
|
||||
print(pretty_print(result))
|
||||
|
||||
if i % 100 == 0:
|
||||
checkpoint = trainer.save()
|
||||
checkpoint = algo.save()
|
||||
print("checkpoint saved at", checkpoint)
|
||||
|
||||
# Also, in case you have trained a model outside of ray/RLlib and have created
|
||||
|
@ -734,9 +734,9 @@ Here is an example of the basic usage (for a more complete example, see `custom_
|
|||
# my_keras_model_trained_outside_rllib.save_weights("model.h5")
|
||||
# (see: https://keras.io/models/about-keras-models/)
|
||||
|
||||
# ... you can load the h5-weights into your Trainer's Policy's ModelV2
|
||||
# ... you can load the h5-weights into your Algorithm's Policy's ModelV2
|
||||
# (tf or torch) by doing:
|
||||
trainer.import_model("my_weights.h5")
|
||||
algo.import_model("my_weights.h5")
|
||||
# NOTE: In order for this to work, your (custom) model needs to implement
|
||||
# the `import_from_h5` method.
|
||||
# See https://github.com/ray-project/ray/blob/master/rllib/tests/test_model_imports.py
|
||||
|
@ -744,9 +744,9 @@ Here is an example of the basic usage (for a more complete example, see `custom_
|
|||
|
||||
.. note::
|
||||
|
||||
It's recommended that you run RLlib trainers with :doc:`Tune <../tune/index>`, for easy experiment management and visualization of results. Just set ``"run": ALG_NAME, "env": ENV_NAME`` in the experiment config.
|
||||
It's recommended that you run RLlib algorithms with :doc:`Tune <../tune/index>`, for easy experiment management and visualization of results. Just set ``"run": ALG_NAME, "env": ENV_NAME`` in the experiment config.
|
||||
|
||||
All RLlib trainers are compatible with the :ref:`Tune API <tune-60-seconds>`. This enables them to be easily used in experiments with :doc:`Tune <../tune/index>`. For example, the following code performs a simple hyperparam sweep of PPO:
|
||||
All RLlib algorithms are compatible with the :ref:`Tune API <tune-60-seconds>`. This enables them to be easily used in experiments with :doc:`Tune <../tune/index>`. For example, the following code performs a simple hyperparam sweep of PPO:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -818,7 +818,7 @@ Loading and restoring a trained agent from a checkpoint is simple:
|
|||
Computing Actions
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
The simplest way to programmatically compute actions from a trained agent is to use ``trainer.compute_action()``.
|
||||
The simplest way to programmatically compute actions from a trained agent is to use ``Algorithm.compute_action()``.
|
||||
This method preprocesses and filters the observation before passing it to the agent policy.
|
||||
Here is a simple example of testing a trained agent for one episode:
|
||||
|
||||
|
@ -836,12 +836,12 @@ Here is a simple example of testing a trained agent for one episode:
|
|||
obs, reward, done, info = env.step(action)
|
||||
episode_reward += reward
|
||||
|
||||
For more advanced usage, you can access the ``workers`` and policies held by the trainer
|
||||
For more advanced usage, you can access the ``workers`` and policies held by the algorithm
|
||||
directly as ``compute_action()`` does:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Trainer(Trainable):
|
||||
class Algorithm(Trainable):
|
||||
|
||||
@PublicAPI
|
||||
def compute_action(self,
|
||||
|
@ -857,7 +857,7 @@ directly as ``compute_action()`` does:
|
|||
Note that you can also access the policy object through
|
||||
self.get_policy(policy_id) and call compute_actions() on it directly.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
observation (obj): observation from the environment.
|
||||
state (list): RNN hidden state, if any. If state is not None,
|
||||
then all of compute_single_action(...) is returned
|
||||
|
@ -905,23 +905,31 @@ directly as ``compute_action()`` does:
|
|||
|
||||
Accessing Policy State
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
It is common to need to access a trainer's internal state, e.g., to set or get internal weights. In RLlib trainer state is replicated across multiple *rollout workers* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``trainer.workers.foreach_worker()`` or ``trainer.workers.foreach_worker_with_index()``. These functions take a lambda function that is applied with the worker as an arg. You can also return values from these functions and those will be returned as a list.
|
||||
It is common to need to access a algorithm's internal state, e.g., to set or get internal weights.
|
||||
In RLlib algorithm state is replicated across multiple *rollout workers* (Ray actors) in the cluster.
|
||||
However, you can easily get and update this state between calls to ``train()`` via ``Algorithm.workers.foreach_worker()`` or ``Algorithm.workers.foreach_worker_with_index()``.
|
||||
These functions take a lambda function that is applied with the worker as an arg.
|
||||
You can also return values from these functions and those will be returned as a list.
|
||||
|
||||
You can also access just the "master" copy of the trainer state through ``trainer.get_policy()`` or ``trainer.workers.local_worker()``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``trainer.get_policy().get_weights()``. This is also equivalent to ``trainer.workers.local_worker().policy_map["default_policy"].get_weights()``:
|
||||
You can also access just the "master" copy of the algorithm state through ``Algorithm.get_policy()`` or
|
||||
``Algorithm.workers.local_worker()``, but note that updates here may not be immediately reflected in
|
||||
remote replicas if you have configured ``num_workers > 0``.
|
||||
For example, to access the weights of a local TF policy, you can run ``Algorithm.get_policy().get_weights()``.
|
||||
This is also equivalent to ``Algorithm.workers.local_worker().policy_map["default_policy"].get_weights()``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Get weights of the default local policy
|
||||
trainer.get_policy().get_weights()
|
||||
algo.get_policy().get_weights()
|
||||
|
||||
# Same as above
|
||||
trainer.workers.local_worker().policy_map["default_policy"].get_weights()
|
||||
algo.workers.local_worker().policy_map["default_policy"].get_weights()
|
||||
|
||||
# Get list of weights of each worker, including remote replicas
|
||||
trainer.workers.foreach_worker(lambda ev: ev.get_policy().get_weights())
|
||||
algo.workers.foreach_worker(lambda ev: ev.get_policy().get_weights())
|
||||
|
||||
# Same as above
|
||||
trainer.workers.foreach_worker_with_index(lambda ev, i: ev.get_policy().get_weights())
|
||||
algo.workers.foreach_worker_with_index(lambda ev, i: ev.get_policy().get_weights())
|
||||
|
||||
Accessing Model State
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -967,7 +975,9 @@ Advanced Python APIs
|
|||
Custom Training Workflows
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
In the `basic training example <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__, Tune will call ``train()`` on your trainer once per training iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports :ref:`custom trainable functions <trainable-docs>` that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_train_fn.py>`__.
|
||||
In the `basic training example <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__, Tune will call ``train()`` on your algorithm once per training iteration and report the new training results.
|
||||
Sometimes, it is desirable to have full control over training, but still run inside Tune.
|
||||
Tune supports :ref:`custom trainable functions <trainable-docs>` that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_train_fn.py>`__.
|
||||
|
||||
For even finer-grained control over training, you can use RLlib's lower-level `building blocks <rllib-concepts.html>`__ directly to implement `fully customized training workflows <https://github.com/ray-project/ray/blob/master/rllib/examples/rollout_worker_custom_workflow.py>`__.
|
||||
|
||||
|
@ -1003,7 +1013,7 @@ You can provide callbacks to be called at points during policy evaluation. These
|
|||
|
||||
User-defined state can be stored for the `episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__ in the ``episode.user_data`` dict, and custom scalar metrics reported by saving values to the ``episode.custom_metrics`` dict. These custom metrics will be aggregated and reported as part of training results. For a full example, see `custom_metrics_and_callbacks.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_metrics_and_callbacks.py>`__.
|
||||
|
||||
.. autoclass:: ray.rllib.agents.callbacks.DefaultCallbacks
|
||||
.. autoclass:: ray.rllib.algorithms.callbacks.DefaultCallbacks
|
||||
:members:
|
||||
|
||||
|
||||
|
@ -1012,7 +1022,7 @@ Chaining Callbacks
|
|||
|
||||
Use the ``MultiCallbacks`` class to chaim multiple callbacks together.
|
||||
|
||||
.. autoclass:: ray.rllib.agents.callbacks.MultiCallbacks
|
||||
.. autoclass:: ray.rllib.algorithms.callbacks.MultiCallbacks
|
||||
|
||||
|
||||
Visualizing Custom Metrics
|
||||
|
@ -1032,19 +1042,19 @@ exploration behavior, including the decisions (how and whether) to sample
|
|||
actions from distributions (stochastically or deterministically).
|
||||
The setup can be done via using built-in Exploration classes
|
||||
(see `this package <https://github.com/ray-project/ray/blob/master/rllib/utils/exploration/>`__),
|
||||
which are specified (and further configured) inside ``Trainer.config["exploration_config"]``.
|
||||
which are specified (and further configured) inside ``Algorithm.config["exploration_config"]``.
|
||||
Besides using one of the available classes, one can sub-class any of
|
||||
these built-ins, add custom behavior to it, and use that new class in
|
||||
the config instead.
|
||||
|
||||
Every policy has-an Exploration object, which is created from the Trainer’s
|
||||
Every policy has-an Exploration object, which is created from the Algorithm’s
|
||||
``config[“exploration_config”]`` dict, which specifies the class to use via the
|
||||
special “type” key, as well as constructor arguments via all other keys,
|
||||
e.g.:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# in Trainer.config:
|
||||
# in Algorithm.config:
|
||||
"exploration_config": {
|
||||
"type": "StochasticSampling", # <- Special `type` key provides class information
|
||||
"[c'tor arg]" : "[value]", # <- Add any needed constructor args here.
|
||||
|
@ -1070,19 +1080,19 @@ b) log-likelihood:
|
|||
:start-after: __sphinx_doc_begin_get_exploration_action__
|
||||
:end-before: __sphinx_doc_end_get_exploration_action__
|
||||
|
||||
On the highest level, the ``Trainer.compute_action`` and ``Policy.compute_action(s)``
|
||||
On the highest level, the ``Algorithm.compute_actions`` and ``Policy.compute_actions``
|
||||
methods have a boolean ``explore`` switch, which is passed into
|
||||
``Exploration.get_exploration_action``. If ``explore=None``, the value of
|
||||
``Trainer.config[“explore”]`` is used, which thus serves as a main switch for
|
||||
``Algorithm.config[“explore”]`` is used, which thus serves as a main switch for
|
||||
exploratory behavior, allowing e.g. turning off any exploration easily for
|
||||
evaluation purposes (see :ref:`CustomEvaluation`).
|
||||
|
||||
The following are example excerpts from different Trainers' configs
|
||||
(see rllib/agents/trainer.py) to setup different exploration behaviors:
|
||||
The following are example excerpts from different Algorithms' configs
|
||||
(see rllib/algorithms/algorithm.py) to setup different exploration behaviors:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# All of the following configs go into Trainer.config.
|
||||
# All of the following configs go into Algorithm.config.
|
||||
|
||||
# 1) Switching *off* exploration by default.
|
||||
# Behavior: Calling `compute_action(s)` without explicitly setting its `explore`
|
||||
|
@ -1119,10 +1129,10 @@ The following are example excerpts from different Trainers' configs
|
|||
"temperature": 1.0,
|
||||
},
|
||||
|
||||
# c) All policy-gradient algos and SAC: see rllib/agents/trainer.py
|
||||
# c) All policy-gradient algos and SAC: see rllib/algorithms/algorithm.py
|
||||
# Behavior: The algo samples stochastically from the
|
||||
# model-parameterized distribution. This is the global Trainer default
|
||||
# setting defined in trainer.py and used by all PG-type algos (plus SAC).
|
||||
# model-parameterized distribution. This is the global Algorithm default
|
||||
# setting defined in algorithm.py and used by all PG-type algos (plus SAC).
|
||||
"explore": True,
|
||||
"exploration_config": {
|
||||
"type": "StochasticSampling",
|
||||
|
@ -1137,13 +1147,13 @@ Customized Evaluation During Training
|
|||
|
||||
RLlib will report online training rewards, however in some cases you may want to compute
|
||||
rewards with different settings (e.g., with exploration turned off, or on a specific set
|
||||
of environment configurations). You can activate evaluating policies during training (``Trainer.train()``) by setting
|
||||
the ``evaluation_interval`` to an int value (> 0) indicating every how many ``Trainer.train()``
|
||||
of environment configurations). You can activate evaluating policies during training (``Algorithm.train()``) by setting
|
||||
the ``evaluation_interval`` to an int value (> 0) indicating every how many ``Algorithm.train()``
|
||||
calls an "evaluation step" is run:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Run one evaluation step on every 3rd `Trainer.train()` call.
|
||||
# Run one evaluation step on every 3rd `Algorithm.train()` call.
|
||||
{
|
||||
"evaluation_interval": 3,
|
||||
}
|
||||
|
@ -1186,7 +1196,7 @@ roughly as long as the train step:
|
|||
.. code-block:: python
|
||||
|
||||
# Run eval and train at the same time via threading and make sure they roughly
|
||||
# take the same time, such that the next `Trainer.train()` call can execute
|
||||
# take the same time, such that the next `Algorithm.train()` call can execute
|
||||
# immediately and not have to wait for a still ongoing (e.g. very long episode)
|
||||
# evaluation step:
|
||||
{
|
||||
|
@ -1204,8 +1214,8 @@ do:
|
|||
.. code-block:: python
|
||||
|
||||
# Switching off exploration behavior for evaluation workers
|
||||
# (see rllib/agents/trainer.py). Use any keys in this sub-dict that are
|
||||
# also supported in the main Trainer config.
|
||||
# (see rllib/algorithms/algorithm.py). Use any keys in this sub-dict that are
|
||||
# also supported in the main Algorithm config.
|
||||
"evaluation_config": {
|
||||
"explore": False
|
||||
}
|
||||
|
@ -1223,8 +1233,8 @@ run as much in parallel as possible. For example, if your ``evaluation_duration=
|
|||
only has to run 1 episode in each eval step.
|
||||
|
||||
In case you would like to entirely customize the evaluation step, set ``custom_eval_function`` in your
|
||||
config to a callable taking the Trainer object and a WorkerSet object (the evaluation WorkerSet)
|
||||
and returning a metrics dict. See `trainer.py <https://github.com/ray-project/ray/blob/master/rllib/agents/trainer.py>`__
|
||||
config to a callable taking the Algorithm object and a WorkerSet object (the evaluation WorkerSet)
|
||||
and returning a metrics dict. See `algorithm.py <https://github.com/ray-project/ray/blob/master/rllib/algorithms/algorithm.py>`__
|
||||
for further documentation.
|
||||
|
||||
There is an end to end example of how to set up custom online evaluation in `custom_eval.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_eval.py>`__. Note that if you only want to eval your policy at the end of training, you can set ``evaluation_interval: N``, where ``N`` is the number of training iterations before stopping.
|
||||
|
@ -1237,12 +1247,12 @@ Below are some examples of how the custom evaluation metrics are reported nested
|
|||
Sample output for `python custom_eval.py`
|
||||
------------------------------------------------------------------------
|
||||
|
||||
INFO trainer.py:623 -- Evaluating current policy for 10 episodes.
|
||||
INFO trainer.py:650 -- Running round 0 of parallel evaluation (2/10 episodes)
|
||||
INFO trainer.py:650 -- Running round 1 of parallel evaluation (4/10 episodes)
|
||||
INFO trainer.py:650 -- Running round 2 of parallel evaluation (6/10 episodes)
|
||||
INFO trainer.py:650 -- Running round 3 of parallel evaluation (8/10 episodes)
|
||||
INFO trainer.py:650 -- Running round 4 of parallel evaluation (10/10 episodes)
|
||||
INFO algorithm.py:623 -- Evaluating current policy for 10 episodes.
|
||||
INFO algorithm.py:650 -- Running round 0 of parallel evaluation (2/10 episodes)
|
||||
INFO algorithm.py:650 -- Running round 1 of parallel evaluation (4/10 episodes)
|
||||
INFO algorithm.py:650 -- Running round 2 of parallel evaluation (6/10 episodes)
|
||||
INFO algorithm.py:650 -- Running round 3 of parallel evaluation (8/10 episodes)
|
||||
INFO algorithm.py:650 -- Running round 4 of parallel evaluation (10/10 episodes)
|
||||
|
||||
Result for PG_SimpleCorridor_2c6b27dc:
|
||||
...
|
||||
|
@ -1326,17 +1336,17 @@ which receives the last training results and returns a new task for the env to b
|
|||
new_task = current_task + 1
|
||||
return new_task
|
||||
|
||||
# Setup your Trainer's config like so:
|
||||
# Setup your Algorithm's config like so:
|
||||
config = {
|
||||
"env": MyEnv,
|
||||
"env_task_fn": curriculum_fn,
|
||||
}
|
||||
# Train using `tune.run` or `Trainer.train()` and the above config stub.
|
||||
# Train using `tune.run` or `Algorithm.train()` and the above config stub.
|
||||
# ...
|
||||
|
||||
There are two more ways to use the RLlib's other APIs to implement `curriculum learning <https://bair.berkeley.edu/blog/2017/12/20/reverse-curriculum/>`__.
|
||||
|
||||
Use the Trainer API and update the environment between calls to ``train()``. This example shows the trainer being run inside a Tune function.
|
||||
Use the Algorithm API and update the environment between calls to ``train()``. This example shows the algorithm being run inside a Tune function.
|
||||
This is basically the same as what the built-in `env_task_fn` API described above already does under the hood, but allows you to do even more
|
||||
customizations to your training loop.
|
||||
|
||||
|
@ -1347,9 +1357,9 @@ customizations to your training loop.
|
|||
from ray.rllib.algorithms.ppo import PPO
|
||||
|
||||
def train(config, reporter):
|
||||
trainer = PPO(config=config, env=YourEnv)
|
||||
algo = PPO(config=config, env=YourEnv)
|
||||
while True:
|
||||
result = trainer.train()
|
||||
result = algo.train()
|
||||
reporter(**result)
|
||||
if result["episode_reward_mean"] > 200:
|
||||
task = 2
|
||||
|
@ -1357,7 +1367,7 @@ customizations to your training loop.
|
|||
task = 1
|
||||
else:
|
||||
task = 0
|
||||
trainer.workers.foreach_worker(
|
||||
algo.workers.foreach_worker(
|
||||
lambda ev: ev.foreach_env(
|
||||
lambda env: env.set_task(task)))
|
||||
|
||||
|
@ -1382,17 +1392,17 @@ You could also use RLlib's callbacks API to update the environment on new traini
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
|
||||
def on_train_result(info):
|
||||
result = info["result"]
|
||||
class MyCallbacks(DefaultCallbacks):
|
||||
def on_train_result(self, algorithm, result, **kwargs):
|
||||
if result["episode_reward_mean"] > 200:
|
||||
task = 2
|
||||
elif result["episode_reward_mean"] > 100:
|
||||
task = 1
|
||||
else:
|
||||
task = 0
|
||||
trainer = info["trainer"]
|
||||
trainer.workers.foreach_worker(
|
||||
algorithm.workers.foreach_worker(
|
||||
lambda ev: ev.foreach_env(
|
||||
lambda env: env.set_task(task)))
|
||||
|
||||
|
@ -1401,9 +1411,7 @@ You could also use RLlib's callbacks API to update the environment on new traini
|
|||
"PPO",
|
||||
config={
|
||||
"env": YourEnv,
|
||||
"callbacks": {
|
||||
"on_train_result": on_train_result,
|
||||
},
|
||||
"callbacks": MyCallbacks,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -1444,8 +1452,8 @@ However, eager can be slower than graph mode unless tracing is enabled.
|
|||
Using PyTorch
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
Trainers that have an implemented TorchPolicy, will allow you to run
|
||||
`rllib train` using the command line ``--torch`` flag.
|
||||
Algorithms that have an implemented TorchPolicy, will allow you to run
|
||||
`rllib train` using the command line ``--framework=torch`` flag.
|
||||
Algorithms that do not have a torch version yet will complain with an error in
|
||||
this case.
|
||||
|
||||
|
@ -1467,7 +1475,10 @@ You can use the `data output API <rllib-offline.html>`__ to save episode traces
|
|||
Log Verbosity
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
You can control the trainer log level via the ``"log_level"`` flag. Valid values are "DEBUG", "INFO", "WARN" (default), and "ERROR". This can be used to increase or decrease the verbosity of internal logging. You can also use the ``-v`` and ``-vv`` flags. For example, the following two commands are about equivalent:
|
||||
You can control the log level via the ``"log_level"`` flag. Valid values are "DEBUG",
|
||||
"INFO", "WARN" (default), and "ERROR". This can be used to increase or decrease the
|
||||
verbosity of internal logging. You can also use the ``-v`` and ``-vv`` flags.
|
||||
For example, the following two commands are about equivalent:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
|
|
@ -11,12 +11,12 @@ from ray.air.preprocessor import Preprocessor
|
|||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.train.rl import RLTrainer
|
||||
|
||||
from ray.rllib.agents import Trainer
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.tune.utils.trainable import TrainableUtil
|
||||
|
||||
|
||||
class _DummyTrainer(Trainer):
|
||||
class _DummyAlgo(Algorithm):
|
||||
train_exec_impl = None
|
||||
|
||||
def setup(self, config):
|
||||
|
@ -59,7 +59,7 @@ def create_checkpoint(
|
|||
preprocessor: Optional[Preprocessor] = None, config: Optional[dict] = None
|
||||
) -> Checkpoint:
|
||||
rl_trainer = RLTrainer(
|
||||
algorithm=_DummyTrainer,
|
||||
algorithm=_DummyAlgo,
|
||||
config=config or {},
|
||||
preprocessor=preprocessor,
|
||||
)
|
||||
|
|
|
@ -10,9 +10,9 @@ from ray.air._internal.checkpointing import (
|
|||
load_preprocessor_from_dir,
|
||||
save_preprocessor_to_dir,
|
||||
)
|
||||
from ray.rllib.agents.trainer import Trainer as RLlibTrainer
|
||||
from ray.rllib.algorithms.algorithm import Algorithm as RLlibAlgo
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.typing import PartialTrainerConfigDict, EnvType
|
||||
from ray.rllib.utils.typing import PartialAlgorithmConfigDict, EnvType
|
||||
from ray.tune import Trainable, PlacementGroupFactory
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.registry import get_trainable_cls
|
||||
|
@ -86,7 +86,7 @@ class RLTrainer(BaseTrainer):
|
|||
import ray
|
||||
from ray.air.config import RunConfig
|
||||
from ray.train.rl import RLTrainer
|
||||
from ray.rllib.agents.marwil.bc import BCTrainer
|
||||
from ray.rllib.algorithms.bc.bc import BC
|
||||
|
||||
dataset = ray.data.read_json(
|
||||
"/tmp/data-dir", parallelism=2, ray_remote_args={"num_cpus": 1}
|
||||
|
@ -114,7 +114,7 @@ class RLTrainer(BaseTrainer):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: Union[str, Type[RLlibTrainer]],
|
||||
algorithm: Union[str, Type[RLlibAlgo]],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
scaling_config: Optional[ScalingConfig] = None,
|
||||
run_config: Optional[RunConfig] = None,
|
||||
|
@ -137,8 +137,7 @@ class RLTrainer(BaseTrainer):
|
|||
super(RLTrainer, self)._validate_attributes()
|
||||
|
||||
if not isinstance(self._algorithm, str) and not (
|
||||
inspect.isclass(self._algorithm)
|
||||
and issubclass(self._algorithm, RLlibTrainer)
|
||||
inspect.isclass(self._algorithm) and issubclass(self._algorithm, RLlibAlgo)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`algorithm` should be either a string or a RLlib trainer class, "
|
||||
|
@ -201,7 +200,7 @@ class RLTrainer(BaseTrainer):
|
|||
class AIRRLTrainer(rllib_trainer):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[PartialTrainerConfigDict] = None,
|
||||
config: Optional[PartialAlgorithmConfigDict] = None,
|
||||
env: Optional[Union[str, EnvType]] = None,
|
||||
logger_creator: Optional[Callable[[], Logger]] = None,
|
||||
remote_checkpoint_dir: Optional[str] = None,
|
||||
|
@ -241,7 +240,7 @@ class RLTrainer(BaseTrainer):
|
|||
|
||||
@classmethod
|
||||
def default_resource_request(
|
||||
cls, config: PartialTrainerConfigDict
|
||||
cls, config: PartialAlgorithmConfigDict
|
||||
) -> Union[Resources, PlacementGroupFactory]:
|
||||
resolved_config = merge_dicts(base_config, config)
|
||||
param_dict["config"] = resolved_config
|
||||
|
|
|
@ -56,7 +56,13 @@ class _CallbackMeta(ABCMeta):
|
|||
|
||||
@classmethod
|
||||
def need_override_by_subclass(mcs, attr_name: str, attr: Any) -> bool:
|
||||
return (attr_name.startswith("on_") or attr_name == "setup") and callable(attr)
|
||||
return (
|
||||
(
|
||||
attr_name.startswith("on_")
|
||||
and not attr_name.startswith("on_trainer_init")
|
||||
)
|
||||
or attr_name == "setup"
|
||||
) and callable(attr)
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
|
|
|
@ -8,7 +8,7 @@ import uuid
|
|||
|
||||
import ray._private.utils
|
||||
|
||||
from ray.rllib.agents.mock import _MockTrainer
|
||||
from ray.rllib.algorithms.mock import _MockTrainer
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.sync_client import get_sync_client
|
||||
|
|
|
@ -7,7 +7,7 @@ import time
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents import DefaultCallbacks
|
||||
from ray.rllib.algorithms.callbacks import DefaultCallbacks
|
||||
from ray.rllib.algorithms.ppo import PPO
|
||||
|
||||
|
||||
|
@ -34,7 +34,7 @@ def fn_trainable(config, checkpoint_dir=None):
|
|||
|
||||
|
||||
class RLlibCallback(DefaultCallbacks):
|
||||
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
|
||||
def on_train_result(self, *, algorithm, result: dict, **kwargs) -> None:
|
||||
result["internal_iter"] = result["training_iteration"]
|
||||
|
||||
|
||||
|
|
13
rllib/BUILD
13
rllib/BUILD
|
@ -1796,6 +1796,13 @@ py_test(
|
|||
# for `tests/test_all_stuff.py`.
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test(
|
||||
name = "tests/backward_compat/test_backward_compat",
|
||||
tags = ["team:rllib", "tests_dir", "tests_dir_B"],
|
||||
size = "medium",
|
||||
srcs = ["tests/backward_compat/test_backward_compat.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_algorithm_imports",
|
||||
tags = ["team:rllib", "tests_dir", "tests_dir_C"],
|
||||
|
@ -2626,7 +2633,7 @@ py_test(
|
|||
tags = ["team:rllib", "exclusive", "multi_gpu", "examples"],
|
||||
size = "medium",
|
||||
srcs = ["examples/deterministic_training.py"],
|
||||
args = ["--as-test", "--stop-iters=1", "--framework=tf", "--num-gpus-trainer=1", "--num-gpus-per-worker=1"]
|
||||
args = ["--as-test", "--stop-iters=1", "--framework=tf", "--num-gpus=1", "--num-gpus-per-worker=1"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -2635,7 +2642,7 @@ py_test(
|
|||
tags = ["team:rllib", "exclusive", "multi_gpu", "examples"],
|
||||
size = "medium",
|
||||
srcs = ["examples/deterministic_training.py"],
|
||||
args = ["--as-test", "--stop-iters=1", "--framework=tf2", "--num-gpus-trainer=1", "--num-gpus-per-worker=1"]
|
||||
args = ["--as-test", "--stop-iters=1", "--framework=tf2", "--num-gpus=1", "--num-gpus-per-worker=1"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -2644,7 +2651,7 @@ py_test(
|
|||
tags = ["team:rllib", "exclusive", "multi_gpu", "examples"],
|
||||
size = "medium",
|
||||
srcs = ["examples/deterministic_training.py"],
|
||||
args = ["--as-test", "--stop-iters=1", "--framework=torch", "--num-gpus-trainer=1", "--num-gpus-per-worker=1"]
|
||||
args = ["--as-test", "--stop-iters=1", "--framework=torch", "--num-gpus=1", "--num-gpus-per-worker=1"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
|
|
@ -174,9 +174,9 @@ Quick First Experiment
|
|||
return self.cur_obs, reward, done, {}
|
||||
|
||||
|
||||
# Create an RLlib Trainer instance to learn how to act in the above
|
||||
# Create an RLlib Algorithm instance to learn how to act in the above
|
||||
# environment.
|
||||
trainer = PPO(
|
||||
algo = PPO(
|
||||
config={
|
||||
# Env class to use (here: our gym.Env sub-class from above).
|
||||
"env": ParrotEnv,
|
||||
|
@ -193,7 +193,7 @@ Quick First Experiment
|
|||
# (exact match between observation and action value),
|
||||
# we can expect to reach an optimal episode reward of 0.0.
|
||||
for i in range(5):
|
||||
results = trainer.train()
|
||||
results = algo.train()
|
||||
print(f"Iter: {i}; avg. reward={results['episode_reward_mean']}")
|
||||
|
||||
|
||||
|
@ -220,7 +220,7 @@ and `attention nets <https://github.com/ray-project/ray/blob/master/rllib/exampl
|
|||
while not done:
|
||||
# Compute a single action, given the current observation
|
||||
# from the environment.
|
||||
action = trainer.compute_single_action(obs)
|
||||
action = algo.compute_single_action(obs)
|
||||
# Apply the computed action in the environment.
|
||||
obs, reward, done, info = env.step(action)
|
||||
# Sum up rewards for reporting purposes.
|
||||
|
|
|
@ -28,8 +28,8 @@ def _setup_logger():
|
|||
|
||||
|
||||
def _register_all():
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.agents.registry import ALGORITHMS, get_trainer_class
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.registry import ALGORITHMS, get_algorithm_class
|
||||
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
|
||||
|
||||
for key in (
|
||||
|
@ -38,12 +38,12 @@ def _register_all():
|
|||
+ ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]
|
||||
):
|
||||
logging.warning(key)
|
||||
register_trainable(key, get_trainer_class(key))
|
||||
register_trainable(key, get_algorithm_class(key))
|
||||
|
||||
def _see_contrib(name):
|
||||
"""Returns dummy agent class warning algo is in contrib/."""
|
||||
|
||||
class _SeeContrib(Trainer):
|
||||
class _SeeContrib(Algorithm):
|
||||
def setup(self, config):
|
||||
raise NameError("Please run `contrib/{}` instead.".format(name))
|
||||
|
||||
|
|
|
@ -1,15 +1,7 @@
|
|||
from ray.rllib.agents.callbacks import (
|
||||
DefaultCallbacks,
|
||||
MemoryTrackingCallbacks,
|
||||
MultiCallbacks,
|
||||
)
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.algorithms.algorithm import Algorithm as Trainer, with_common_config
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig as TrainerConfig
|
||||
|
||||
__all__ = [
|
||||
"DefaultCallbacks",
|
||||
"MemoryTrackingCallbacks",
|
||||
"MultiCallbacks",
|
||||
"Trainer",
|
||||
"TrainerConfig",
|
||||
"with_common_config",
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import ray.rllib.agents.a3c.a2c as a2c # noqa
|
||||
from ray.rllib.algorithms.a2c.a2c import (
|
||||
A2CConfig,
|
||||
A2C as A2CTrainer,
|
||||
|
|
4
rllib/agents/a3c/a2c.py
Normal file
4
rllib/agents/a3c/a2c.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from ray.rllib.algorithms.a2c import ( # noqa
|
||||
A2C as A2CTrainer,
|
||||
A2C_DEFAULT_CONFIG,
|
||||
)
|
6
rllib/agents/a3c/a3c.py
Normal file
6
rllib/agents/a3c/a3c.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from ray.rllib.algorithms.a3c import ( # noqa
|
||||
a3c_tf_policy,
|
||||
a3c_torch_policy,
|
||||
A3C as A3CTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
|
@ -1,4 +1,4 @@
|
|||
from ray.rllib.algorithms.ars.ars import ARSTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.ars.ars import ARS as ARSTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.ars.ars_tf_policy import ARSTFPolicy
|
||||
from ray.rllib.algorithms.ars.ars_torch_policy import ARSTorchPolicy
|
||||
|
||||
|
|
|
@ -1,570 +1,13 @@
|
|||
import numpy as np
|
||||
import os
|
||||
import tracemalloc
|
||||
from typing import Dict, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.episode import Episode
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.algorithms.callbacks import ( # noqa
|
||||
DefaultCallbacks,
|
||||
MemoryTrackingCallbacks,
|
||||
MultiCallbacks,
|
||||
RE3UpdateCallbacks,
|
||||
)
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.exploration.random_encoder import (
|
||||
_MovingMeanStd,
|
||||
compute_states_entropy,
|
||||
update_beta,
|
||||
)
|
||||
from ray.rllib.utils.typing import AgentID, EnvType, PolicyID
|
||||
from ray.tune.callback import _CallbackMeta
|
||||
|
||||
# Import psutil after ray so the packaged version is used.
|
||||
import psutil
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.evaluation import RolloutWorker
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class DefaultCallbacks(metaclass=_CallbackMeta):
|
||||
"""Abstract base class for RLlib callbacks (similar to Keras callbacks).
|
||||
|
||||
These callbacks can be used for custom metrics and custom postprocessing.
|
||||
|
||||
By default, all of these callbacks are no-ops. To configure custom training
|
||||
callbacks, subclass DefaultCallbacks and then set
|
||||
{"callbacks": YourCallbacksClass} in the trainer config.
|
||||
"""
|
||||
|
||||
def __init__(self, legacy_callbacks_dict: Dict[str, callable] = None):
|
||||
if legacy_callbacks_dict:
|
||||
deprecation_warning(
|
||||
"callbacks dict interface",
|
||||
"a class extending rllib.agents.callbacks.DefaultCallbacks",
|
||||
old="ray.rllib.agents.callbacks",
|
||||
new="ray.rllib.algorithms.callbacks",
|
||||
error=False,
|
||||
)
|
||||
self.legacy_callbacks = legacy_callbacks_dict or {}
|
||||
|
||||
def on_sub_environment_created(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
sub_environment: EnvType,
|
||||
env_context: EnvContext,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Callback run when a new sub-environment has been created.
|
||||
|
||||
This method gets called after each sub-environment (usually a
|
||||
gym.Env) has been created, validated (RLlib built-in validation
|
||||
+ possible custom validation function implemented by overriding
|
||||
`Trainer.validate_env()`), wrapped (e.g. video-wrapper), and seeded.
|
||||
|
||||
Args:
|
||||
worker: Reference to the current rollout worker.
|
||||
sub_environment: The sub-environment instance that has been
|
||||
created. This is usually a gym.Env object.
|
||||
env_context: The `EnvContext` object that has been passed to
|
||||
the env's constructor.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_trainer_init(
|
||||
self,
|
||||
*,
|
||||
trainer: "Trainer",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Callback run when a new trainer instance has finished setup.
|
||||
|
||||
This method gets called at the end of Trainer.setup() after all
|
||||
the initialization is done, and before actually training starts.
|
||||
|
||||
Args:
|
||||
trainer: Reference to the trainer instance.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_episode_start(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: Episode,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Callback run on the rollout worker before each episode starts.
|
||||
|
||||
Args:
|
||||
worker: Reference to the current rollout worker.
|
||||
base_env: BaseEnv running the episode. The underlying
|
||||
sub environment objects can be retrieved by calling
|
||||
`base_env.get_sub_environments()`.
|
||||
policies: Mapping of policy id to policy objects. In single
|
||||
agent mode there will only be a single "default" policy.
|
||||
episode: Episode object which contains the episode's
|
||||
state. You can use the `episode.user_data` dict to store
|
||||
temporary data, and `episode.custom_metrics` to store custom
|
||||
metrics for the episode.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_episode_start"):
|
||||
self.legacy_callbacks["on_episode_start"](
|
||||
{
|
||||
"env": base_env,
|
||||
"policy": policies,
|
||||
"episode": episode,
|
||||
}
|
||||
)
|
||||
|
||||
def on_episode_step(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Optional[Dict[PolicyID, Policy]] = None,
|
||||
episode: Episode,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Runs on each episode step.
|
||||
|
||||
Args:
|
||||
worker: Reference to the current rollout worker.
|
||||
base_env: BaseEnv running the episode. The underlying
|
||||
sub environment objects can be retrieved by calling
|
||||
`base_env.get_sub_environments()`.
|
||||
policies: Mapping of policy id to policy objects.
|
||||
In single agent mode there will only be a single
|
||||
"default_policy".
|
||||
episode: Episode object which contains episode
|
||||
state. You can use the `episode.user_data` dict to store
|
||||
temporary data, and `episode.custom_metrics` to store custom
|
||||
metrics for the episode.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_episode_step"):
|
||||
self.legacy_callbacks["on_episode_step"](
|
||||
{"env": base_env, "episode": episode}
|
||||
)
|
||||
|
||||
def on_episode_end(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: Episode,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Runs when an episode is done.
|
||||
|
||||
Args:
|
||||
worker: Reference to the current rollout worker.
|
||||
base_env: BaseEnv running the episode. The underlying
|
||||
sub environment objects can be retrieved by calling
|
||||
`base_env.get_sub_environments()`.
|
||||
policies: Mapping of policy id to policy
|
||||
objects. In single agent mode there will only be a single
|
||||
"default_policy".
|
||||
episode: Episode object which contains episode
|
||||
state. You can use the `episode.user_data` dict to store
|
||||
temporary data, and `episode.custom_metrics` to store custom
|
||||
metrics for the episode.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_episode_end"):
|
||||
self.legacy_callbacks["on_episode_end"](
|
||||
{
|
||||
"env": base_env,
|
||||
"policy": policies,
|
||||
"episode": episode,
|
||||
}
|
||||
)
|
||||
|
||||
def on_postprocess_trajectory(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
episode: Episode,
|
||||
agent_id: AgentID,
|
||||
policy_id: PolicyID,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
postprocessed_batch: SampleBatch,
|
||||
original_batches: Dict[AgentID, Tuple[Policy, SampleBatch]],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Called immediately after a policy's postprocess_fn is called.
|
||||
|
||||
You can use this callback to do additional postprocessing for a policy,
|
||||
including looking at the trajectory data of other agents in multi-agent
|
||||
settings.
|
||||
|
||||
Args:
|
||||
worker: Reference to the current rollout worker.
|
||||
episode: Episode object.
|
||||
agent_id: Id of the current agent.
|
||||
policy_id: Id of the current policy for the agent.
|
||||
policies: Mapping of policy id to policy objects. In single
|
||||
agent mode there will only be a single "default_policy".
|
||||
postprocessed_batch: The postprocessed sample batch
|
||||
for this agent. You can mutate this object to apply your own
|
||||
trajectory postprocessing.
|
||||
original_batches: Mapping of agents to their unpostprocessed
|
||||
trajectory data. You should not mutate this object.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_postprocess_traj"):
|
||||
self.legacy_callbacks["on_postprocess_traj"](
|
||||
{
|
||||
"episode": episode,
|
||||
"agent_id": agent_id,
|
||||
"pre_batch": original_batches[agent_id],
|
||||
"post_batch": postprocessed_batch,
|
||||
"all_pre_batches": original_batches,
|
||||
}
|
||||
)
|
||||
|
||||
def on_sample_end(
|
||||
self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs
|
||||
) -> None:
|
||||
"""Called at the end of RolloutWorker.sample().
|
||||
|
||||
Args:
|
||||
worker: Reference to the current rollout worker.
|
||||
samples: Batch to be returned. You can mutate this
|
||||
object to modify the samples generated.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_sample_end"):
|
||||
self.legacy_callbacks["on_sample_end"](
|
||||
{
|
||||
"worker": worker,
|
||||
"samples": samples,
|
||||
}
|
||||
)
|
||||
|
||||
def on_learn_on_batch(
|
||||
self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs
|
||||
) -> None:
|
||||
"""Called at the beginning of Policy.learn_on_batch().
|
||||
|
||||
Note: This is called before 0-padding via
|
||||
`pad_batch_to_sequences_of_same_size`.
|
||||
|
||||
Also note, SampleBatch.INFOS column will not be available on
|
||||
train_batch within this callback if framework is tf1, due to
|
||||
the fact that tf1 static graph would mistake it as part of the
|
||||
input dict if present.
|
||||
It is available though, for tf2 and torch frameworks.
|
||||
|
||||
Args:
|
||||
policy: Reference to the current Policy object.
|
||||
train_batch: SampleBatch to be trained on. You can
|
||||
mutate this object to modify the samples generated.
|
||||
result: A results dict to add custom metrics to.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def on_train_result(self, *, trainer: "Trainer", result: dict, **kwargs) -> None:
|
||||
"""Called at the end of Trainable.train().
|
||||
|
||||
Args:
|
||||
trainer: Current trainer instance.
|
||||
result: Dict of results returned from trainer.train() call.
|
||||
You can mutate this object to add additional metrics.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_train_result"):
|
||||
self.legacy_callbacks["on_train_result"](
|
||||
{
|
||||
"trainer": trainer,
|
||||
"result": result,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class MemoryTrackingCallbacks(DefaultCallbacks):
|
||||
"""MemoryTrackingCallbacks can be used to trace and track memory usage
|
||||
in rollout workers.
|
||||
|
||||
The Memory Tracking Callbacks uses tracemalloc and psutil to track
|
||||
python allocations during rollouts,
|
||||
in training or evaluation.
|
||||
|
||||
The tracking data is logged to the custom_metrics of an episode and
|
||||
can therefore be viewed in tensorboard
|
||||
(or in WandB etc..)
|
||||
|
||||
Add MemoryTrackingCallbacks callback to the tune config
|
||||
e.g. { ...'callbacks': MemoryTrackingCallbacks ...}
|
||||
|
||||
Note:
|
||||
This class is meant for debugging and should not be used
|
||||
in production code as tracemalloc incurs
|
||||
a significant slowdown in execution speed.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Will track the top 10 lines where memory is allocated
|
||||
tracemalloc.start(10)
|
||||
|
||||
def on_episode_end(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: Episode,
|
||||
env_index: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
top_stats = snapshot.statistics("lineno")
|
||||
|
||||
for stat in top_stats[:10]:
|
||||
count = stat.count
|
||||
size = stat.size
|
||||
|
||||
trace = str(stat.traceback)
|
||||
|
||||
episode.custom_metrics[f"tracemalloc/{trace}/size"] = size
|
||||
episode.custom_metrics[f"tracemalloc/{trace}/count"] = count
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
worker_rss = process.memory_info().rss
|
||||
worker_data = process.memory_info().data
|
||||
worker_vms = process.memory_info().vms
|
||||
episode.custom_metrics["tracemalloc/worker/rss"] = worker_rss
|
||||
episode.custom_metrics["tracemalloc/worker/data"] = worker_data
|
||||
episode.custom_metrics["tracemalloc/worker/vms"] = worker_vms
|
||||
|
||||
|
||||
class MultiCallbacks(DefaultCallbacks):
|
||||
"""MultiCallbacks allows multiple callbacks to be registered at
|
||||
the same time in the config of the environment.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
'callbacks': MultiCallbacks([
|
||||
MyCustomStatsCallbacks,
|
||||
MyCustomVideoCallbacks,
|
||||
MyCustomTraceCallbacks,
|
||||
....
|
||||
])
|
||||
"""
|
||||
|
||||
IS_CALLBACK_CONTAINER = True
|
||||
|
||||
def __init__(self, callback_class_list):
|
||||
super().__init__()
|
||||
self._callback_class_list = callback_class_list
|
||||
|
||||
self._callback_list = []
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._callback_list = [
|
||||
callback_class() for callback_class in self._callback_class_list
|
||||
]
|
||||
|
||||
return self
|
||||
|
||||
def on_trainer_init(self, *, trainer: "Trainer", **kwargs) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_trainer_init(trainer=trainer, **kwargs)
|
||||
|
||||
def on_sub_environment_created(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
sub_environment: EnvType,
|
||||
env_context: EnvContext,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_sub_environment_created(
|
||||
worker=worker,
|
||||
sub_environment=sub_environment,
|
||||
env_context=env_context,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_episode_start(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: Episode,
|
||||
env_index: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_episode_start(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode,
|
||||
env_index=env_index,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_episode_step(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Optional[Dict[PolicyID, Policy]] = None,
|
||||
episode: Episode,
|
||||
env_index: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_episode_step(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode,
|
||||
env_index=env_index,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_episode_end(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: Episode,
|
||||
env_index: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_episode_end(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode,
|
||||
env_index=env_index,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_postprocess_trajectory(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
episode: Episode,
|
||||
agent_id: AgentID,
|
||||
policy_id: PolicyID,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
postprocessed_batch: SampleBatch,
|
||||
original_batches: Dict[AgentID, Tuple[Policy, SampleBatch]],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_postprocess_trajectory(
|
||||
worker=worker,
|
||||
episode=episode,
|
||||
agent_id=agent_id,
|
||||
policy_id=policy_id,
|
||||
policies=policies,
|
||||
postprocessed_batch=postprocessed_batch,
|
||||
original_batches=original_batches,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_sample_end(
|
||||
self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs
|
||||
) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_sample_end(worker=worker, samples=samples, **kwargs)
|
||||
|
||||
def on_learn_on_batch(
|
||||
self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs
|
||||
) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_learn_on_batch(
|
||||
policy=policy, train_batch=train_batch, result=result, **kwargs
|
||||
)
|
||||
|
||||
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_train_result(trainer=trainer, result=result, **kwargs)
|
||||
|
||||
|
||||
# This Callback is used by the RE3 exploration strategy.
|
||||
# See rllib/examples/re3_exploration.py for details.
|
||||
class RE3UpdateCallbacks(DefaultCallbacks):
|
||||
"""Update input callbacks to mutate batch with states entropy rewards."""
|
||||
|
||||
_step = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
embeds_dim: int = 128,
|
||||
k_nn: int = 50,
|
||||
beta: float = 0.1,
|
||||
rho: float = 0.0001,
|
||||
beta_schedule: str = "constant",
|
||||
**kwargs,
|
||||
):
|
||||
self.embeds_dim = embeds_dim
|
||||
self.k_nn = k_nn
|
||||
self.beta = beta
|
||||
self.rho = rho
|
||||
self.beta_schedule = beta_schedule
|
||||
self._rms = _MovingMeanStd()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def on_learn_on_batch(
|
||||
self,
|
||||
*,
|
||||
policy: Policy,
|
||||
train_batch: SampleBatch,
|
||||
result: dict,
|
||||
**kwargs,
|
||||
):
|
||||
super().on_learn_on_batch(
|
||||
policy=policy, train_batch=train_batch, result=result, **kwargs
|
||||
)
|
||||
states_entropy = compute_states_entropy(
|
||||
train_batch[SampleBatch.OBS_EMBEDS], self.embeds_dim, self.k_nn
|
||||
)
|
||||
states_entropy = update_beta(
|
||||
self.beta_schedule, self.beta, self.rho, RE3UpdateCallbacks._step
|
||||
) * np.reshape(
|
||||
self._rms(states_entropy),
|
||||
train_batch[SampleBatch.OBS_EMBEDS].shape[:-1],
|
||||
)
|
||||
train_batch[SampleBatch.REWARDS] = (
|
||||
train_batch[SampleBatch.REWARDS] + states_entropy
|
||||
)
|
||||
if Postprocessing.ADVANTAGES in train_batch:
|
||||
train_batch[Postprocessing.ADVANTAGES] = (
|
||||
train_batch[Postprocessing.ADVANTAGES] + states_entropy
|
||||
)
|
||||
train_batch[Postprocessing.VALUE_TARGETS] = (
|
||||
train_batch[Postprocessing.VALUE_TARGETS] + states_entropy
|
||||
)
|
||||
|
||||
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
|
||||
# TODO(gjoliver): Remove explicit _step tracking and pass
|
||||
# trainer._iteration as a parameter to on_learn_on_batch() call.
|
||||
RE3UpdateCallbacks._step = result["training_iteration"]
|
||||
super().on_train_result(trainer=trainer, result=result, **kwargs)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import ray.rllib.agents.ddpg.apex as apex # noqa
|
||||
import ray.rllib.agents.ddpg.td3 as td3 # noqa
|
||||
from ray.rllib.algorithms.apex_ddpg.apex_ddpg import ApexDDPG as ApexDDPGTrainer
|
||||
from ray.rllib.algorithms.ddpg.ddpg import (
|
||||
DDPGConfig,
|
||||
|
|
4
rllib/agents/ddpg/apex.py
Normal file
4
rllib/agents/ddpg/apex.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from ray.rllib.algorithms.apex_ddpg import ( # noqa
|
||||
ApexDDPG as ApexDDPGTrainer,
|
||||
APEX_DDPG_DEFAULT_CONFIG,
|
||||
)
|
6
rllib/agents/ddpg/ddpg.py
Normal file
6
rllib/agents/ddpg/ddpg.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from ray.rllib.algorithms.ddpg import ( # noqa
|
||||
ddpg_tf_policy,
|
||||
ddpg_torch_policy,
|
||||
DDPG as DDPGTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
4
rllib/agents/ddpg/td3.py
Normal file
4
rllib/agents/ddpg/td3.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from ray.rllib.algorithms.td3 import ( # noqa
|
||||
TD3 as TD3Trainer,
|
||||
TD3_DEFAULT_CONFIG,
|
||||
)
|
|
@ -1,3 +1,5 @@
|
|||
import ray.rllib.agents.dqn.apex as apex # noqa
|
||||
import ray.rllib.agents.dqn.simple_q as simple_q # noqa
|
||||
from ray.rllib.algorithms.apex_dqn.apex_dqn import (
|
||||
ApexDQNConfig,
|
||||
ApexDQN as ApexTrainer,
|
||||
|
|
4
rllib/agents/dqn/apex.py
Normal file
4
rllib/agents/dqn/apex.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from ray.rllib.algorithms.apex_dqn import ( # noqa
|
||||
ApexDQN as ApexTrainer,
|
||||
APEX_DEFAULT_CONFIG,
|
||||
)
|
6
rllib/agents/dqn/dqn.py
Normal file
6
rllib/agents/dqn/dqn.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from ray.rllib.algorithms.dqn import ( # noqa
|
||||
dqn_tf_policy,
|
||||
dqn_torch_policy,
|
||||
DQN as DQNTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
4
rllib/agents/dqn/simple_q.py
Normal file
4
rllib/agents/dqn/simple_q.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from ray.rllib.algorithms.simple_q import ( # noqa
|
||||
SimpleQ as SimpleQTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
|
@ -1,4 +1,4 @@
|
|||
from ray.rllib.algorithms.es.es import ESTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.es.es import ES as ESTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.es.es_tf_policy import ESTFPolicy
|
||||
from ray.rllib.algorithms.es.es_torch_policy import ESTorchPolicy
|
||||
|
||||
|
|
|
@ -1,158 +1,13 @@
|
|||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from ray.tune import result as tune_result
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
|
||||
class _MockTrainer(Trainer):
|
||||
"""Mock trainer for use in tests"""
|
||||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return with_common_config(
|
||||
{
|
||||
"mock_error": False,
|
||||
"persistent_error": False,
|
||||
"test_variable": 1,
|
||||
"num_workers": 0,
|
||||
"user_checkpoint_freq": 0,
|
||||
"framework": "tf",
|
||||
}
|
||||
from ray.rllib.algorithms.mock import ( # noqa
|
||||
_MockTrainer,
|
||||
_ParameterTuningTrainer,
|
||||
_SigmoidFakeData,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
return None
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
@override(Trainer)
|
||||
def setup(self, config):
|
||||
# Setup our config: Merge the user-supplied config (which could
|
||||
# be a partial config dict with the class' default).
|
||||
self.config = self.merge_trainer_configs(
|
||||
self.get_default_config(), config, self._allow_unknown_configs
|
||||
deprecation_warning(
|
||||
old="ray.rllib.agents.callbacks",
|
||||
new="ray.rllib.algorithms.callbacks",
|
||||
error=False,
|
||||
)
|
||||
self.config["env"] = self._env_id
|
||||
|
||||
self.validate_config(self.config)
|
||||
self.callbacks = self.config["callbacks"]()
|
||||
|
||||
# Add needed properties.
|
||||
self.info = None
|
||||
self.restored = False
|
||||
|
||||
@override(Trainer)
|
||||
def step(self):
|
||||
if (
|
||||
self.config["mock_error"]
|
||||
and self.iteration == 1
|
||||
and (self.config["persistent_error"] or not self.restored)
|
||||
):
|
||||
raise Exception("mock error")
|
||||
result = dict(
|
||||
episode_reward_mean=10, episode_len_mean=10, timesteps_this_iter=10, info={}
|
||||
)
|
||||
if self.config["user_checkpoint_freq"] > 0 and self.iteration > 0:
|
||||
if self.iteration % self.config["user_checkpoint_freq"] == 0:
|
||||
result.update({tune_result.SHOULD_CHECKPOINT: True})
|
||||
return result
|
||||
|
||||
@override(Trainer)
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self.info, f)
|
||||
return path
|
||||
|
||||
@override(Trainer)
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
with open(checkpoint_path, "rb") as f:
|
||||
info = pickle.load(f)
|
||||
self.info = info
|
||||
self.restored = True
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
def _get_env_id_and_creator(env_specifier, config):
|
||||
# No env to register.
|
||||
return None, None
|
||||
|
||||
def set_info(self, info):
|
||||
self.info = info
|
||||
return info
|
||||
|
||||
def get_info(self, sess=None):
|
||||
return self.info
|
||||
|
||||
|
||||
class _SigmoidFakeData(_MockTrainer):
|
||||
"""Trainer that returns sigmoid learning curves.
|
||||
|
||||
This can be helpful for evaluating early stopping algorithms."""
|
||||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return with_common_config(
|
||||
{
|
||||
"width": 100,
|
||||
"height": 100,
|
||||
"offset": 0,
|
||||
"iter_time": 10,
|
||||
"iter_timesteps": 1,
|
||||
"num_workers": 0,
|
||||
}
|
||||
)
|
||||
|
||||
def step(self):
|
||||
i = max(0, self.iteration - self.config["offset"])
|
||||
v = np.tanh(float(i) / self.config["width"])
|
||||
v *= self.config["height"]
|
||||
return dict(
|
||||
episode_reward_mean=v,
|
||||
episode_len_mean=v,
|
||||
timesteps_this_iter=self.config["iter_timesteps"],
|
||||
time_this_iter_s=self.config["iter_time"],
|
||||
info={},
|
||||
)
|
||||
|
||||
|
||||
class _ParameterTuningTrainer(_MockTrainer):
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return with_common_config(
|
||||
{
|
||||
"reward_amt": 10,
|
||||
"dummy_param": 10,
|
||||
"dummy_param2": 15,
|
||||
"iter_time": 10,
|
||||
"iter_timesteps": 1,
|
||||
"num_workers": 0,
|
||||
}
|
||||
)
|
||||
|
||||
def step(self):
|
||||
return dict(
|
||||
episode_reward_mean=self.config["reward_amt"] * self.iteration,
|
||||
episode_len_mean=self.config["reward_amt"],
|
||||
timesteps_this_iter=self.config["iter_timesteps"],
|
||||
time_this_iter_s=self.config["iter_time"],
|
||||
info={},
|
||||
)
|
||||
|
||||
|
||||
def _trainer_import_failed(trace):
|
||||
"""Returns dummy agent class for if PyTorch etc. is not installed."""
|
||||
|
||||
class _TrainerImportFailed(Trainer):
|
||||
_name = "TrainerImportFailed"
|
||||
|
||||
def setup(self, config):
|
||||
raise ImportError(trace)
|
||||
|
||||
return _TrainerImportFailed
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import ray.rllib.agents.ppo.appo as appo # noqa
|
||||
from ray.rllib.algorithms.ppo.ppo import PPOConfig, PPO as PPOTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy, PPOTF2Policy
|
||||
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
|
|
4
rllib/agents/ppo/appo.py
Normal file
4
rllib/agents/ppo/appo.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from ray.rllib.algorithms.appo import ( # noqa
|
||||
APPO as APPOTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
4
rllib/agents/ppo/ddppo.py
Normal file
4
rllib/agents/ppo/ddppo.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from ray.rllib.algorithms.ddppo import ( # noqa
|
||||
DDPPO as DDPPOTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
6
rllib/agents/ppo/ppo.py
Normal file
6
rllib/agents/ppo/ppo.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from ray.rllib.algorithms.ppo import ( # noqa
|
||||
ppo_tf_policy,
|
||||
ppo_torch_policy,
|
||||
PPO as PPOTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks, MultiCallbacks
|
||||
from ray.rllib.algorithms.callbacks import DefaultCallbacks, MultiCallbacks
|
||||
import ray.rllib.algorithms.dqn as dqn
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
|
|
|
@ -12,14 +12,14 @@ import ray.rllib.algorithms.a3c as a3c
|
|||
import ray.rllib.algorithms.dqn as dqn
|
||||
from ray.rllib.algorithms.bc import BC, BCConfig
|
||||
import ray.rllib.algorithms.pg as pg
|
||||
from ray.rllib.agents.trainer import COMMON_CONFIG
|
||||
from ray.rllib.algorithms.algorithm import COMMON_CONFIG
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
from ray.rllib.examples.parallel_evaluation_and_training import AssertEvalCallback
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
|
||||
|
||||
class TestTrainer(unittest.TestCase):
|
||||
class TestAlgorithm(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init(num_cpus=6)
|
||||
|
@ -35,20 +35,20 @@ class TestTrainer(unittest.TestCase):
|
|||
"""
|
||||
# Given:
|
||||
standard_config = copy.deepcopy(COMMON_CONFIG)
|
||||
trainer = pg.PG(env="CartPole-v0", config=standard_config)
|
||||
algo = pg.PG(env="CartPole-v0", config=standard_config)
|
||||
|
||||
# When (we validate config 2 times).
|
||||
# Try deprecated `Trainer._validate_config()` method (static).
|
||||
trainer._validate_config(standard_config, trainer)
|
||||
algo._validate_config(standard_config, algo)
|
||||
config_v1 = copy.deepcopy(standard_config)
|
||||
# Try new method: `Trainer.validate_config()` (non-static).
|
||||
trainer.validate_config(standard_config)
|
||||
algo.validate_config(standard_config)
|
||||
config_v2 = copy.deepcopy(standard_config)
|
||||
|
||||
# Make sure nothing changed.
|
||||
self.assertEqual(config_v1, config_v2)
|
||||
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
def test_add_delete_policy(self):
|
||||
config = pg.DEFAULT_CONFIG.copy()
|
||||
|
@ -84,9 +84,9 @@ class TestTrainer(unittest.TestCase):
|
|||
)
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
trainer = pg.PG(config=config)
|
||||
pol0 = trainer.get_policy("p0")
|
||||
r = trainer.train()
|
||||
algo = pg.PG(config=config)
|
||||
pol0 = algo.get_policy("p0")
|
||||
r = algo.train()
|
||||
self.assertTrue("p0" in r["info"][LEARNER_INFO])
|
||||
for i in range(1, 3):
|
||||
|
||||
|
@ -95,21 +95,21 @@ class TestTrainer(unittest.TestCase):
|
|||
|
||||
# Add a new policy.
|
||||
pid = f"p{i}"
|
||||
new_pol = trainer.add_policy(
|
||||
new_pol = algo.add_policy(
|
||||
pid,
|
||||
trainer.get_default_policy_class(config),
|
||||
algo.get_default_policy_class(config),
|
||||
# Test changing the mapping fn.
|
||||
policy_mapping_fn=new_mapping_fn,
|
||||
# Change the list of policies to train.
|
||||
policies_to_train=[f"p{i}", f"p{i-1}"],
|
||||
)
|
||||
pol_map = trainer.workers.local_worker().policy_map
|
||||
pol_map = algo.workers.local_worker().policy_map
|
||||
self.assertTrue(new_pol is not pol0)
|
||||
for j in range(i + 1):
|
||||
self.assertTrue(f"p{j}" in pol_map)
|
||||
self.assertTrue(len(pol_map) == i + 1)
|
||||
trainer.train()
|
||||
checkpoint = trainer.save()
|
||||
algo.train()
|
||||
checkpoint = algo.save()
|
||||
|
||||
# Test restoring from the checkpoint (which has more policies
|
||||
# than what's defined in the config dict).
|
||||
|
@ -124,7 +124,7 @@ class TestTrainer(unittest.TestCase):
|
|||
all(test.evaluation_workers.foreach_worker(_has_policy))
|
||||
)
|
||||
|
||||
# Make sure trainer can continue training the restored policy.
|
||||
# Make sure algorithm can continue training the restored policy.
|
||||
pol0 = test.get_policy("p0")
|
||||
test.train()
|
||||
# Test creating an action with the added (and restored) policy.
|
||||
|
@ -134,9 +134,9 @@ class TestTrainer(unittest.TestCase):
|
|||
self.assertTrue(pol0.action_space.contains(a))
|
||||
test.stop()
|
||||
|
||||
# Delete all added policies again from trainer.
|
||||
# Delete all added policies again from Algorithm.
|
||||
for i in range(2, 0, -1):
|
||||
trainer.remove_policy(
|
||||
algo.remove_policy(
|
||||
f"p{i}",
|
||||
# Note that the complete signature of a policy_mapping_fn
|
||||
# is: `agent_id, episode, worker, **kwargs`.
|
||||
|
@ -144,7 +144,7 @@ class TestTrainer(unittest.TestCase):
|
|||
policies_to_train=[f"p{i - 1}"],
|
||||
)
|
||||
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
def test_evaluation_option(self):
|
||||
# Use a custom callback that asserts that we are running the
|
||||
|
@ -164,18 +164,18 @@ class TestTrainer(unittest.TestCase):
|
|||
)
|
||||
|
||||
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
||||
trainer = config.build()
|
||||
algo = config.build()
|
||||
# Given evaluation_interval=2, r0, r2, r4 should not contain
|
||||
# evaluation metrics, while r1, r3 should.
|
||||
r0 = trainer.train()
|
||||
r0 = algo.train()
|
||||
print(r0)
|
||||
r1 = trainer.train()
|
||||
r1 = algo.train()
|
||||
print(r1)
|
||||
r2 = trainer.train()
|
||||
r2 = algo.train()
|
||||
print(r2)
|
||||
r3 = trainer.train()
|
||||
r3 = algo.train()
|
||||
print(r3)
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
self.assertFalse("evaluation" in r0)
|
||||
self.assertTrue("evaluation" in r1)
|
||||
|
@ -202,13 +202,13 @@ class TestTrainer(unittest.TestCase):
|
|||
.callbacks(callbacks_class=AssertEvalCallback)
|
||||
)
|
||||
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
||||
trainer = config.build()
|
||||
algo = config.build()
|
||||
# Should always see latest available eval results.
|
||||
r0 = trainer.train()
|
||||
r1 = trainer.train()
|
||||
r2 = trainer.train()
|
||||
r3 = trainer.train()
|
||||
trainer.stop()
|
||||
r0 = algo.train()
|
||||
r1 = algo.train()
|
||||
r2 = algo.train()
|
||||
r3 = algo.train()
|
||||
algo.stop()
|
||||
|
||||
# Eval results are not available at step 0.
|
||||
# But step 3 should still have it, even though no eval was
|
||||
|
@ -228,26 +228,26 @@ class TestTrainer(unittest.TestCase):
|
|||
)
|
||||
|
||||
for _ in framework_iterator(frameworks=("tf", "torch")):
|
||||
# Setup trainer w/o evaluation worker set and still call
|
||||
# Setup algorithm w/o evaluation worker set and still call
|
||||
# evaluate() -> Expect error.
|
||||
trainer_wo_env_on_driver = config.build()
|
||||
algo_wo_env_on_driver = config.build()
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Cannot evaluate w/o an evaluation worker set",
|
||||
trainer_wo_env_on_driver.evaluate,
|
||||
algo_wo_env_on_driver.evaluate,
|
||||
)
|
||||
trainer_wo_env_on_driver.stop()
|
||||
algo_wo_env_on_driver.stop()
|
||||
|
||||
# Try again using `create_env_on_driver=True`.
|
||||
# This force-adds the env on the local-worker, so this Trainer
|
||||
# can `evaluate` even though it doesn't have an evaluation-worker
|
||||
# set.
|
||||
config.create_env_on_local_worker = True
|
||||
trainer_w_env_on_driver = config.build()
|
||||
results = trainer_w_env_on_driver.evaluate()
|
||||
algo_w_env_on_driver = config.build()
|
||||
results = algo_w_env_on_driver.evaluate()
|
||||
assert "evaluation" in results
|
||||
assert "episode_reward_mean" in results["evaluation"]
|
||||
trainer_w_env_on_driver.stop()
|
||||
algo_w_env_on_driver.stop()
|
||||
config.create_env_on_local_worker = False
|
||||
|
||||
def test_space_inference_from_remote_workers(self):
|
||||
|
@ -264,20 +264,20 @@ class TestTrainer(unittest.TestCase):
|
|||
# No env on driver -> expect longer build time due to space
|
||||
# lookup from remote worker.
|
||||
t0 = time.time()
|
||||
trainer = config.build()
|
||||
algo = config.build()
|
||||
w_lookup = time.time() - t0
|
||||
print(f"No env on learner: {w_lookup}sec")
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
# Env on driver -> expect shorted build time due to no space
|
||||
# lookup required from remote worker.
|
||||
config.create_env_on_local_worker = True
|
||||
t0 = time.time()
|
||||
trainer = config.build()
|
||||
algo = config.build()
|
||||
wo_lookup = time.time() - t0
|
||||
print(f"Env on learner: {wo_lookup}sec")
|
||||
self.assertLess(wo_lookup, w_lookup)
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
# Spaces given -> expect shorter build time due to no space
|
||||
# lookup required from remote worker.
|
||||
|
@ -287,11 +287,11 @@ class TestTrainer(unittest.TestCase):
|
|||
action_space=env.action_space,
|
||||
)
|
||||
t0 = time.time()
|
||||
trainer = config.build()
|
||||
algo = config.build()
|
||||
wo_lookup = time.time() - t0
|
||||
print(f"Spaces given manually in config: {wo_lookup}sec")
|
||||
self.assertLess(wo_lookup, w_lookup)
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
def test_worker_validation_time(self):
|
||||
"""Tests the time taken by `validate_workers_after_construction=True`."""
|
||||
|
@ -302,17 +302,17 @@ class TestTrainer(unittest.TestCase):
|
|||
# >> 1 workers.
|
||||
config.num_workers = 1
|
||||
t0 = time.time()
|
||||
trainer = config.build()
|
||||
algo = config.build()
|
||||
total_time_1 = time.time() - t0
|
||||
print(f"Validating w/ 1 worker: {total_time_1}sec")
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
config.num_workers = 5
|
||||
t0 = time.time()
|
||||
trainer = config.build()
|
||||
algo = config.build()
|
||||
total_time_5 = time.time() - t0
|
||||
print(f"Validating w/ 5 workers: {total_time_5}sec")
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
check(total_time_5 / total_time_1, 1.0, atol=1.0)
|
||||
|
||||
|
@ -343,9 +343,9 @@ class TestTrainer(unittest.TestCase):
|
|||
.offline_data(input_=[input_file])
|
||||
)
|
||||
|
||||
bc_trainer = BC(config=offline_rl_config)
|
||||
bc_trainer.train()
|
||||
bc_trainer.stop()
|
||||
bc = BC(config=offline_rl_config)
|
||||
bc.train()
|
||||
bc.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -3,7 +3,7 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
from ray.rllib.agents.registry import get_trainer_class
|
||||
from ray.rllib.algorithms.registry import get_algorithm_class
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
@ -66,7 +66,7 @@ class IgnoresWorkerFailure(unittest.TestCase):
|
|||
|
||||
def _do_test_fault_ignore(self, alg: str, config: dict):
|
||||
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
||||
agent_cls = get_trainer_class(alg)
|
||||
agent_cls = get_algorithm_class(alg)
|
||||
|
||||
# Test fault handling
|
||||
config["num_workers"] = 2
|
||||
|
@ -82,7 +82,7 @@ class IgnoresWorkerFailure(unittest.TestCase):
|
|||
|
||||
def _do_test_fault_fatal(self, alg, config):
|
||||
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
||||
agent_cls = get_trainer_class(alg)
|
||||
agent_cls = get_algorithm_class(alg)
|
||||
|
||||
# Test raises real error when out of workers
|
||||
config["num_workers"] = 2
|
||||
|
@ -97,7 +97,7 @@ class IgnoresWorkerFailure(unittest.TestCase):
|
|||
|
||||
def _do_test_fault_fatal_but_recreate(self, alg, config):
|
||||
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
||||
agent_cls = get_trainer_class(alg)
|
||||
agent_cls = get_algorithm_class(alg)
|
||||
|
||||
# Test raises real error when out of workers
|
||||
config["num_workers"] = 2
|
||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,8 @@
|
|||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Algorithm",
|
||||
"AlgorithmConfig",
|
||||
]
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import math
|
||||
from typing import Optional
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.a3c.a3c import A3CConfig, A3C
|
||||
from ray.rllib.execution.common import (
|
||||
STEPS_TRAINED_COUNTER,
|
||||
|
@ -22,9 +22,9 @@ from ray.rllib.utils.metrics import (
|
|||
WORKER_UPDATE_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.typing import (
|
||||
PartialTrainerConfigDict,
|
||||
PartialAlgorithmConfigDict,
|
||||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -39,7 +39,7 @@ class A2CConfig(A3CConfig):
|
|||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_rollout_workers=2)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -62,7 +62,7 @@ class A2CConfig(A3CConfig):
|
|||
|
||||
def __init__(self):
|
||||
"""Initializes a A2CConfig instance."""
|
||||
super().__init__(trainer_class=A2C)
|
||||
super().__init__(algo_class=A2C)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -93,7 +93,7 @@ class A2CConfig(A3CConfig):
|
|||
memory. To enable, set this to a value less than the train batch size.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -107,11 +107,11 @@ class A2CConfig(A3CConfig):
|
|||
class A2C(A3C):
|
||||
@classmethod
|
||||
@override(A3C)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return A2CConfig().to_dict()
|
||||
|
||||
@override(A3C)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
|
@ -130,8 +130,8 @@ class A2C(A3C):
|
|||
"Otherwise, microbatches of desired size won't be achievable."
|
||||
)
|
||||
|
||||
@override(Trainer)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
@override(Algorithm)
|
||||
def setup(self, config: PartialAlgorithmConfigDict):
|
||||
super().setup(config)
|
||||
|
||||
# Create a microbatch variable for collecting gradients on microbatches'.
|
||||
|
@ -146,11 +146,11 @@ class A2C(A3C):
|
|||
|
||||
@override(A3C)
|
||||
def training_step(self) -> ResultDict:
|
||||
# W/o microbatching: Identical to Trainer's default implementation.
|
||||
# Only difference to a default Trainer being the value function loss term
|
||||
# W/o microbatching: Identical to Algorithm's default implementation.
|
||||
# Only difference to a default Algorithm being the value function loss term
|
||||
# and its value computations alongside each action.
|
||||
if self.config["microbatch_size"] is None:
|
||||
return Trainer.training_step(self)
|
||||
return Algorithm.training_step(self)
|
||||
|
||||
# In microbatch mode, we want to compute gradients on experience
|
||||
# microbatches, average a number of these microbatches, and then
|
||||
|
|
|
@ -2,8 +2,8 @@ import logging
|
|||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.execution.parallel_requests import (
|
||||
AsyncRequestsManager,
|
||||
|
@ -23,15 +23,15 @@ from ray.rllib.utils.metrics import (
|
|||
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
|
||||
from ray.rllib.utils.typing import (
|
||||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
PartialTrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
PartialAlgorithmConfigDict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class A3CConfig(TrainerConfig):
|
||||
"""Defines a configuration class from which a A3C Trainer can be built.
|
||||
class A3CConfig(AlgorithmConfig):
|
||||
"""Defines a configuration class from which a A3C Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray import tune
|
||||
|
@ -39,7 +39,7 @@ class A3CConfig(TrainerConfig):
|
|||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_rollout_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -60,9 +60,9 @@ class A3CConfig(TrainerConfig):
|
|||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes a A3CConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or A3C)
|
||||
super().__init__(algo_class=algo_class or A3C)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -78,7 +78,7 @@ class A3CConfig(TrainerConfig):
|
|||
self.entropy_coeff_schedule = None
|
||||
self.sample_async = True
|
||||
|
||||
# Override some of TrainerConfig's default values with PPO-specific values.
|
||||
# Override some of AlgorithmConfig's default values with PPO-specific values.
|
||||
self.rollout_fragment_length = 10
|
||||
self.lr = 0.0001
|
||||
# Min time (in seconds) per reporting.
|
||||
|
@ -89,7 +89,7 @@ class A3CConfig(TrainerConfig):
|
|||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
@override(TrainerConfig)
|
||||
@override(AlgorithmConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
|
@ -125,7 +125,7 @@ class A3CConfig(TrainerConfig):
|
|||
to async buffering of batches.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -152,21 +152,21 @@ class A3CConfig(TrainerConfig):
|
|||
return self
|
||||
|
||||
|
||||
class A3C(Trainer):
|
||||
class A3C(Algorithm):
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return A3CConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
@override(Algorithm)
|
||||
def setup(self, config: PartialAlgorithmConfigDict):
|
||||
super().setup(config)
|
||||
self._worker_manager = AsyncRequestsManager(
|
||||
self.workers.remote_workers(), max_remote_requests_in_flight_per_worker=1
|
||||
)
|
||||
|
||||
@override(Trainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
@override(Algorithm)
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
|
@ -175,8 +175,8 @@ class A3C(Trainer):
|
|||
if config["num_workers"] <= 0 and config["sample_async"]:
|
||||
raise ValueError("`num_workers` for A3C must be >= 1!")
|
||||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy
|
||||
|
||||
|
@ -256,7 +256,7 @@ class A3C(Trainer):
|
|||
|
||||
return learner_info_builder.finalize()
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def on_worker_failures(
|
||||
self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]
|
||||
):
|
||||
|
|
2538
rllib/algorithms/algorithm.py
Normal file
2538
rllib/algorithms/algorithm.py
Normal file
File diff suppressed because it is too large
Load diff
1252
rllib/algorithms/algorithm_config.py
Normal file
1252
rllib/algorithms/algorithm_config.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -11,7 +11,7 @@ import ray
|
|||
from ray.actor import ActorHandle
|
||||
from ray.rllib.algorithms.alpha_star.distributed_learners import DistributedLearners
|
||||
from ray.rllib.algorithms.alpha_star.league_builder import AlphaStarLeagueBuilder
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
import ray.rllib.algorithms.appo.appo as appo
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.execution.parallel_requests import (
|
||||
|
@ -36,10 +36,10 @@ from ray.rllib.utils.metrics import (
|
|||
)
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.typing import (
|
||||
PartialTrainerConfigDict,
|
||||
PartialAlgorithmConfigDict,
|
||||
PolicyID,
|
||||
PolicyState,
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
ResultDict,
|
||||
)
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
|
@ -47,7 +47,7 @@ from ray.util.timer import _Timer
|
|||
|
||||
|
||||
class AlphaStarConfig(appo.APPOConfig):
|
||||
"""Defines a configuration class from which an AlphaStar Trainer can be built.
|
||||
"""Defines a configuration class from which an AlphaStar Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.alpha_star import AlphaStarConfig
|
||||
|
@ -55,7 +55,7 @@ class AlphaStarConfig(appo.APPOConfig):
|
|||
... .resources(num_gpus=4)\
|
||||
... .rollouts(num_rollout_workers=64)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -78,9 +78,9 @@ class AlphaStarConfig(appo.APPOConfig):
|
|||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes a AlphaStarConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or AlphaStar)
|
||||
super().__init__(algo_class=algo_class or AlphaStar)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -141,7 +141,6 @@ class AlphaStarConfig(appo.APPOConfig):
|
|||
# values.
|
||||
self.vtrace_drop_last_ts = False
|
||||
self.min_time_s_per_iteration = 2
|
||||
self._disable_execution_plan_api = True
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
|
@ -194,14 +193,14 @@ class AlphaStarConfig(appo.APPOConfig):
|
|||
`ray.rllib.algorithms.alpha_star.league_builder::AlphaStarLeagueBuilder`
|
||||
(used by default by this algo) as an example.
|
||||
max_num_policies_to_train: The maximum number of trainable policies for this
|
||||
Trainer. Each trainable policy will exist as a independent remote actor,
|
||||
co-located with a replay buffer. This is besides its existence inside
|
||||
the RolloutWorkers for training and evaluation. Set to None for
|
||||
Algorithm. Each trainable policy will exist as a independent remote
|
||||
actor, co-located with a replay buffer. This is besides its existence
|
||||
inside the RolloutWorkers for training and evaluation. Set to None for
|
||||
automatically inferring this value from the number of trainable
|
||||
policies found in the `multiagent` config.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -244,7 +243,7 @@ class AlphaStar(appo.APPO):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls.get_default_config(), **config)
|
||||
# Construct a dummy LeagueBuilder, such that it gets the opportunity to
|
||||
|
@ -311,11 +310,11 @@ class AlphaStar(appo.APPO):
|
|||
|
||||
@classmethod
|
||||
@override(appo.APPO)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return AlphaStarConfig().to_dict()
|
||||
|
||||
@override(appo.APPO)
|
||||
def validate_config(self, config: TrainerConfigDict):
|
||||
def validate_config(self, config: AlgorithmConfigDict):
|
||||
# Create the LeagueBuilder object, allowing it to build the multiagent
|
||||
# config as well.
|
||||
self.league_builder = from_config(
|
||||
|
@ -324,7 +323,7 @@ class AlphaStar(appo.APPO):
|
|||
super().validate_config(config)
|
||||
|
||||
@override(appo.APPO)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
def setup(self, config: PartialAlgorithmConfigDict):
|
||||
# Call super's setup to validate config, create RolloutWorkers
|
||||
# (train and eval), etc..
|
||||
num_gpus_saved = config["num_gpus"]
|
||||
|
@ -403,7 +402,7 @@ class AlphaStar(appo.APPO):
|
|||
ray_wait_timeout_s=self.config["timeout_s_learner_manager"],
|
||||
)
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def step(self) -> ResultDict:
|
||||
# Perform a full step (including evaluation).
|
||||
result = super().step()
|
||||
|
@ -414,7 +413,7 @@ class AlphaStar(appo.APPO):
|
|||
|
||||
return result
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def training_step(self) -> ResultDict:
|
||||
# Trigger asynchronous rollouts on all RolloutWorkers.
|
||||
# - Rollout results are sent directly to correct replay buffer
|
||||
|
@ -495,7 +494,7 @@ class AlphaStar(appo.APPO):
|
|||
|
||||
return train_infos
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def add_policy(
|
||||
self,
|
||||
policy_id: PolicyID,
|
||||
|
@ -503,7 +502,7 @@ class AlphaStar(appo.APPO):
|
|||
*,
|
||||
observation_space: Optional[gym.spaces.Space] = None,
|
||||
action_space: Optional[gym.spaces.Space] = None,
|
||||
config: Optional[PartialTrainerConfigDict] = None,
|
||||
config: Optional[PartialAlgorithmConfigDict] = None,
|
||||
policy_state: Optional[PolicyState] = None,
|
||||
**kwargs,
|
||||
) -> Policy:
|
||||
|
@ -536,7 +535,7 @@ class AlphaStar(appo.APPO):
|
|||
|
||||
return new_policy
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def cleanup(self) -> None:
|
||||
super().cleanup()
|
||||
# Stop all policy- and replay actors.
|
||||
|
|
|
@ -3,11 +3,11 @@ from typing import Any, Dict, List, Optional, Type
|
|||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
from ray.rllib.utils.actors import create_colocated_actors
|
||||
from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
|
||||
from ray.rllib.utils.typing import PolicyID, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import PolicyID, AlgorithmConfigDict
|
||||
|
||||
|
||||
class DistributedLearners:
|
||||
|
@ -30,7 +30,7 @@ class DistributedLearners:
|
|||
"""Initializes a DistributedLearners instance.
|
||||
|
||||
Args:
|
||||
config: The Trainer's config dict.
|
||||
config: The Algorithm's config dict.
|
||||
max_num_policies_to_train: Maximum number of policies that will ever be
|
||||
trainable. For these policies, we'll have to create remote
|
||||
policy actors, distributed across n "learner shards".
|
||||
|
@ -161,7 +161,7 @@ class _Shard:
|
|||
# Merge the policies config overrides with the main config.
|
||||
# Also, adjust `num_gpus` (to indicate an individual policy's
|
||||
# num_gpus, not the total number of GPUs).
|
||||
cfg = Trainer.merge_trainer_configs(
|
||||
cfg = Algorithm.merge_trainer_configs(
|
||||
self.config,
|
||||
dict(policy_spec.config, **{"num_gpus": self.num_gpus_per_policy}),
|
||||
)
|
||||
|
@ -207,7 +207,7 @@ class _Shard:
|
|||
self,
|
||||
policy_id: PolicyID,
|
||||
policy_spec: PolicySpec,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
):
|
||||
assert self.replay_actor is None
|
||||
assert len(self.policy_actors) == 0
|
||||
|
|
|
@ -5,26 +5,26 @@ import numpy as np
|
|||
import re
|
||||
from typing import Any, DefaultDict, Dict
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI, override
|
||||
from ray.rllib.utils.numpy import softmax
|
||||
from ray.rllib.utils.typing import PolicyID, TrainerConfigDict, ResultDict
|
||||
from ray.rllib.utils.typing import PolicyID, AlgorithmConfigDict, ResultDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ExperimentalAPI
|
||||
class LeagueBuilder(metaclass=ABCMeta):
|
||||
def __init__(self, trainer: Trainer, trainer_config: TrainerConfigDict):
|
||||
def __init__(self, trainer: Algorithm, trainer_config: AlgorithmConfigDict):
|
||||
"""Initializes a LeagueBuilder instance.
|
||||
|
||||
Args:
|
||||
trainer: The Trainer object by which this league builder is used.
|
||||
Trainer calls `build_league()` after each training step.
|
||||
trainer: The Algorithm object by which this league builder is used.
|
||||
Algorithm calls `build_league()` after each training step.
|
||||
trainer_config: The (not yet validated) config dict to be
|
||||
used on the Trainer. Child classes of `LeagueBuilder`
|
||||
used on the Algorithm. Child classes of `LeagueBuilder`
|
||||
should preprocess this to add e.g. multiagent settings
|
||||
to this config.
|
||||
"""
|
||||
|
@ -67,8 +67,8 @@ class NoLeagueBuilder(LeagueBuilder):
|
|||
class AlphaStarLeagueBuilder(LeagueBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
trainer: Trainer,
|
||||
trainer_config: TrainerConfigDict,
|
||||
trainer: Algorithm,
|
||||
trainer_config: AlgorithmConfigDict,
|
||||
num_random_policies: int = 2,
|
||||
num_learning_league_exploiters: int = 4,
|
||||
num_learning_main_exploiters: int = 4,
|
||||
|
@ -86,11 +86,11 @@ class AlphaStarLeagueBuilder(LeagueBuilder):
|
|||
M: Main self-play (main vs main).
|
||||
|
||||
Args:
|
||||
trainer: The Trainer object by which this league builder is used.
|
||||
Trainer calls `build_league()` after each training step to reconfigure
|
||||
trainer: The Algorithm object by which this league builder is used.
|
||||
Algorithm calls `build_league()` after each training step to reconfigure
|
||||
the league structure (e.g. to add/remove policies).
|
||||
trainer_config: The (not yet validated) config dict to be
|
||||
used on the Trainer. Child classes of `LeagueBuilder`
|
||||
used on the Algorithm. Child classes of `LeagueBuilder`
|
||||
should preprocess this to add e.g. multiagent settings
|
||||
to this config.
|
||||
num_random_policies: The number of random policies to add to the
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import logging
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.algorithms.callbacks import DefaultCallbacks
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.replay_ops import (
|
||||
SimpleReplayBuffer,
|
||||
|
@ -37,7 +37,7 @@ from ray.rllib.utils.metrics import (
|
|||
SYNCH_WORKER_WEIGHTS_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
|
||||
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import ResultDict, AlgorithmConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
from ray.rllib.algorithms.alpha_zero.alpha_zero_policy import AlphaZeroPolicy
|
||||
|
@ -63,8 +63,8 @@ class AlphaZeroDefaultCallbacks(DefaultCallbacks):
|
|||
episode.user_data["initial_state"] = state
|
||||
|
||||
|
||||
class AlphaZeroConfig(TrainerConfig):
|
||||
"""Defines a configuration class from which an AlphaZero Trainer can be built.
|
||||
class AlphaZeroConfig(AlgorithmConfig):
|
||||
"""Defines a configuration class from which an AlphaZero Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.alpha_zero import AlphaZeroConfig
|
||||
|
@ -72,7 +72,7 @@ class AlphaZeroConfig(TrainerConfig):
|
|||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -95,9 +95,9 @@ class AlphaZeroConfig(TrainerConfig):
|
|||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes a PPOConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or AlphaZero)
|
||||
super().__init__(algo_class=algo_class or AlphaZero)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -134,7 +134,7 @@ class AlphaZeroConfig(TrainerConfig):
|
|||
"num_init_rewards": 100,
|
||||
}
|
||||
|
||||
# Override some of TrainerConfig's default values with AlphaZero-specific
|
||||
# Override some of AlgorithmConfig's default values with AlphaZero-specific
|
||||
# values.
|
||||
self.framework_str = "torch"
|
||||
self.callbacks_class = AlphaZeroDefaultCallbacks
|
||||
|
@ -154,7 +154,7 @@ class AlphaZeroConfig(TrainerConfig):
|
|||
|
||||
self.buffer_size = DEPRECATED_VALUE
|
||||
|
||||
@override(TrainerConfig)
|
||||
@override(AlgorithmConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
|
@ -220,7 +220,7 @@ class AlphaZeroConfig(TrainerConfig):
|
|||
from: https://arxiv.org/pdf/1807.01672.pdf
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -271,7 +271,7 @@ class AlphaZeroPolicyWrapperClass(AlphaZeroPolicy):
|
|||
model = ModelCatalog.get_model_v2(
|
||||
obs_space, action_space, action_space.n, config["model"], "torch"
|
||||
)
|
||||
_, env_creator = Trainer._get_env_id_and_creator(config["env"], config)
|
||||
_, env_creator = Algorithm._get_env_id_and_creator(config["env"], config)
|
||||
if config["ranked_rewards"]["enable"]:
|
||||
# if r2 is enabled, tne env is wrapped to include a rewards buffer
|
||||
# used to normalize rewards
|
||||
|
@ -302,23 +302,23 @@ class AlphaZeroPolicyWrapperClass(AlphaZeroPolicy):
|
|||
)
|
||||
|
||||
|
||||
class AlphaZero(Trainer):
|
||||
class AlphaZero(Algorithm):
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return AlphaZeroConfig().to_dict()
|
||||
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
"""Checks and updates the config based on settings."""
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
validate_buffer_config(config)
|
||||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||
return AlphaZeroPolicyWrapperClass
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def training_step(self) -> ResultDict:
|
||||
"""TODO:
|
||||
|
||||
|
@ -374,9 +374,9 @@ class AlphaZero(Trainer):
|
|||
return train_results
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def execution_plan(
|
||||
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
||||
workers: WorkerSet, config: AlgorithmConfigDict, **kwargs
|
||||
) -> LocalIterator[dict]:
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib.agents import Trainer
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQN
|
||||
from ray.rllib.algorithms.ddpg.ddpg import DDPG, DDPGConfig
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.typing import PartialTrainerConfigDict
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict
|
||||
from ray.rllib.utils.typing import PartialAlgorithmConfigDict
|
||||
from ray.rllib.utils.typing import ResultDict
|
||||
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
|
||||
from ray.util.iter import LocalIterator
|
||||
|
@ -44,9 +44,9 @@ class ApexDDPGConfig(DDPGConfig):
|
|||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes an ApexDDPGConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or ApexDDPG)
|
||||
super().__init__(algo_class=algo_class or ApexDDPG)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -174,11 +174,11 @@ class ApexDDPGConfig(DDPGConfig):
|
|||
class ApexDDPG(DDPG, ApexDQN):
|
||||
@classmethod
|
||||
@override(DDPG)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return ApexDDPGConfig().to_dict()
|
||||
|
||||
@override(DDPG)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
def setup(self, config: PartialAlgorithmConfigDict):
|
||||
return ApexDQN.setup(self, config)
|
||||
|
||||
@override(DDPG)
|
||||
|
@ -186,7 +186,7 @@ class ApexDDPG(DDPG, ApexDQN):
|
|||
"""Use APEX-DQN's training iteration function."""
|
||||
return ApexDQN.training_step(self)
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def on_worker_failures(
|
||||
self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]
|
||||
):
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Distributed Prioritized Experience Replay (Ape-X)
|
||||
=================================================
|
||||
|
||||
This file defines a DQN trainer using the Ape-X architecture.
|
||||
This file defines a DQN algorithm using the Ape-X architecture.
|
||||
|
||||
Ape-X uses a single GPU learner and many CPU workers for experience collection.
|
||||
Experience collection can scale to hundreds of CPU workers due to the
|
||||
|
@ -21,7 +21,7 @@ from typing import Dict, List, Type, Optional, Callable
|
|||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib import Policy
|
||||
from ray.rllib.agents import Trainer
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.dqn.dqn import DQN, DQNConfig
|
||||
from ray.rllib.algorithms.dqn.learner_thread import LearnerThread
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
|
@ -48,16 +48,17 @@ from ray.rllib.utils.metrics import (
|
|||
TARGET_NET_UPDATE_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.typing import (
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
ResultDict,
|
||||
PartialTrainerConfigDict,
|
||||
PartialAlgorithmConfigDict,
|
||||
)
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
from ray.util.ml_utils.dict import merge_dicts
|
||||
|
||||
|
||||
class ApexDQNConfig(DQNConfig):
|
||||
"""Defines a configuration class from which an ApexDQN Trainer can be built.
|
||||
"""Defines a configuration class from which an ApexDQN Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQNConfig
|
||||
|
@ -75,9 +76,9 @@ class ApexDQNConfig(DQNConfig):
|
|||
>>> .resources(num_gpus=1)\
|
||||
>>> .rollouts(num_rollout_workers=30)\
|
||||
>>> .environment("CartPole-v1")
|
||||
>>> trainer = config.build()
|
||||
>>> algo = config.build()
|
||||
>>> while True:
|
||||
>>> trainer.train()
|
||||
>>> algo.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQNConfig
|
||||
|
@ -120,9 +121,9 @@ class ApexDQNConfig(DQNConfig):
|
|||
>>> .exploration(exploration_config=explore_config)
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes a ApexConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or ApexDQN)
|
||||
super().__init__(algo_class=algo_class or ApexDQN)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -350,8 +351,8 @@ class ApexDQNConfig(DQNConfig):
|
|||
|
||||
|
||||
class ApexDQN(DQN):
|
||||
@override(Trainer)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
@override(Trainable)
|
||||
def setup(self, config: PartialAlgorithmConfigDict):
|
||||
super().setup(config)
|
||||
|
||||
# Shortcut: If execution_plan, thread and buffer will be created in there.
|
||||
|
@ -423,7 +424,7 @@ class ApexDQN(DQN):
|
|||
|
||||
@classmethod
|
||||
@override(DQN)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return ApexDQNConfig().to_dict()
|
||||
|
||||
@override(DQN)
|
||||
|
@ -641,7 +642,7 @@ class ApexDQN(DQN):
|
|||
STEPS_TRAINED_COUNTER
|
||||
]
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def on_worker_failures(
|
||||
self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]
|
||||
):
|
||||
|
@ -654,7 +655,7 @@ class ApexDQN(DQN):
|
|||
self._sampling_actor_manager.remove_workers(removed_workers)
|
||||
self._sampling_actor_manager.add_workers(new_workers)
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
|
||||
result = super()._compile_iteration_results(
|
||||
step_ctx=step_ctx, iteration_results=iteration_results
|
||||
|
@ -679,7 +680,7 @@ class ApexDQN(DQN):
|
|||
return result
|
||||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls.get_default_config(), **config)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Asynchronous Proximal Policy Optimization (APPO)
|
||||
================================================
|
||||
|
||||
This file defines the distributed Trainer class for the asynchronous version
|
||||
This file defines the distributed Algorithm class for the asynchronous version
|
||||
of proximal policy optimization (APPO).
|
||||
See `appo_[tf|torch]_policy.py` for the definition of the policy loss.
|
||||
|
||||
|
@ -24,16 +24,16 @@ from ray.rllib.utils.metrics import (
|
|||
)
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.typing import (
|
||||
PartialTrainerConfigDict,
|
||||
PartialAlgorithmConfigDict,
|
||||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APPOConfig(ImpalaConfig):
|
||||
"""Defines a configuration class from which an APPO Trainer can be built.
|
||||
"""Defines a configuration class from which an APPO Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.appo import APPOConfig
|
||||
|
@ -41,7 +41,7 @@ class APPOConfig(ImpalaConfig):
|
|||
... .resources(num_gpus=1)\
|
||||
... .rollouts(num_rollout_workers=16)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -64,9 +64,9 @@ class APPOConfig(ImpalaConfig):
|
|||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes a APPOConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or APPO)
|
||||
super().__init__(algo_class=algo_class or APPO)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -141,7 +141,7 @@ class APPOConfig(ImpalaConfig):
|
|||
`kl_coeff` automatically).
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -238,12 +238,12 @@ class APPO(Impala):
|
|||
|
||||
@classmethod
|
||||
@override(Impala)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return APPOConfig().to_dict()
|
||||
|
||||
@override(Impala)
|
||||
def get_default_policy_class(
|
||||
self, config: PartialTrainerConfigDict
|
||||
self, config: PartialAlgorithmConfigDict
|
||||
) -> Optional[Type[Policy]]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy
|
||||
|
|
|
@ -10,7 +10,7 @@ import time
|
|||
from typing import Optional
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents import Trainer, TrainerConfig
|
||||
from ray.rllib.algorithms import Algorithm, AlgorithmConfig
|
||||
from ray.rllib.algorithms.ars.ars_tf_policy import ARSTFPolicy
|
||||
from ray.rllib.algorithms.es import optimizers, utils
|
||||
from ray.rllib.algorithms.es.es_tf_policy import rollout
|
||||
|
@ -26,7 +26,7 @@ from ray.rllib.utils.metrics import (
|
|||
NUM_ENV_STEPS_TRAINED,
|
||||
)
|
||||
from ray.rllib.utils.torch_utils import set_torch_seed
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -43,8 +43,8 @@ Result = namedtuple(
|
|||
)
|
||||
|
||||
|
||||
class ARSConfig(TrainerConfig):
|
||||
"""Defines a configuration class from which an ARS Trainer can be built.
|
||||
class ARSConfig(AlgorithmConfig):
|
||||
"""Defines a configuration class from which an ARS Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.ars import ARSConfig
|
||||
|
@ -52,7 +52,7 @@ class ARSConfig(TrainerConfig):
|
|||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_rollout_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -77,7 +77,7 @@ class ARSConfig(TrainerConfig):
|
|||
|
||||
def __init__(self):
|
||||
"""Initializes a ARSConfig instance."""
|
||||
super().__init__(trainer_class=ARS)
|
||||
super().__init__(algo_class=ARS)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -93,10 +93,10 @@ class ARSConfig(TrainerConfig):
|
|||
self.report_length = 10
|
||||
self.offset = 0
|
||||
|
||||
# Override some of TrainerConfig's default values with ARS-specific values.
|
||||
# Override some of AlgorithmConfig's default values with ARS-specific values.
|
||||
self.num_workers = 2
|
||||
self.observation_filter = "MeanStdFilter"
|
||||
# ARS will use Trainer's evaluation WorkerSet (if evaluation_interval > 0).
|
||||
# ARS will use Algorithm's evaluation WorkerSet (if evaluation_interval > 0).
|
||||
# Therefore, we must be careful not to use more than 1 env per eval worker
|
||||
# (would break ARSPolicy's compute_single_action method) and to not do
|
||||
# obs-filtering.
|
||||
|
@ -106,7 +106,7 @@ class ARSConfig(TrainerConfig):
|
|||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
@override(TrainerConfig)
|
||||
@override(AlgorithmConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
|
@ -139,7 +139,7 @@ class ARSConfig(TrainerConfig):
|
|||
from humanoid) during rollouts.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -312,16 +312,16 @@ def get_policy_class(config):
|
|||
return policy_cls
|
||||
|
||||
|
||||
class ARS(Trainer):
|
||||
class ARS(Algorithm):
|
||||
"""Large-scale implementation of Augmented Random Search in Ray."""
|
||||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return ARSConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
@override(Algorithm)
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
|
@ -341,7 +341,7 @@ class ARS(Trainer):
|
|||
"`NoFilter` for ARS!"
|
||||
)
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def setup(self, config):
|
||||
# Setup our config: Merge the user-supplied config (which could
|
||||
# be a partial config dict with the class' default).
|
||||
|
@ -387,7 +387,7 @@ class ARS(Trainer):
|
|||
self.reward_list = []
|
||||
self.tstart = time.time()
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def get_policy(self, policy=DEFAULT_POLICY_ID):
|
||||
if policy != DEFAULT_POLICY_ID:
|
||||
raise ValueError(
|
||||
|
@ -396,7 +396,7 @@ class ARS(Trainer):
|
|||
)
|
||||
return self.policy
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def step(self):
|
||||
config = self.config
|
||||
|
||||
|
@ -505,20 +505,20 @@ class ARS(Trainer):
|
|||
|
||||
return result
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def cleanup(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for w in self.workers:
|
||||
w.__ray_terminate__.remote()
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def compute_single_action(self, observation, *args, **kwargs):
|
||||
action, _, _ = self.policy.compute_actions([observation], update=True)
|
||||
if kwargs.get("full_fetch"):
|
||||
return action[0], [], {}
|
||||
return action[0]
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def _sync_weights_to_workers(self, *, worker_set=None, workers=None):
|
||||
# Broadcast the new policy weights to all evaluation workers.
|
||||
assert worker_set is not None
|
||||
|
|
|
@ -15,7 +15,7 @@ class TestARS(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
|
||||
def test_ars_compilation(self):
|
||||
"""Test whether an ARSTrainer can be built on all frameworks."""
|
||||
"""Test whether an ARSAlgorithm can be built on all frameworks."""
|
||||
config = ars.ARSConfig()
|
||||
|
||||
# Keep it simple.
|
||||
|
|
|
@ -1,19 +1,19 @@
|
|||
import logging
|
||||
from typing import Type, Union
|
||||
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
from ray.rllib.algorithms.bandit.bandit_tf_policy import BanditTFPolicy
|
||||
from ray.rllib.algorithms.bandit.bandit_torch_policy import BanditTorchPolicy
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BanditConfig(TrainerConfig):
|
||||
class BanditConfig(AlgorithmConfig):
|
||||
"""Defines a contextual bandit configuration class from which
|
||||
a contexual bandit algorithm can be built. Note this config is shared
|
||||
between BanditLinUCB and BanditLinTS. You likely
|
||||
|
@ -21,18 +21,18 @@ class BanditConfig(TrainerConfig):
|
|||
instead.
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class: Union["BanditLinTS", "BanditLinUCB"] = None):
|
||||
super().__init__(trainer_class=trainer_class)
|
||||
def __init__(self, algo_class: Union["BanditLinTS", "BanditLinUCB"] = None):
|
||||
super().__init__(algo_class=algo_class)
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# Override some of TrainerConfig's default values with bandit-specific values.
|
||||
# Override some of AlgorithmConfig's default values with bandit-specific values.
|
||||
self.framework_str = "torch"
|
||||
self.num_workers = 0
|
||||
self.rollout_fragment_length = 1
|
||||
self.train_batch_size = 1
|
||||
# Make sure, a `train()` call performs at least 100 env sampling
|
||||
# timesteps, before reporting results. Not setting this (default is 0)
|
||||
# would significantly slow down the Bandit Trainer.
|
||||
# would significantly slow down the Bandit Algorithm.
|
||||
self.min_sample_timesteps_per_iteration = 100
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
@ -46,16 +46,16 @@ class BanditLinTSConfig(BanditConfig):
|
|||
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
|
||||
>>> config = BanditLinTSConfig().rollouts(num_rollout_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env=WheelBanditEnv)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(trainer_class=BanditLinTS)
|
||||
super().__init__(algo_class=BanditLinTS)
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# Override some of TrainerConfig's default values with bandit-specific values.
|
||||
# Override some of AlgorithmConfig's default values with bandit-specific values.
|
||||
self.exploration_config = {"type": "ThompsonSampling"}
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
@ -69,31 +69,31 @@ class BanditLinUCBConfig(BanditConfig):
|
|||
>>> from ray.rllib.examples.env.bandit_envs_discrete import WheelBanditEnv
|
||||
>>> config = BanditLinUCBConfig().rollouts(num_rollout_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env=WheelBanditEnv)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(trainer_class=BanditLinUCB)
|
||||
super().__init__(algo_class=BanditLinUCB)
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# Override some of TrainerConfig's default values with bandit-specific values.
|
||||
# Override some of AlgorithmConfig's default values with bandit-specific values.
|
||||
self.exploration_config = {"type": "UpperConfidenceBound"}
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
|
||||
class BanditLinTS(Trainer):
|
||||
"""Bandit Trainer using ThompsonSampling exploration."""
|
||||
class BanditLinTS(Algorithm):
|
||||
"""Bandit Algorithm using ThompsonSampling exploration."""
|
||||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> BanditLinTSConfig:
|
||||
return BanditLinTSConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
return BanditTorchPolicy
|
||||
elif config["framework"] == "tf2":
|
||||
|
@ -102,14 +102,14 @@ class BanditLinTS(Trainer):
|
|||
raise NotImplementedError("Only `framework=[torch|tf2]` supported!")
|
||||
|
||||
|
||||
class BanditLinUCB(Trainer):
|
||||
class BanditLinUCB(Algorithm):
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> BanditLinUCBConfig:
|
||||
return BanditLinUCBConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
return BanditTorchPolicy
|
||||
elif config["framework"] == "tf2":
|
||||
|
|
|
@ -20,7 +20,7 @@ from ray.rllib.policy.tf_policy_template import build_tf_policy
|
|||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.tf_utils import make_tf_callable
|
||||
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import TensorType, AlgorithmConfigDict
|
||||
from ray.util.debug import log_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -71,7 +71,7 @@ def validate_spaces(
|
|||
policy: Policy,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> None:
|
||||
"""Validates the observation- and action spaces used for the Policy.
|
||||
|
||||
|
|
|
@ -29,15 +29,15 @@ class TestBandits(unittest.TestCase):
|
|||
):
|
||||
for train_batch_size in [1, 10]:
|
||||
config.training(train_batch_size=train_batch_size)
|
||||
trainer = config.build()
|
||||
algo = config.build()
|
||||
results = None
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
for _ in range(num_iterations):
|
||||
results = algo.train()
|
||||
check_train_results(results)
|
||||
print(results)
|
||||
# Force good learning behavior (this is a very simple env).
|
||||
self.assertTrue(results["episode_reward_mean"] == 10.0)
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
def test_bandit_lin_ucb_compilation(self):
|
||||
"""Test whether BanditLinUCB can be built on all frameworks."""
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict
|
||||
|
||||
|
||||
class BCConfig(MARWILConfig):
|
||||
|
@ -38,8 +38,8 @@ class BCConfig(MARWILConfig):
|
|||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
super().__init__(trainer_class=trainer_class or BC)
|
||||
def __init__(self, algo_class=None):
|
||||
super().__init__(algo_class=algo_class or BC)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -62,11 +62,11 @@ class BC(MARWIL):
|
|||
|
||||
@classmethod
|
||||
@override(MARWIL)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return BCConfig().to_dict()
|
||||
|
||||
@override(MARWIL)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
|
|
609
rllib/algorithms/callbacks.py
Normal file
609
rllib/algorithms/callbacks.py
Normal file
|
@ -0,0 +1,609 @@
|
|||
import numpy as np
|
||||
import os
|
||||
import tracemalloc
|
||||
from typing import Dict, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.episode import Episode
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.utils.annotations import (
|
||||
is_overridden,
|
||||
OverrideToImplementCustomLogic,
|
||||
PublicAPI,
|
||||
)
|
||||
from ray.rllib.utils.deprecation import deprecation_warning, Deprecated
|
||||
from ray.rllib.utils.exploration.random_encoder import (
|
||||
_MovingMeanStd,
|
||||
compute_states_entropy,
|
||||
update_beta,
|
||||
)
|
||||
from ray.rllib.utils.typing import AgentID, EnvType, PolicyID
|
||||
from ray.tune.callback import _CallbackMeta
|
||||
|
||||
# Import psutil after ray so the packaged version is used.
|
||||
import psutil
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.evaluation import RolloutWorker
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class DefaultCallbacks(metaclass=_CallbackMeta):
|
||||
"""Abstract base class for RLlib callbacks (similar to Keras callbacks).
|
||||
|
||||
These callbacks can be used for custom metrics and custom postprocessing.
|
||||
|
||||
By default, all of these callbacks are no-ops. To configure custom training
|
||||
callbacks, subclass DefaultCallbacks and then set
|
||||
{"callbacks": YourCallbacksClass} in the algo config.
|
||||
"""
|
||||
|
||||
def __init__(self, legacy_callbacks_dict: Dict[str, callable] = None):
|
||||
if legacy_callbacks_dict:
|
||||
deprecation_warning(
|
||||
"callbacks dict interface",
|
||||
"a class extending rllib.algorithms.callbacks.DefaultCallbacks",
|
||||
)
|
||||
self.legacy_callbacks = legacy_callbacks_dict or {}
|
||||
if is_overridden(self.on_trainer_init):
|
||||
deprecation_warning(
|
||||
old="on_trainer_init(trainer, **kwargs)",
|
||||
new="on_algorithm_init(algorithm, **kwargs)",
|
||||
error=True,
|
||||
)
|
||||
|
||||
def on_sub_environment_created(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
sub_environment: EnvType,
|
||||
env_context: EnvContext,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Callback run when a new sub-environment has been created.
|
||||
|
||||
This method gets called after each sub-environment (usually a
|
||||
gym.Env) has been created, validated (RLlib built-in validation
|
||||
+ possible custom validation function implemented by overriding
|
||||
`Algorithm.validate_env()`), wrapped (e.g. video-wrapper), and seeded.
|
||||
|
||||
Args:
|
||||
worker: Reference to the current rollout worker.
|
||||
sub_environment: The sub-environment instance that has been
|
||||
created. This is usually a gym.Env object.
|
||||
env_context: The `EnvContext` object that has been passed to
|
||||
the env's constructor.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_algorithm_init(
|
||||
self,
|
||||
*,
|
||||
algorithm: "Algorithm",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Callback run when a new algorithm instance has finished setup.
|
||||
|
||||
This method gets called at the end of Algorithm.setup() after all
|
||||
the initialization is done, and before actually training starts.
|
||||
|
||||
Args:
|
||||
algorithm: Reference to the trainer instance.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_episode_start(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: Episode,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Callback run on the rollout worker before each episode starts.
|
||||
|
||||
Args:
|
||||
worker: Reference to the current rollout worker.
|
||||
base_env: BaseEnv running the episode. The underlying
|
||||
sub environment objects can be retrieved by calling
|
||||
`base_env.get_sub_environments()`.
|
||||
policies: Mapping of policy id to policy objects. In single
|
||||
agent mode there will only be a single "default" policy.
|
||||
episode: Episode object which contains the episode's
|
||||
state. You can use the `episode.user_data` dict to store
|
||||
temporary data, and `episode.custom_metrics` to store custom
|
||||
metrics for the episode.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_episode_start"):
|
||||
self.legacy_callbacks["on_episode_start"](
|
||||
{
|
||||
"env": base_env,
|
||||
"policy": policies,
|
||||
"episode": episode,
|
||||
}
|
||||
)
|
||||
|
||||
def on_episode_step(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Optional[Dict[PolicyID, Policy]] = None,
|
||||
episode: Episode,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Runs on each episode step.
|
||||
|
||||
Args:
|
||||
worker: Reference to the current rollout worker.
|
||||
base_env: BaseEnv running the episode. The underlying
|
||||
sub environment objects can be retrieved by calling
|
||||
`base_env.get_sub_environments()`.
|
||||
policies: Mapping of policy id to policy objects.
|
||||
In single agent mode there will only be a single
|
||||
"default_policy".
|
||||
episode: Episode object which contains episode
|
||||
state. You can use the `episode.user_data` dict to store
|
||||
temporary data, and `episode.custom_metrics` to store custom
|
||||
metrics for the episode.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_episode_step"):
|
||||
self.legacy_callbacks["on_episode_step"](
|
||||
{"env": base_env, "episode": episode}
|
||||
)
|
||||
|
||||
def on_episode_end(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: Episode,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Runs when an episode is done.
|
||||
|
||||
Args:
|
||||
worker: Reference to the current rollout worker.
|
||||
base_env: BaseEnv running the episode. The underlying
|
||||
sub environment objects can be retrieved by calling
|
||||
`base_env.get_sub_environments()`.
|
||||
policies: Mapping of policy id to policy
|
||||
objects. In single agent mode there will only be a single
|
||||
"default_policy".
|
||||
episode: Episode object which contains episode
|
||||
state. You can use the `episode.user_data` dict to store
|
||||
temporary data, and `episode.custom_metrics` to store custom
|
||||
metrics for the episode.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_episode_end"):
|
||||
self.legacy_callbacks["on_episode_end"](
|
||||
{
|
||||
"env": base_env,
|
||||
"policy": policies,
|
||||
"episode": episode,
|
||||
}
|
||||
)
|
||||
|
||||
def on_postprocess_trajectory(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
episode: Episode,
|
||||
agent_id: AgentID,
|
||||
policy_id: PolicyID,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
postprocessed_batch: SampleBatch,
|
||||
original_batches: Dict[AgentID, Tuple[Policy, SampleBatch]],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Called immediately after a policy's postprocess_fn is called.
|
||||
|
||||
You can use this callback to do additional postprocessing for a policy,
|
||||
including looking at the trajectory data of other agents in multi-agent
|
||||
settings.
|
||||
|
||||
Args:
|
||||
worker: Reference to the current rollout worker.
|
||||
episode: Episode object.
|
||||
agent_id: Id of the current agent.
|
||||
policy_id: Id of the current policy for the agent.
|
||||
policies: Mapping of policy id to policy objects. In single
|
||||
agent mode there will only be a single "default_policy".
|
||||
postprocessed_batch: The postprocessed sample batch
|
||||
for this agent. You can mutate this object to apply your own
|
||||
trajectory postprocessing.
|
||||
original_batches: Mapping of agents to their unpostprocessed
|
||||
trajectory data. You should not mutate this object.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_postprocess_traj"):
|
||||
self.legacy_callbacks["on_postprocess_traj"](
|
||||
{
|
||||
"episode": episode,
|
||||
"agent_id": agent_id,
|
||||
"pre_batch": original_batches[agent_id],
|
||||
"post_batch": postprocessed_batch,
|
||||
"all_pre_batches": original_batches,
|
||||
}
|
||||
)
|
||||
|
||||
def on_sample_end(
|
||||
self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs
|
||||
) -> None:
|
||||
"""Called at the end of RolloutWorker.sample().
|
||||
|
||||
Args:
|
||||
worker: Reference to the current rollout worker.
|
||||
samples: Batch to be returned. You can mutate this
|
||||
object to modify the samples generated.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_sample_end"):
|
||||
self.legacy_callbacks["on_sample_end"](
|
||||
{
|
||||
"worker": worker,
|
||||
"samples": samples,
|
||||
}
|
||||
)
|
||||
|
||||
def on_learn_on_batch(
|
||||
self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs
|
||||
) -> None:
|
||||
"""Called at the beginning of Policy.learn_on_batch().
|
||||
|
||||
Note: This is called before 0-padding via
|
||||
`pad_batch_to_sequences_of_same_size`.
|
||||
|
||||
Also note, SampleBatch.INFOS column will not be available on
|
||||
train_batch within this callback if framework is tf1, due to
|
||||
the fact that tf1 static graph would mistake it as part of the
|
||||
input dict if present.
|
||||
It is available though, for tf2 and torch frameworks.
|
||||
|
||||
Args:
|
||||
policy: Reference to the current Policy object.
|
||||
train_batch: SampleBatch to be trained on. You can
|
||||
mutate this object to modify the samples generated.
|
||||
result: A results dict to add custom metrics to.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def on_train_result(
|
||||
self,
|
||||
*,
|
||||
algorithm: Optional["Algorithm"] = None,
|
||||
result: dict,
|
||||
trainer=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Called at the end of Trainable.train().
|
||||
|
||||
Args:
|
||||
algorithm: Current trainer instance.
|
||||
result: Dict of results returned from trainer.train() call.
|
||||
You can mutate this object to add additional metrics.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
if trainer is not None:
|
||||
algorithm = trainer
|
||||
|
||||
if self.legacy_callbacks.get("on_train_result"):
|
||||
self.legacy_callbacks["on_train_result"](
|
||||
{
|
||||
"trainer": algorithm,
|
||||
"result": result,
|
||||
}
|
||||
)
|
||||
|
||||
@OverrideToImplementCustomLogic
|
||||
@Deprecated(error=True)
|
||||
def on_trainer_init(self, *args, **kwargs):
|
||||
raise DeprecationWarning
|
||||
|
||||
|
||||
class MemoryTrackingCallbacks(DefaultCallbacks):
|
||||
"""MemoryTrackingCallbacks can be used to trace and track memory usage
|
||||
in rollout workers.
|
||||
|
||||
The Memory Tracking Callbacks uses tracemalloc and psutil to track
|
||||
python allocations during rollouts,
|
||||
in training or evaluation.
|
||||
|
||||
The tracking data is logged to the custom_metrics of an episode and
|
||||
can therefore be viewed in tensorboard
|
||||
(or in WandB etc..)
|
||||
|
||||
Add MemoryTrackingCallbacks callback to the tune config
|
||||
e.g. { ...'callbacks': MemoryTrackingCallbacks ...}
|
||||
|
||||
Note:
|
||||
This class is meant for debugging and should not be used
|
||||
in production code as tracemalloc incurs
|
||||
a significant slowdown in execution speed.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Will track the top 10 lines where memory is allocated
|
||||
tracemalloc.start(10)
|
||||
|
||||
def on_episode_end(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: Episode,
|
||||
env_index: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
top_stats = snapshot.statistics("lineno")
|
||||
|
||||
for stat in top_stats[:10]:
|
||||
count = stat.count
|
||||
size = stat.size
|
||||
|
||||
trace = str(stat.traceback)
|
||||
|
||||
episode.custom_metrics[f"tracemalloc/{trace}/size"] = size
|
||||
episode.custom_metrics[f"tracemalloc/{trace}/count"] = count
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
worker_rss = process.memory_info().rss
|
||||
worker_data = process.memory_info().data
|
||||
worker_vms = process.memory_info().vms
|
||||
episode.custom_metrics["tracemalloc/worker/rss"] = worker_rss
|
||||
episode.custom_metrics["tracemalloc/worker/data"] = worker_data
|
||||
episode.custom_metrics["tracemalloc/worker/vms"] = worker_vms
|
||||
|
||||
|
||||
class MultiCallbacks(DefaultCallbacks):
|
||||
"""MultiCallbacks allows multiple callbacks to be registered at
|
||||
the same time in the config of the environment.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
'callbacks': MultiCallbacks([
|
||||
MyCustomStatsCallbacks,
|
||||
MyCustomVideoCallbacks,
|
||||
MyCustomTraceCallbacks,
|
||||
....
|
||||
])
|
||||
"""
|
||||
|
||||
IS_CALLBACK_CONTAINER = True
|
||||
|
||||
def __init__(self, callback_class_list):
|
||||
super().__init__()
|
||||
self._callback_class_list = callback_class_list
|
||||
|
||||
self._callback_list = []
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._callback_list = [
|
||||
callback_class() for callback_class in self._callback_class_list
|
||||
]
|
||||
|
||||
return self
|
||||
|
||||
def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_algorithm_init(algorithm=algorithm, **kwargs)
|
||||
|
||||
def on_sub_environment_created(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
sub_environment: EnvType,
|
||||
env_context: EnvContext,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_sub_environment_created(
|
||||
worker=worker,
|
||||
sub_environment=sub_environment,
|
||||
env_context=env_context,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_episode_start(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: Episode,
|
||||
env_index: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_episode_start(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode,
|
||||
env_index=env_index,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_episode_step(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Optional[Dict[PolicyID, Policy]] = None,
|
||||
episode: Episode,
|
||||
env_index: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_episode_step(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode,
|
||||
env_index=env_index,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_episode_end(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: Episode,
|
||||
env_index: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_episode_end(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode,
|
||||
env_index=env_index,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_postprocess_trajectory(
|
||||
self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
episode: Episode,
|
||||
agent_id: AgentID,
|
||||
policy_id: PolicyID,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
postprocessed_batch: SampleBatch,
|
||||
original_batches: Dict[AgentID, Tuple[Policy, SampleBatch]],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_postprocess_trajectory(
|
||||
worker=worker,
|
||||
episode=episode,
|
||||
agent_id=agent_id,
|
||||
policy_id=policy_id,
|
||||
policies=policies,
|
||||
postprocessed_batch=postprocessed_batch,
|
||||
original_batches=original_batches,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_sample_end(
|
||||
self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs
|
||||
) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_sample_end(worker=worker, samples=samples, **kwargs)
|
||||
|
||||
def on_learn_on_batch(
|
||||
self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs
|
||||
) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_learn_on_batch(
|
||||
policy=policy, train_batch=train_batch, result=result, **kwargs
|
||||
)
|
||||
|
||||
def on_train_result(
|
||||
self, *, algorithm=None, result: dict, trainer=None, **kwargs
|
||||
) -> None:
|
||||
if trainer is not None:
|
||||
algorithm = trainer
|
||||
|
||||
for callback in self._callback_list:
|
||||
# TODO: Remove `trainer` arg at some point to fully deprecate the old term.
|
||||
callback.on_train_result(
|
||||
algorithm=algorithm, result=result, trainer=algorithm, **kwargs
|
||||
)
|
||||
|
||||
|
||||
# This Callback is used by the RE3 exploration strategy.
|
||||
# See rllib/examples/re3_exploration.py for details.
|
||||
class RE3UpdateCallbacks(DefaultCallbacks):
|
||||
"""Update input callbacks to mutate batch with states entropy rewards."""
|
||||
|
||||
_step = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
embeds_dim: int = 128,
|
||||
k_nn: int = 50,
|
||||
beta: float = 0.1,
|
||||
rho: float = 0.0001,
|
||||
beta_schedule: str = "constant",
|
||||
**kwargs,
|
||||
):
|
||||
self.embeds_dim = embeds_dim
|
||||
self.k_nn = k_nn
|
||||
self.beta = beta
|
||||
self.rho = rho
|
||||
self.beta_schedule = beta_schedule
|
||||
self._rms = _MovingMeanStd()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def on_learn_on_batch(
|
||||
self,
|
||||
*,
|
||||
policy: Policy,
|
||||
train_batch: SampleBatch,
|
||||
result: dict,
|
||||
**kwargs,
|
||||
):
|
||||
super().on_learn_on_batch(
|
||||
policy=policy, train_batch=train_batch, result=result, **kwargs
|
||||
)
|
||||
states_entropy = compute_states_entropy(
|
||||
train_batch[SampleBatch.OBS_EMBEDS], self.embeds_dim, self.k_nn
|
||||
)
|
||||
states_entropy = update_beta(
|
||||
self.beta_schedule, self.beta, self.rho, RE3UpdateCallbacks._step
|
||||
) * np.reshape(
|
||||
self._rms(states_entropy),
|
||||
train_batch[SampleBatch.OBS_EMBEDS].shape[:-1],
|
||||
)
|
||||
train_batch[SampleBatch.REWARDS] = (
|
||||
train_batch[SampleBatch.REWARDS] + states_entropy
|
||||
)
|
||||
if Postprocessing.ADVANTAGES in train_batch:
|
||||
train_batch[Postprocessing.ADVANTAGES] = (
|
||||
train_batch[Postprocessing.ADVANTAGES] + states_entropy
|
||||
)
|
||||
train_batch[Postprocessing.VALUE_TARGETS] = (
|
||||
train_batch[Postprocessing.VALUE_TARGETS] + states_entropy
|
||||
)
|
||||
|
||||
def on_train_result(
|
||||
self, *, result: dict, algorithm=None, trainer=None, **kwargs
|
||||
) -> None:
|
||||
if trainer is not None:
|
||||
algorithm = trainer
|
||||
# TODO(gjoliver): Remove explicit _step tracking and pass
|
||||
# trainer._iteration as a parameter to on_learn_on_batch() call.
|
||||
RE3UpdateCallbacks._step = result["training_iteration"]
|
||||
# TODO: Remove `trainer` arg at some point to fully deprecate the old term.
|
||||
super().on_train_result(
|
||||
algorithm=algorithm, result=result, trainer=algorithm, **kwargs
|
||||
)
|
|
@ -32,7 +32,7 @@ from ray.rllib.utils.metrics import (
|
|||
SYNCH_WORKER_WEIGHTS_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer
|
||||
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import ResultDict, AlgorithmConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
tfp = try_import_tfp()
|
||||
|
@ -52,8 +52,8 @@ class CQLConfig(SACConfig):
|
|||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
super().__init__(trainer_class=trainer_class or CQL)
|
||||
def __init__(self, algo_class=None):
|
||||
super().__init__(algo_class=algo_class or CQL)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -99,7 +99,7 @@ class CQLConfig(SACConfig):
|
|||
min_q_weight: in Q weight multiplier.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -165,11 +165,11 @@ class CQL(SAC):
|
|||
|
||||
@classmethod
|
||||
@override(SAC)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return CQLConfig().to_dict()
|
||||
|
||||
@override(SAC)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
# First check, whether old `timesteps_per_iteration` is used. If so
|
||||
# convert right away as for CQL, we must measure in training timesteps,
|
||||
# never sampling timesteps (CQL does not sample).
|
||||
|
@ -206,7 +206,7 @@ class CQL(SAC):
|
|||
try_import_tfp(error=True)
|
||||
|
||||
@override(SAC)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
return CQLTorchPolicy
|
||||
else:
|
||||
|
|
|
@ -35,7 +35,7 @@ from ray.rllib.utils.typing import (
|
|||
LocalOptimizer,
|
||||
ModelGradients,
|
||||
TensorType,
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
)
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
@ -314,7 +314,7 @@ def setup_early_mixins(
|
|||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> None:
|
||||
"""Call mixin classes' constructors before Policy's initialization.
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ from ray.rllib.policy.policy import Policy
|
|||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.typing import LocalOptimizer, TensorType, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import LocalOptimizer, TensorType, AlgorithmConfigDict
|
||||
from ray.rllib.utils.torch_utils import (
|
||||
apply_grad_clipping,
|
||||
convert_to_torch_tensor,
|
||||
|
@ -340,7 +340,7 @@ def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]
|
|||
|
||||
|
||||
def cql_optimizer_fn(
|
||||
policy: Policy, config: TrainerConfigDict
|
||||
policy: Policy, config: AlgorithmConfigDict
|
||||
) -> Tuple[LocalOptimizer]:
|
||||
policy.cur_iter = 0
|
||||
opt_list = optimizer_fn(policy, config)
|
||||
|
@ -365,7 +365,7 @@ def cql_setup_late_mixins(
|
|||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> None:
|
||||
setup_late_mixins(policy, obs_space, action_space, config)
|
||||
if config["lagrangian"]:
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
from typing import Type, List, Optional
|
||||
import tree
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer, TrainerConfig
|
||||
from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig
|
||||
from ray.rllib.execution.train_ops import (
|
||||
multi_gpu_train_one_step,
|
||||
train_one_step,
|
||||
|
@ -18,17 +18,17 @@ from ray.rllib.utils.metrics import (
|
|||
TARGET_NET_UPDATE_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.typing import (
|
||||
PartialTrainerConfigDict,
|
||||
PartialAlgorithmConfigDict,
|
||||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRRConfig(TrainerConfig):
|
||||
def __init__(self, trainer_class=None):
|
||||
super().__init__(trainer_class=trainer_class or CRR)
|
||||
class CRRConfig(AlgorithmConfig):
|
||||
def __init__(self, algo_class=None):
|
||||
super().__init__(algo_class=algo_class or CRR)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -142,13 +142,13 @@ class CRRConfig(TrainerConfig):
|
|||
NUM_GRADIENT_UPDATES = "num_grad_updates"
|
||||
|
||||
|
||||
class CRR(Trainer):
|
||||
class CRR(Algorithm):
|
||||
|
||||
# TODO: we have a circular dependency for get
|
||||
# default config. config -> Trainer -> config
|
||||
# defining Config class in the same file for now as a workaround.
|
||||
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
def setup(self, config: PartialAlgorithmConfigDict):
|
||||
super().setup(config)
|
||||
# initial setup for handling the offline data in form of a replay buffer
|
||||
# Add the entire dataset to Replay Buffer (global variable)
|
||||
|
@ -194,12 +194,12 @@ class CRR(Trainer):
|
|||
self._counters[NUM_TARGET_UPDATES] = 0
|
||||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return CRRConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.algorithms.crr.torch import CRRTorchPolicy
|
||||
|
||||
|
@ -207,7 +207,7 @@ class CRR(Trainer):
|
|||
else:
|
||||
raise ValueError("Non-torch frameworks are not supported yet!")
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def training_step(self) -> ResultDict:
|
||||
|
||||
total_transitions = len(self.local_replay_buffer)
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from ray.rllib.agents import TrainerConfig
|
||||
from ray.rllib.algorithms import AlgorithmConfig
|
||||
from ray.rllib.algorithms.crr.torch import CRRModel
|
||||
from ray.rllib.algorithms.ddpg.noop_model import TorchNoopModel
|
||||
from ray.rllib.algorithms.sac.sac_torch_policy import TargetNetworkMixin
|
||||
|
@ -384,6 +384,6 @@ if __name__ == "__main__":
|
|||
|
||||
obs_space = gym.spaces.Box(np.array((-1, -1)), np.array((1, 1)))
|
||||
act_space = gym.spaces.Box(np.array((-1, -1)), np.array((1, 1)))
|
||||
config = TrainerConfig().framework(framework="torch").to_dict()
|
||||
config = AlgorithmConfig().framework(framework="torch").to_dict()
|
||||
print(config["framework"])
|
||||
CRRTorchPolicy(obs_space, act_space, config=config)
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import logging
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
from ray.rllib.algorithms.simple_q.simple_q import SimpleQ, SimpleQConfig
|
||||
from ray.rllib.algorithms.ddpg.ddpg_tf_policy import DDPGTFPolicy
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
|
||||
|
@ -44,9 +44,9 @@ class DDPGConfig(SimpleQConfig):
|
|||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes a DDPGConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or DDPG)
|
||||
super().__init__(algo_class=algo_class or DDPG)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -129,7 +129,7 @@ class DDPGConfig(SimpleQConfig):
|
|||
# Deprecated.
|
||||
self.worker_side_prioritization = DEPRECATED_VALUE
|
||||
|
||||
@override(TrainerConfig)
|
||||
@override(AlgorithmConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
|
@ -256,12 +256,12 @@ class DDPGConfig(SimpleQConfig):
|
|||
class DDPG(SimpleQ):
|
||||
@classmethod
|
||||
@override(SimpleQ)
|
||||
# TODO make this return a TrainerConfig
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
# TODO make this return a AlgorithmConfig
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return DDPGConfig().to_dict()
|
||||
|
||||
@override(SimpleQ)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.algorithms.ddpg.ddpg_torch_policy import DDPGTorchPolicy
|
||||
|
||||
|
@ -270,7 +270,7 @@ class DDPG(SimpleQ):
|
|||
return DDPGTFPolicy
|
||||
|
||||
@override(SimpleQ)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
|
|
@ -29,7 +29,7 @@ from ray.rllib.utils.framework import get_variable, try_import_tf
|
|||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.tf_utils import huber_loss, make_tf_callable
|
||||
from ray.rllib.utils.typing import (
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
TensorType,
|
||||
LocalOptimizer,
|
||||
ModelGradients,
|
||||
|
@ -45,7 +45,7 @@ def build_ddpg_models(
|
|||
policy: Policy,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> ModelV2:
|
||||
if policy.config["use_state_preprocessor"]:
|
||||
default_model = None # catalog decides
|
||||
|
@ -379,7 +379,7 @@ def setup_early_mixins(
|
|||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> None:
|
||||
"""Call mixin classes' constructors before Policy's initialization.
|
||||
|
||||
|
@ -425,13 +425,13 @@ def setup_mid_mixins(
|
|||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> None:
|
||||
ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss)
|
||||
|
||||
|
||||
class TargetNetworkMixin:
|
||||
def __init__(self, config: TrainerConfigDict):
|
||||
def __init__(self, config: AlgorithmConfigDict):
|
||||
@make_tf_callable(self.get_session())
|
||||
def update_target_fn(tau):
|
||||
tau = tf.convert_to_tensor(tau, dtype=tf.float32)
|
||||
|
@ -466,7 +466,7 @@ def setup_late_mixins(
|
|||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> None:
|
||||
TargetNetworkMixin.__init__(policy, config)
|
||||
|
||||
|
@ -475,7 +475,7 @@ def validate_spaces(
|
|||
policy: Policy,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> None:
|
||||
if not isinstance(action_space, Box):
|
||||
raise UnsupportedSpaceException(
|
||||
|
|
|
@ -28,7 +28,7 @@ from ray.rllib.utils.torch_utils import (
|
|||
l2_loss,
|
||||
)
|
||||
from ray.rllib.utils.typing import (
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
TensorType,
|
||||
LocalOptimizer,
|
||||
GradInfoDict,
|
||||
|
@ -43,7 +43,7 @@ def build_ddpg_models_and_action_dist(
|
|||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> Tuple[ModelV2, ActionDistribution]:
|
||||
model = build_ddpg_models(policy, obs_space, action_space, config)
|
||||
|
||||
|
@ -202,7 +202,7 @@ def ddpg_actor_critic_loss(
|
|||
|
||||
|
||||
def make_ddpg_optimizers(
|
||||
policy: Policy, config: TrainerConfigDict
|
||||
policy: Policy, config: AlgorithmConfigDict
|
||||
) -> Tuple[LocalOptimizer]:
|
||||
"""Create separate optimizers for actor & critic losses."""
|
||||
|
||||
|
@ -248,7 +248,7 @@ def before_init_fn(
|
|||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> None:
|
||||
# Create global step for counting the number of update operations.
|
||||
policy.global_step = 0
|
||||
|
@ -285,7 +285,7 @@ def setup_late_mixins(
|
|||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> None:
|
||||
ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss)
|
||||
TargetNetworkMixin.__init__(policy)
|
||||
|
|
|
@ -47,21 +47,20 @@ class TestDDPG(unittest.TestCase):
|
|||
|
||||
# Test against all frameworks.
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
""""""
|
||||
trainer = config.build(env="Pendulum-v1")
|
||||
algo = config.build(env="Pendulum-v1")
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
results = algo.train()
|
||||
check_train_results(results)
|
||||
print(results)
|
||||
check_compute_single_action(trainer)
|
||||
check_compute_single_action(algo)
|
||||
# Ensure apply_gradient_fn is being called and updating global_step
|
||||
pol = trainer.get_policy()
|
||||
pol = algo.get_policy()
|
||||
if config.framework_str == "tf":
|
||||
a = pol.get_session().run(pol.global_step)
|
||||
else:
|
||||
a = pol.global_step
|
||||
check(a, 500)
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
def test_ddpg_exploration_and_with_random_prerun(self):
|
||||
"""Tests DDPG's Exploration (w/ random actions for n timesteps)."""
|
||||
|
@ -74,21 +73,21 @@ class TestDDPG(unittest.TestCase):
|
|||
config = ddpg.DDPGConfig().rollouts(num_rollout_workers=0)
|
||||
config.seed = 42
|
||||
# Default OUNoise setup.
|
||||
trainer = config.build(env="Pendulum-v1")
|
||||
algo = config.build(env="Pendulum-v1")
|
||||
# Setting explore=False should always return the same action.
|
||||
a_ = trainer.compute_single_action(obs, explore=False)
|
||||
check(trainer.get_policy().global_timestep, 1)
|
||||
a_ = algo.compute_single_action(obs, explore=False)
|
||||
check(algo.get_policy().global_timestep, 1)
|
||||
for i in range(50):
|
||||
a = trainer.compute_single_action(obs, explore=False)
|
||||
check(trainer.get_policy().global_timestep, i + 2)
|
||||
a = algo.compute_single_action(obs, explore=False)
|
||||
check(algo.get_policy().global_timestep, i + 2)
|
||||
check(a, a_)
|
||||
# explore=None (default: explore) should return different actions.
|
||||
actions = []
|
||||
for i in range(50):
|
||||
actions.append(trainer.compute_single_action(obs))
|
||||
check(trainer.get_policy().global_timestep, i + 52)
|
||||
actions.append(algo.compute_single_action(obs))
|
||||
check(algo.get_policy().global_timestep, i + 52)
|
||||
check(np.std(actions), 0.0, false=True)
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
# Check randomness at beginning.
|
||||
config.exploration_config.update(
|
||||
|
@ -102,30 +101,30 @@ class TestDDPG(unittest.TestCase):
|
|||
}
|
||||
)
|
||||
|
||||
trainer = ddpg.DDPG(config=config, env="Pendulum-v1")
|
||||
algo = ddpg.DDPG(config=config, env="Pendulum-v1")
|
||||
# ts=0 (get a deterministic action as per explore=False).
|
||||
deterministic_action = trainer.compute_single_action(obs, explore=False)
|
||||
check(trainer.get_policy().global_timestep, 1)
|
||||
deterministic_action = algo.compute_single_action(obs, explore=False)
|
||||
check(algo.get_policy().global_timestep, 1)
|
||||
# ts=1-49 (in random window).
|
||||
random_a = []
|
||||
for i in range(1, 50):
|
||||
random_a.append(trainer.compute_single_action(obs, explore=True))
|
||||
check(trainer.get_policy().global_timestep, i + 1)
|
||||
random_a.append(algo.compute_single_action(obs, explore=True))
|
||||
check(algo.get_policy().global_timestep, i + 1)
|
||||
check(random_a[-1], deterministic_action, false=True)
|
||||
self.assertTrue(np.std(random_a) > 0.5)
|
||||
|
||||
# ts > 50 (a=deterministic_action + scale * N[0,1])
|
||||
for i in range(50):
|
||||
a = trainer.compute_single_action(obs, explore=True)
|
||||
check(trainer.get_policy().global_timestep, i + 51)
|
||||
a = algo.compute_single_action(obs, explore=True)
|
||||
check(algo.get_policy().global_timestep, i + 51)
|
||||
check(a, deterministic_action, rtol=0.1)
|
||||
|
||||
# ts >> 50 (BUT: explore=False -> expect deterministic action).
|
||||
for i in range(50):
|
||||
a = trainer.compute_single_action(obs, explore=False)
|
||||
check(trainer.get_policy().global_timestep, i + 101)
|
||||
a = algo.compute_single_action(obs, explore=False)
|
||||
check(algo.get_policy().global_timestep, i + 101)
|
||||
check(a, deterministic_action)
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
def test_ddpg_loss_function(self):
|
||||
"""Tests DDPG loss function results across all frameworks."""
|
||||
|
@ -147,7 +146,7 @@ class TestDDPG(unittest.TestCase):
|
|||
# Use very simple nets.
|
||||
config.actor_hiddens = [10]
|
||||
config.critic_hiddens = [10]
|
||||
# Make sure, timing differences do not affect trainer.train().
|
||||
# Make sure, timing differences do not affect Algorithm.train().
|
||||
config.min_time_s_per_iteration = 0
|
||||
config.min_sample_timesteps_per_iteration = 100
|
||||
|
||||
|
@ -222,9 +221,9 @@ class TestDDPG(unittest.TestCase):
|
|||
for fw, sess in framework_iterator(
|
||||
config, frameworks=("tf", "torch"), session=True
|
||||
):
|
||||
# Generate Trainer and get its default Policy object.
|
||||
trainer = config.build(env=env)
|
||||
policy = trainer.get_policy()
|
||||
# Generate Algorithm and get its default Policy object.
|
||||
algo = config.build(env=env)
|
||||
policy = algo.get_policy()
|
||||
p_sess = None
|
||||
if sess:
|
||||
p_sess = policy.get_session()
|
||||
|
@ -367,9 +366,9 @@ class TestDDPG(unittest.TestCase):
|
|||
tf_inputs.append(in_)
|
||||
# Set a fake-batch to use
|
||||
# (instead of sampling from replay buffer).
|
||||
buf = trainer.local_replay_buffer
|
||||
buf = algo.local_replay_buffer
|
||||
patch_buffer_with_fake_sampling_method(buf, in_)
|
||||
trainer.train()
|
||||
algo.train()
|
||||
updated_weights = policy.get_weights()
|
||||
# Net must have changed.
|
||||
if tf_updated_weights:
|
||||
|
@ -388,9 +387,9 @@ class TestDDPG(unittest.TestCase):
|
|||
in_ = tf_inputs[update_iteration]
|
||||
# Set a fake-batch to use
|
||||
# (instead of sampling from replay buffer).
|
||||
buf = trainer.local_replay_buffer
|
||||
buf = algo.local_replay_buffer
|
||||
patch_buffer_with_fake_sampling_method(buf, in_)
|
||||
trainer.train()
|
||||
algo.train()
|
||||
# Compare updated model and target weights.
|
||||
for tf_key in tf_weights.keys():
|
||||
tf_var = tf_weights[tf_key]
|
||||
|
@ -407,7 +406,7 @@ class TestDDPG(unittest.TestCase):
|
|||
else:
|
||||
check(tf_var, torch_var, atol=0.1)
|
||||
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
def _get_batch_helper(self, obs_size, actions, batch_size):
|
||||
return SampleBatch(
|
||||
|
|
|
@ -42,9 +42,9 @@ from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
|
|||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||
from ray.rllib.utils.typing import (
|
||||
EnvType,
|
||||
PartialTrainerConfigDict,
|
||||
PartialAlgorithmConfigDict,
|
||||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
)
|
||||
from ray.tune.logger import Logger
|
||||
|
||||
|
@ -52,7 +52,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class DDPPOConfig(PPOConfig):
|
||||
"""Defines a configuration class from which a DDPPO Trainer can be built.
|
||||
"""Defines a configuration class from which a DDPPO Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.ddppo import DDPPOConfig
|
||||
|
@ -60,7 +60,7 @@ class DDPPOConfig(PPOConfig):
|
|||
... .resources(num_gpus=1)\
|
||||
... .rollouts(num_workers=10)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -83,9 +83,9 @@ class DDPPOConfig(PPOConfig):
|
|||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes a DDPPOConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or DDPPO)
|
||||
super().__init__(algo_class=algo_class or DDPPO)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -93,7 +93,7 @@ class DDPPOConfig(PPOConfig):
|
|||
self.keep_local_weights_in_sync = True
|
||||
self.torch_distributed_backend = "gloo"
|
||||
|
||||
# Override some of PPO/Trainer's default values with DDPPO-specific values.
|
||||
# Override some of PPO/Algorithm's default values with DDPPO-specific values.
|
||||
# During the sampling phase, each rollout worker will collect a batch
|
||||
# `rollout_fragment_length * num_envs_per_worker` steps in size.
|
||||
self.rollout_fragment_length = 100
|
||||
|
@ -144,7 +144,7 @@ class DDPPOConfig(PPOConfig):
|
|||
distributed.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -160,7 +160,7 @@ class DDPPOConfig(PPOConfig):
|
|||
class DDPPO(PPO):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[PartialTrainerConfigDict] = None,
|
||||
config: Optional[PartialAlgorithmConfigDict] = None,
|
||||
env: Optional[Union[str, EnvType]] = None,
|
||||
logger_creator: Optional[Callable[[], Logger]] = None,
|
||||
remote_checkpoint_dir: Optional[str] = None,
|
||||
|
@ -196,12 +196,12 @@ class DDPPO(PPO):
|
|||
|
||||
@classmethod
|
||||
@override(PPO)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return DDPPOConfig().to_dict()
|
||||
|
||||
@override(PPO)
|
||||
def validate_config(self, config):
|
||||
"""Validates the Trainer's config dict.
|
||||
"""Validates the Algorithm's config dict.
|
||||
|
||||
Args:
|
||||
config: The Trainer's config to check.
|
||||
|
@ -251,7 +251,7 @@ class DDPPO(PPO):
|
|||
raise ValueError("DDPPO doesn't support KL penalties like PPO-1")
|
||||
|
||||
@override(PPO)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
def setup(self, config: PartialAlgorithmConfigDict):
|
||||
super().setup(config)
|
||||
|
||||
# Initialize torch process group for
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Deep Q-Networks (DQN, Rainbow, Parametric DQN)
|
||||
==============================================
|
||||
|
||||
This file defines the distributed Trainer class for the Deep Q-Networks
|
||||
This file defines the distributed Algorithm class for the Deep Q-Networks
|
||||
algorithm. See `dqn_[tf|torch]_policy.py` for the definition of the policies.
|
||||
|
||||
Detailed documentation:
|
||||
|
@ -32,7 +32,7 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer
|
||||
from ray.rllib.utils.typing import (
|
||||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
)
|
||||
from ray.rllib.utils.metrics import (
|
||||
NUM_ENV_STEPS_SAMPLED,
|
||||
|
@ -53,7 +53,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class DQNConfig(SimpleQConfig):
|
||||
"""Defines a configuration class from which a DQN Trainer can be built.
|
||||
"""Defines a configuration class from which a DQN Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.dqn.dqn import DQNConfig
|
||||
|
@ -115,9 +115,9 @@ class DQNConfig(SimpleQConfig):
|
|||
>>> .exploration(exploration_config=explore_config)
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes a DQNConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or DQN)
|
||||
super().__init__(algo_class=algo_class or DQN)
|
||||
|
||||
# DQN specific config settings.
|
||||
# fmt: off
|
||||
|
@ -248,7 +248,7 @@ class DQNConfig(SimpleQConfig):
|
|||
zero, there is still a chance of drawing the sample.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -281,7 +281,7 @@ class DQNConfig(SimpleQConfig):
|
|||
return self
|
||||
|
||||
|
||||
def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:
|
||||
def calculate_rr_weights(config: AlgorithmConfigDict) -> List[float]:
|
||||
"""Calculate the round robin weights for the rollout and train steps"""
|
||||
if not config["training_intensity"]:
|
||||
return [1, 1]
|
||||
|
@ -311,11 +311,11 @@ def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:
|
|||
class DQN(SimpleQ):
|
||||
@classmethod
|
||||
@override(SimpleQ)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return DEFAULT_CONFIG
|
||||
|
||||
@override(SimpleQ)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
|
@ -325,7 +325,7 @@ class DQN(SimpleQ):
|
|||
|
||||
@override(SimpleQ)
|
||||
def get_default_policy_class(
|
||||
self, config: TrainerConfigDict
|
||||
self, config: AlgorithmConfigDict
|
||||
) -> Optional[Type[Policy]]:
|
||||
if config["framework"] == "torch":
|
||||
return DQNTorchPolicy
|
||||
|
|
|
@ -27,7 +27,7 @@ from ray.rllib.utils.tf_utils import (
|
|||
minimize_and_clip,
|
||||
reduce_mean_ignore_inf,
|
||||
)
|
||||
from ray.rllib.utils.typing import ModelGradients, TensorType, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import ModelGradients, TensorType, AlgorithmConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
@ -152,7 +152,7 @@ def build_q_model(
|
|||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> ModelV2:
|
||||
"""Build q_model and target_model for DQN
|
||||
|
||||
|
@ -160,7 +160,7 @@ def build_q_model(
|
|||
policy: The Policy, which will use the model for optimization.
|
||||
obs_space (gym.spaces.Space): The policy's observation space.
|
||||
action_space (gym.spaces.Space): The policy's action space.
|
||||
config (TrainerConfigDict):
|
||||
config (AlgorithmConfigDict):
|
||||
|
||||
Returns:
|
||||
ModelV2: The Model for the Policy to use.
|
||||
|
@ -328,7 +328,7 @@ def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> Tensor
|
|||
|
||||
|
||||
def adam_optimizer(
|
||||
policy: Policy, config: TrainerConfigDict
|
||||
policy: Policy, config: AlgorithmConfigDict
|
||||
) -> "tf.keras.optimizers.Optimizer":
|
||||
if policy.config["framework"] in ["tf2", "tfe"]:
|
||||
return tf.keras.optimizers.Adam(
|
||||
|
@ -372,7 +372,7 @@ def setup_late_mixins(
|
|||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> None:
|
||||
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ from ray.rllib.utils.torch_utils import (
|
|||
reduce_mean_ignore_inf,
|
||||
softmax_cross_entropy_with_logits,
|
||||
)
|
||||
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import TensorType, AlgorithmConfigDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
F = None
|
||||
|
@ -145,7 +145,7 @@ def build_q_model_and_distribution(
|
|||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> Tuple[ModelV2, TorchDistributionWrapper]:
|
||||
"""Build q_model and target_model for DQN
|
||||
|
||||
|
@ -153,7 +153,7 @@ def build_q_model_and_distribution(
|
|||
policy: The policy, which will use the model for optimization.
|
||||
obs_space (gym.spaces.Space): The policy's observation space.
|
||||
action_space (gym.spaces.Space): The policy's action space.
|
||||
config (TrainerConfigDict):
|
||||
config (AlgorithmConfigDict):
|
||||
|
||||
Returns:
|
||||
(q_model, TorchCategorical)
|
||||
|
@ -354,7 +354,7 @@ def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> Tensor
|
|||
|
||||
|
||||
def adam_optimizer(
|
||||
policy: Policy, config: TrainerConfigDict
|
||||
policy: Policy, config: AlgorithmConfigDict
|
||||
) -> "torch.optim.Optimizer":
|
||||
|
||||
# By this time, the models have been moved to the GPU - if any - and we
|
||||
|
@ -384,7 +384,7 @@ def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
|
|||
|
||||
|
||||
def setup_early_mixins(
|
||||
policy: Policy, obs_space, action_space, config: TrainerConfigDict
|
||||
policy: Policy, obs_space, action_space, config: AlgorithmConfigDict
|
||||
) -> None:
|
||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
|
||||
|
@ -393,7 +393,7 @@ def before_loss_init(
|
|||
policy: Policy,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
config: AlgorithmConfigDict,
|
||||
) -> None:
|
||||
ComputeTDErrorMixin.__init__(policy)
|
||||
TargetNetworkMixin.__init__(policy)
|
||||
|
|
|
@ -3,9 +3,9 @@ import numpy as np
|
|||
import random
|
||||
from typing import Optional
|
||||
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
from ray.rllib.algorithms.dreamer.dreamer_torch_policy import DreamerTorchPolicy
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
|
@ -18,17 +18,17 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
||||
from ray.rllib.utils.typing import (
|
||||
PartialTrainerConfigDict,
|
||||
PartialAlgorithmConfigDict,
|
||||
SampleBatchType,
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
ResultDict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DreamerConfig(TrainerConfig):
|
||||
"""Defines a configuration class from which a Dreamer Trainer can be built.
|
||||
class DreamerConfig(AlgorithmConfig):
|
||||
"""Defines a configuration class from which a Dreamer Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.dreamer import DreamerConfig
|
||||
|
@ -36,7 +36,7 @@ class DreamerConfig(TrainerConfig):
|
|||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_rollout_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -61,7 +61,7 @@ class DreamerConfig(TrainerConfig):
|
|||
|
||||
def __init__(self):
|
||||
"""Initializes a PPOConfig instance."""
|
||||
super().__init__(trainer_class=Dreamer)
|
||||
super().__init__(algo_class=Dreamer)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -92,7 +92,7 @@ class DreamerConfig(TrainerConfig):
|
|||
"action_init_std": 5.0,
|
||||
}
|
||||
|
||||
# Override some of TrainerConfig's default values with PPO-specific values.
|
||||
# Override some of AlgorithmConfig's default values with PPO-specific values.
|
||||
# .rollouts()
|
||||
self.num_workers = 0
|
||||
self.num_envs_per_worker = 1
|
||||
|
@ -112,7 +112,7 @@ class DreamerConfig(TrainerConfig):
|
|||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
@override(TrainerConfig)
|
||||
@override(AlgorithmConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
|
@ -293,14 +293,14 @@ class DreamerIteration:
|
|||
return _postprocess_gif(gif=gif)
|
||||
|
||||
|
||||
class Dreamer(Trainer):
|
||||
class Dreamer(Algorithm):
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return DreamerConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
@override(Algorithm)
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
|
@ -323,12 +323,12 @@ class Dreamer(Trainer):
|
|||
if config["action_repeat"] > 1:
|
||||
config["horizon"] = config["horizon"] / config["action_repeat"]
|
||||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict):
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict):
|
||||
return DreamerTorchPolicy
|
||||
|
||||
@override(Trainer)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
@override(Algorithm)
|
||||
def setup(self, config: PartialAlgorithmConfigDict):
|
||||
super().setup(config)
|
||||
# `training_iteration` implementation: Setup buffer in `setup`, not
|
||||
# in `execution_plan` (deprecated).
|
||||
|
@ -344,7 +344,7 @@ class Dreamer(Trainer):
|
|||
self.local_replay_buffer.add(samples)
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
|
@ -376,7 +376,7 @@ class Dreamer(Trainer):
|
|||
)
|
||||
return rollouts
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def training_step(self) -> ResultDict:
|
||||
local_worker = self.workers.local_worker()
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from ray.rllib.utils.test_utils import framework_iterator
|
|||
|
||||
|
||||
class TestDreamer(unittest.TestCase):
|
||||
"""Sanity tests for DreamerTrainer."""
|
||||
"""Sanity tests for Dreamer."""
|
||||
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
@ -17,7 +17,7 @@ class TestDreamer(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
|
||||
def test_dreamer_compilation(self):
|
||||
"""Test whether an DreamerTrainer can be built with all frameworks."""
|
||||
"""Test whether an Dreamer can be built with all frameworks."""
|
||||
config = dreamer.DreamerConfig()
|
||||
config.environment(
|
||||
env=RandomEnv,
|
||||
|
|
|
@ -9,7 +9,7 @@ import time
|
|||
from typing import Optional
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents import Trainer, TrainerConfig
|
||||
from ray.rllib.algorithms import Algorithm, AlgorithmConfig
|
||||
from ray.rllib.algorithms.es import optimizers, utils
|
||||
from ray.rllib.algorithms.es.es_tf_policy import ESTFPolicy, rollout
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
|
@ -24,7 +24,7 @@ from ray.rllib.utils.metrics import (
|
|||
NUM_ENV_STEPS_TRAINED,
|
||||
)
|
||||
from ray.rllib.utils.torch_utils import set_torch_seed
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -41,8 +41,8 @@ Result = namedtuple(
|
|||
)
|
||||
|
||||
|
||||
class ESConfig(TrainerConfig):
|
||||
"""Defines a configuration class from which an ES Trainer can be built.
|
||||
class ESConfig(AlgorithmConfig):
|
||||
"""Defines a configuration class from which an ES Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.es import ESConfig
|
||||
|
@ -50,7 +50,7 @@ class ESConfig(TrainerConfig):
|
|||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_rollout_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -75,7 +75,7 @@ class ESConfig(TrainerConfig):
|
|||
|
||||
def __init__(self):
|
||||
"""Initializes a ESConfig instance."""
|
||||
super().__init__(trainer_class=ES)
|
||||
super().__init__(algo_class=ES)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -91,11 +91,11 @@ class ESConfig(TrainerConfig):
|
|||
self.noise_size = 250000000
|
||||
self.report_length = 10
|
||||
|
||||
# Override some of TrainerConfig's default values with ES-specific values.
|
||||
# Override some of AlgorithmConfig's default values with ES-specific values.
|
||||
self.train_batch_size = 10000
|
||||
self.num_workers = 10
|
||||
self.observation_filter = "MeanStdFilter"
|
||||
# ARS will use Trainer's evaluation WorkerSet (if evaluation_interval > 0).
|
||||
# ARS will use Algorithm's evaluation WorkerSet (if evaluation_interval > 0).
|
||||
# Therefore, we must be careful not to use more than 1 env per eval worker
|
||||
# (would break ARSPolicy's compute_single_action method) and to not do
|
||||
# obs-filtering.
|
||||
|
@ -105,7 +105,7 @@ class ESConfig(TrainerConfig):
|
|||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
@override(TrainerConfig)
|
||||
@override(AlgorithmConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
|
@ -137,7 +137,7 @@ class ESConfig(TrainerConfig):
|
|||
report_length: How many of the last rewards we average over.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -319,16 +319,16 @@ def get_policy_class(config):
|
|||
return policy_cls
|
||||
|
||||
|
||||
class ES(Trainer):
|
||||
class ES(Algorithm):
|
||||
"""Large-scale implementation of Evolution Strategies in Ray."""
|
||||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return ESConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
@override(Algorithm)
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
|
@ -348,7 +348,7 @@ class ES(Trainer):
|
|||
"`NoFilter` for ES!"
|
||||
)
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def setup(self, config):
|
||||
# Setup our config: Merge the user-supplied config (which could
|
||||
# be a partial config dict with the class' default).
|
||||
|
@ -393,7 +393,7 @@ class ES(Trainer):
|
|||
self.reward_list = []
|
||||
self.tstart = time.time()
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def get_policy(self, policy=DEFAULT_POLICY_ID):
|
||||
if policy != DEFAULT_POLICY_ID:
|
||||
raise ValueError(
|
||||
|
@ -402,7 +402,7 @@ class ES(Trainer):
|
|||
)
|
||||
return self.policy
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def step(self):
|
||||
config = self.config
|
||||
|
||||
|
@ -503,7 +503,7 @@ class ES(Trainer):
|
|||
|
||||
return result
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def compute_single_action(self, observation, *args, **kwargs):
|
||||
action, _, _ = self.policy.compute_actions([observation], update=False)
|
||||
if kwargs.get("full_fetch"):
|
||||
|
@ -514,7 +514,7 @@ class ES(Trainer):
|
|||
def compute_action(self, observation, *args, **kwargs):
|
||||
return self.compute_single_action(observation, *args, **kwargs)
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def _sync_weights_to_workers(self, *, worker_set=None, workers=None):
|
||||
# Broadcast the new policy weights to all evaluation workers.
|
||||
assert worker_set is not None
|
||||
|
@ -522,7 +522,7 @@ class ES(Trainer):
|
|||
weights = ray.put(self.policy.get_flat_weights())
|
||||
worker_set.foreach_policy(lambda p, pid: p.set_flat_weights(ray.get(weights)))
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def cleanup(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for w in self.workers:
|
||||
|
|
|
@ -7,7 +7,7 @@ from ray.rllib.utils.test_utils import check_compute_single_action, framework_it
|
|||
|
||||
class TestES(unittest.TestCase):
|
||||
def test_es_compilation(self):
|
||||
"""Test whether an ESTrainer can be built on all frameworks."""
|
||||
"""Test whether an ESAlgorithm can be built on all frameworks."""
|
||||
ray.init(num_cpus=4)
|
||||
config = es.ESConfig()
|
||||
# Keep it simple.
|
||||
|
|
|
@ -8,7 +8,8 @@ from typing import Optional, Type, List, Dict, Union, Callable, Any
|
|||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib import SampleBatch
|
||||
from ray.rllib.agents.trainer import Trainer, TrainerConfig
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer
|
||||
from ray.rllib.execution.learner_thread import LearnerThread
|
||||
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
|
||||
|
@ -38,9 +39,9 @@ from ray.rllib.utils.metrics import (
|
|||
|
||||
# from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
|
||||
from ray.rllib.utils.typing import (
|
||||
PartialTrainerConfigDict,
|
||||
PartialAlgorithmConfigDict,
|
||||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
SampleBatchType,
|
||||
T,
|
||||
)
|
||||
|
@ -55,7 +56,7 @@ from ray.types import ObjectRef
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImpalaConfig(TrainerConfig):
|
||||
class ImpalaConfig(AlgorithmConfig):
|
||||
"""Defines a configuration class from which an Impala can be built.
|
||||
|
||||
Example:
|
||||
|
@ -64,7 +65,7 @@ class ImpalaConfig(TrainerConfig):
|
|||
... .resources(num_gpus=4)\
|
||||
... .rollouts(num_rollout_workers=64)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -87,9 +88,9 @@ class ImpalaConfig(TrainerConfig):
|
|||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes a ImpalaConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or Impala)
|
||||
super().__init__(algo_class=algo_class or Impala)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -127,7 +128,7 @@ class ImpalaConfig(TrainerConfig):
|
|||
self._lr_vf = 0.0005
|
||||
self.after_train_step = None
|
||||
|
||||
# Override some of TrainerConfig's default values with ARS-specific values.
|
||||
# Override some of AlgorithmConfig's default values with ARS-specific values.
|
||||
self.rollout_fragment_length = 50
|
||||
self.train_batch_size = 500
|
||||
self.num_workers = 2
|
||||
|
@ -140,7 +141,7 @@ class ImpalaConfig(TrainerConfig):
|
|||
# Deprecated value.
|
||||
self.num_data_loader_buffers = DEPRECATED_VALUE
|
||||
|
||||
@override(TrainerConfig)
|
||||
@override(AlgorithmConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
|
@ -266,7 +267,7 @@ class ImpalaConfig(TrainerConfig):
|
|||
in flight, or enable compression in your experiment of timesteps.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -433,8 +434,8 @@ class BroadcastUpdateLearnerWeights:
|
|||
self.workers.local_worker().set_global_vars(_get_global_vars())
|
||||
|
||||
|
||||
class Impala(Trainer):
|
||||
"""Importance weighted actor/learner architecture (IMPALA) Trainer
|
||||
class Impala(Algorithm):
|
||||
"""Importance weighted actor/learner architecture (IMPALA) Algorithm
|
||||
|
||||
== Overview of data flow in IMPALA ==
|
||||
1. Policy evaluation in parallel across `num_workers` actors produces
|
||||
|
@ -448,13 +449,13 @@ class Impala(Trainer):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return ImpalaConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(
|
||||
self, config: PartialTrainerConfigDict
|
||||
self, config: PartialAlgorithmConfigDict
|
||||
) -> Optional[Type[Policy]]:
|
||||
if config["framework"] == "torch":
|
||||
if config["vtrace"]:
|
||||
|
@ -488,7 +489,7 @@ class Impala(Trainer):
|
|||
|
||||
return A3CTFPolicy
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def validate_config(self, config):
|
||||
# Call the super class' validation method first.
|
||||
super().validate_config(config)
|
||||
|
@ -536,8 +537,8 @@ class Impala(Trainer):
|
|||
)
|
||||
config["_tf_policy_handles_more_than_one_loss"] = True
|
||||
|
||||
@override(Trainer)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
@override(Algorithm)
|
||||
def setup(self, config: PartialAlgorithmConfigDict):
|
||||
super().setup(config)
|
||||
|
||||
if self.config["_disable_execution_plan_api"]:
|
||||
|
@ -607,7 +608,7 @@ class Impala(Trainer):
|
|||
self._learner_thread.start()
|
||||
self.workers_that_need_updates = set()
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def training_step(self) -> ResultDict:
|
||||
unprocessed_sample_batches = self.get_samples_from_workers()
|
||||
|
||||
|
@ -629,7 +630,7 @@ class Impala(Trainer):
|
|||
return train_results
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
|
@ -687,7 +688,7 @@ class Impala(Trainer):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls.get_default_config(), **config)
|
||||
|
||||
|
@ -887,7 +888,7 @@ class Impala(Trainer):
|
|||
# Update global vars of the local worker.
|
||||
self.workers.local_worker().set_global_vars(global_vars)
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def on_worker_failures(
|
||||
self, removed_workers: List[ActorHandle], new_workers: List[ActorHandle]
|
||||
):
|
||||
|
@ -900,7 +901,7 @@ class Impala(Trainer):
|
|||
self._sampling_actor_manager.remove_workers(removed_workers)
|
||||
self._sampling_actor_manager.add_workers(new_workers)
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
|
||||
result = super()._compile_iteration_results(
|
||||
step_ctx=step_ctx, iteration_results=iteration_results
|
||||
|
@ -915,7 +916,7 @@ class Impala(Trainer):
|
|||
class AggregatorWorker:
|
||||
"""A worker for doing tree aggregation of collected episodes"""
|
||||
|
||||
def __init__(self, config: TrainerConfigDict):
|
||||
def __init__(self, config: AlgorithmConfigDict):
|
||||
self.config = config
|
||||
self._mixin_buffer = MixInMultiAgentReplayBuffer(
|
||||
capacity=(
|
||||
|
|
|
@ -82,7 +82,7 @@ class VTraceLoss:
|
|||
bootstrap_value: A float32 tensor of shape [B].
|
||||
dist_class: action distribution class for logits.
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
config: Trainer config dict.
|
||||
config: Algorithm config dict.
|
||||
"""
|
||||
|
||||
# Compute vtrace on the CPU for better perf.
|
||||
|
|
|
@ -80,7 +80,7 @@ class VTraceLoss:
|
|||
bootstrap_value: A float32 tensor of shape [B].
|
||||
dist_class: action distribution class for logits.
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
config: Trainer config dict.
|
||||
config: Algorithm config dict.
|
||||
"""
|
||||
import ray.rllib.algorithms.impala.vtrace_torch as vtrace
|
||||
|
||||
|
|
|
@ -12,21 +12,21 @@ and the README for how to run with the multi-agent particle envs.
|
|||
import logging
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
from ray.rllib.algorithms.dqn.dqn import DQN
|
||||
from ray.rllib.algorithms.maddpg.maddpg_tf_policy import MADDPGTFPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import Deprecated, override
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class MADDPGConfig(TrainerConfig):
|
||||
"""Defines a configuration class from which a MADDPG Trainer can be built.
|
||||
class MADDPGConfig(AlgorithmConfig):
|
||||
"""Defines a configuration class from which a MADDPG Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.maddpg.maddpg import MADDPGConfig
|
||||
|
@ -44,9 +44,9 @@ class MADDPGConfig(TrainerConfig):
|
|||
>>> .resources(num_gpus=0)\
|
||||
>>> .rollouts(num_rollout_workers=4)\
|
||||
>>> .environment("CartPole-v1")
|
||||
>>> trainer = config.build()
|
||||
>>> algo = config.build()
|
||||
>>> while True:
|
||||
>>> trainer.train()
|
||||
>>> algo.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.maddpg.maddpg import MADDPGConfig
|
||||
|
@ -61,9 +61,9 @@ class MADDPGConfig(TrainerConfig):
|
|||
>>> )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes a DQNConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or MADDPG)
|
||||
super().__init__(algo_class=algo_class or MADDPG)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -97,7 +97,7 @@ class MADDPGConfig(TrainerConfig):
|
|||
self.actor_feature_reg = 0.001
|
||||
self.grad_norm_clipping = 0.5
|
||||
|
||||
# Changes to Trainer's default:
|
||||
# Changes to Algorithm's default:
|
||||
self.rollout_fragment_length = 100
|
||||
self.train_batch_size = 1024
|
||||
self.num_workers = 1
|
||||
|
@ -105,7 +105,7 @@ class MADDPGConfig(TrainerConfig):
|
|||
# fmt: on
|
||||
# __sphinx_doc_end__
|
||||
|
||||
@override(TrainerConfig)
|
||||
@override(AlgorithmConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
|
@ -201,7 +201,7 @@ class MADDPGConfig(TrainerConfig):
|
|||
value.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
|
@ -277,11 +277,11 @@ def before_learn_on_batch(multi_agent_batch, policies, train_batch_size):
|
|||
class MADDPG(DQN):
|
||||
@classmethod
|
||||
@override(DQN)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return MADDPGConfig().to_dict()
|
||||
|
||||
@override(DQN)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
"""Adds the `before_learn_on_batch` hook to the config.
|
||||
|
||||
This hook is called explicitly prior to TrainOneStep() in the execution
|
||||
|
@ -299,7 +299,7 @@ class MADDPG(DQN):
|
|||
config["before_learn_on_batch"] = f
|
||||
|
||||
@override(DQN)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||
return MADDPGTFPolicy
|
||||
|
||||
|
||||
|
|
|
@ -46,12 +46,12 @@ class TestMADDPG(unittest.TestCase):
|
|||
|
||||
# Only working for tf right now.
|
||||
for _ in framework_iterator(config, frameworks="tf"):
|
||||
trainer = config.build()
|
||||
algo = config.build()
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
results = algo.train()
|
||||
check_train_results(results)
|
||||
print(results)
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -2,8 +2,8 @@ import logging
|
|||
import numpy as np
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.common import (
|
||||
|
@ -20,20 +20,20 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
||||
from ray.rllib.utils.sgd import standardized
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict
|
||||
from ray.util.iter import from_actors, LocalIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MAMLConfig(TrainerConfig):
|
||||
"""Defines a configuration class from which a MAML Trainer can be built.
|
||||
class MAMLConfig(AlgorithmConfig):
|
||||
"""Defines a configuration class from which a MAML Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.maml import MAMLConfig
|
||||
>>> config = MAMLConfig().training(use_gae=False).resources(num_gpus=1)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -56,9 +56,9 @@ class MAMLConfig(TrainerConfig):
|
|||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes a PGConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or MAML)
|
||||
super().__init__(algo_class=algo_class or MAML)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -77,7 +77,7 @@ class MAMLConfig(TrainerConfig):
|
|||
self.inner_lr = 0.1
|
||||
self.use_meta_env = True
|
||||
|
||||
# Override some of TrainerConfig's default values with MAML-specific values.
|
||||
# Override some of AlgorithmConfig's default values with MAML-specific values.
|
||||
self.rollout_fragment_length = 200
|
||||
self.create_env_on_local_worker = True
|
||||
self.lr = 1e-3
|
||||
|
@ -136,7 +136,7 @@ class MAMLConfig(TrainerConfig):
|
|||
use_meta_env: Use Meta Env Template.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -252,14 +252,14 @@ def inner_adaptation(workers, samples):
|
|||
e.learn_on_batch.remote(samples[i])
|
||||
|
||||
|
||||
class MAML(Trainer):
|
||||
class MAML(Algorithm):
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return MAMLConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
@override(Algorithm)
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
|
@ -281,8 +281,8 @@ class MAML(Trainer):
|
|||
"(local) worker! Set `create_env_on_driver` to True."
|
||||
)
|
||||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy
|
||||
|
||||
|
@ -297,9 +297,9 @@ class MAML(Trainer):
|
|||
return MAMLEagerTFPolicy
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def execution_plan(
|
||||
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
||||
workers: WorkerSet, config: AlgorithmConfigDict, **kwargs
|
||||
) -> LocalIterator[dict]:
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
synchronous_parallel_sample,
|
||||
|
@ -21,13 +21,13 @@ from ray.rllib.utils.metrics import (
|
|||
)
|
||||
from ray.rllib.utils.typing import (
|
||||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
|
||||
|
||||
|
||||
class MARWILConfig(TrainerConfig):
|
||||
"""Defines a configuration class from which a MARWIL Trainer can be built.
|
||||
class MARWILConfig(AlgorithmConfig):
|
||||
"""Defines a configuration class from which a MARWIL Algorithm can be built.
|
||||
|
||||
|
||||
Example:
|
||||
|
@ -36,7 +36,7 @@ class MARWILConfig(TrainerConfig):
|
|||
>>> config = MARWILConfig().training(beta=1.0, lr=0.00001, gamma=0.99)\
|
||||
... .offline_data(input_=["./rllib/tests/data/cartpole/large.json"])
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build()
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -61,9 +61,9 @@ class MARWILConfig(TrainerConfig):
|
|||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes a MARWILConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or MARWIL)
|
||||
super().__init__(algo_class=algo_class or MARWIL)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -91,7 +91,7 @@ class MARWILConfig(TrainerConfig):
|
|||
self.vf_coeff = 1.0
|
||||
self.grad_clip = None
|
||||
|
||||
# Override some of TrainerConfig's default values with MARWIL-specific values.
|
||||
# Override some of AlgorithmConfig's default values with MARWIL-specific values.
|
||||
|
||||
# You should override input_ to point to an offline dataset
|
||||
# (see trainer.py and trainer_config.py).
|
||||
|
@ -114,7 +114,7 @@ class MARWILConfig(TrainerConfig):
|
|||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
@override(TrainerConfig)
|
||||
@override(AlgorithmConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
|
@ -181,7 +181,7 @@ class MARWILConfig(TrainerConfig):
|
|||
grad_clip: If specified, clip the global norm of gradients by this amount.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -206,14 +206,14 @@ class MARWILConfig(TrainerConfig):
|
|||
return self
|
||||
|
||||
|
||||
class MARWIL(Trainer):
|
||||
class MARWIL(Algorithm):
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return MARWILConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
@override(Algorithm)
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
|
@ -231,8 +231,8 @@ class MARWIL(Trainer):
|
|||
"calculate accum., discounted returns)!"
|
||||
)
|
||||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.algorithms.marwil.marwil_torch_policy import (
|
||||
MARWILTorchPolicy,
|
||||
|
@ -250,7 +250,7 @@ class MARWIL(Trainer):
|
|||
|
||||
return MARWILTF2Policy
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def training_step(self) -> ResultDict:
|
||||
# Collect SampleBatches from sample workers.
|
||||
batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||
|
|
|
@ -18,7 +18,7 @@ torch, _ = try_import_torch()
|
|||
|
||||
|
||||
class MARWILTorchPolicy(ValueNetworkMixin, PostprocessAdvantages, TorchPolicyV2):
|
||||
"""PyTorch policy class used with MarwilTrainer."""
|
||||
"""PyTorch policy class used with Marwil."""
|
||||
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(
|
||||
|
|
|
@ -31,7 +31,7 @@ class TestMARWIL(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
|
||||
def test_marwil_compilation_and_learning_from_offline_file(self):
|
||||
"""Test whether a MARWILTrainer can be built with all frameworks.
|
||||
"""Test whether a MARWILAlgorithm can be built with all frameworks.
|
||||
|
||||
Learns from a historic-data file.
|
||||
To generate this data, first run:
|
||||
|
@ -83,7 +83,7 @@ class TestMARWIL(unittest.TestCase):
|
|||
|
||||
if not learnt:
|
||||
raise ValueError(
|
||||
"MARWILTrainer did not reach {} reward from expert "
|
||||
"MARWILAlgorithm did not reach {} reward from expert "
|
||||
"offline data!".format(min_reward)
|
||||
)
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@ from typing import List, Optional, Type
|
|||
import ray
|
||||
from ray.rllib.algorithms.mbmpo.model_ensemble import DynamicsEnsembleCustomModel
|
||||
from ray.rllib.algorithms.mbmpo.utils import calculate_gae_advantages, MBMPOExploration
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.wrappers.model_vector_env import model_vector_env
|
||||
from ray.rllib.evaluation.metrics import (
|
||||
|
@ -29,14 +29,14 @@ from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
|||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
||||
from ray.rllib.utils.sgd import standardized
|
||||
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
|
||||
from ray.rllib.utils.typing import EnvType, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import EnvType, AlgorithmConfigDict
|
||||
from ray.util.iter import from_actors, LocalIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MBMPOConfig(TrainerConfig):
|
||||
"""Defines a configuration class from which an MBMPO Trainer can be built.
|
||||
class MBMPOConfig(AlgorithmConfig):
|
||||
"""Defines a configuration class from which an MBMPO Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.mbmpo import MBMPOConfig
|
||||
|
@ -44,7 +44,7 @@ class MBMPOConfig(TrainerConfig):
|
|||
... .resources(num_gpus=4)\
|
||||
... .rollouts(num_rollout_workers=64)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -67,9 +67,9 @@ class MBMPOConfig(TrainerConfig):
|
|||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes a MBMPOConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or MBMPO)
|
||||
super().__init__(algo_class=algo_class or MBMPO)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -127,7 +127,7 @@ class MBMPOConfig(TrainerConfig):
|
|||
# How many iterations through MAML per MBMPO iteration.
|
||||
self.num_maml_steps = 10
|
||||
|
||||
# Override some of TrainerConfig's default values with MBMPO-specific
|
||||
# Override some of AlgorithmConfig's default values with MBMPO-specific
|
||||
# values.
|
||||
self.batch_mode = "complete_episodes"
|
||||
# Size of batches collected from each worker.
|
||||
|
@ -149,7 +149,7 @@ class MBMPOConfig(TrainerConfig):
|
|||
self.vf_share_layers = DEPRECATED_VALUE
|
||||
self._disable_execution_plan_api = False
|
||||
|
||||
@override(TrainerConfig)
|
||||
@override(AlgorithmConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
|
@ -198,7 +198,7 @@ class MBMPOConfig(TrainerConfig):
|
|||
num_maml_steps: How many iterations through MAML per MBMPO iteration.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -360,7 +360,7 @@ def inner_adaptation(workers: WorkerSet, samples: List[SampleBatch]):
|
|||
"""Performs one gradient descend step on each remote worker.
|
||||
|
||||
Args:
|
||||
workers: The WorkerSet of the Trainer.
|
||||
workers: The WorkerSet of the Algorithm.
|
||||
samples (List[SampleBatch]): The list of SampleBatches to perform
|
||||
a training step on (one for each remote worker).
|
||||
"""
|
||||
|
@ -422,7 +422,7 @@ def sync_stats(workers: WorkerSet) -> None:
|
|||
e.foreach_policy.remote(set_func, normalizations=normalization_dict)
|
||||
|
||||
|
||||
def post_process_samples(samples, config: TrainerConfigDict):
|
||||
def post_process_samples(samples, config: AlgorithmConfigDict):
|
||||
# Instead of using NN for value function, we use regression
|
||||
split_lst = []
|
||||
for sample in samples:
|
||||
|
@ -446,10 +446,10 @@ def post_process_samples(samples, config: TrainerConfigDict):
|
|||
return samples, split_lst
|
||||
|
||||
|
||||
class MBMPO(Trainer):
|
||||
"""Model-Based Meta Policy Optimization (MB-MPO) Trainer.
|
||||
class MBMPO(Algorithm):
|
||||
"""Model-Based Meta Policy Optimization (MB-MPO) Algorithm.
|
||||
|
||||
This file defines the distributed Trainer class for model-based meta
|
||||
This file defines the distributed Algorithm class for model-based meta
|
||||
policy optimization.
|
||||
See `mbmpo_[tf|torch]_policy.py` for the definition of the policy loss.
|
||||
|
||||
|
@ -458,12 +458,12 @@ class MBMPO(Trainer):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return DEFAULT_CONFIG
|
||||
|
||||
@override(Trainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
@override(Algorithm)
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
|
@ -491,16 +491,16 @@ class MBMPO(Trainer):
|
|||
"(local) worker! Set `create_env_on_driver` to True."
|
||||
)
|
||||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||
from ray.rllib.algorithms.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
|
||||
|
||||
return MBMPOTorchPolicy
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def execution_plan(
|
||||
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
||||
workers: WorkerSet, config: AlgorithmConfigDict, **kwargs
|
||||
) -> LocalIterator[dict]:
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
|
@ -576,7 +576,7 @@ class MBMPO(Trainer):
|
|||
return train_op
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def validate_env(env: EnvType, env_context: EnvContext) -> None:
|
||||
"""Validates the local_worker's env object (after creation).
|
||||
|
||||
|
|
158
rllib/algorithms/mock.py
Normal file
158
rllib/algorithms/mock.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from ray.tune import result as tune_result
|
||||
from ray.rllib.algorithms.algorithm import Algorithm, with_common_config
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict
|
||||
|
||||
|
||||
class _MockTrainer(Algorithm):
|
||||
"""Mock trainer for use in tests"""
|
||||
|
||||
@classmethod
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return with_common_config(
|
||||
{
|
||||
"mock_error": False,
|
||||
"persistent_error": False,
|
||||
"test_variable": 1,
|
||||
"num_workers": 0,
|
||||
"user_checkpoint_freq": 0,
|
||||
"framework": "tf",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
return None
|
||||
|
||||
@override(Algorithm)
|
||||
def setup(self, config):
|
||||
# Setup our config: Merge the user-supplied config (which could
|
||||
# be a partial config dict with the class' default).
|
||||
self.config = self.merge_trainer_configs(
|
||||
self.get_default_config(), config, self._allow_unknown_configs
|
||||
)
|
||||
self.config["env"] = self._env_id
|
||||
|
||||
self.validate_config(self.config)
|
||||
self.callbacks = self.config["callbacks"]()
|
||||
|
||||
# Add needed properties.
|
||||
self.info = None
|
||||
self.restored = False
|
||||
|
||||
@override(Algorithm)
|
||||
def step(self):
|
||||
if (
|
||||
self.config["mock_error"]
|
||||
and self.iteration == 1
|
||||
and (self.config["persistent_error"] or not self.restored)
|
||||
):
|
||||
raise Exception("mock error")
|
||||
result = dict(
|
||||
episode_reward_mean=10, episode_len_mean=10, timesteps_this_iter=10, info={}
|
||||
)
|
||||
if self.config["user_checkpoint_freq"] > 0 and self.iteration > 0:
|
||||
if self.iteration % self.config["user_checkpoint_freq"] == 0:
|
||||
result.update({tune_result.SHOULD_CHECKPOINT: True})
|
||||
return result
|
||||
|
||||
@override(Algorithm)
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self.info, f)
|
||||
return path
|
||||
|
||||
@override(Algorithm)
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
with open(checkpoint_path, "rb") as f:
|
||||
info = pickle.load(f)
|
||||
self.info = info
|
||||
self.restored = True
|
||||
|
||||
@staticmethod
|
||||
@override(Algorithm)
|
||||
def _get_env_id_and_creator(env_specifier, config):
|
||||
# No env to register.
|
||||
return None, None
|
||||
|
||||
def set_info(self, info):
|
||||
self.info = info
|
||||
return info
|
||||
|
||||
def get_info(self, sess=None):
|
||||
return self.info
|
||||
|
||||
|
||||
class _SigmoidFakeData(_MockTrainer):
|
||||
"""Trainer that returns sigmoid learning curves.
|
||||
|
||||
This can be helpful for evaluating early stopping algorithms."""
|
||||
|
||||
@classmethod
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return with_common_config(
|
||||
{
|
||||
"width": 100,
|
||||
"height": 100,
|
||||
"offset": 0,
|
||||
"iter_time": 10,
|
||||
"iter_timesteps": 1,
|
||||
"num_workers": 0,
|
||||
}
|
||||
)
|
||||
|
||||
def step(self):
|
||||
i = max(0, self.iteration - self.config["offset"])
|
||||
v = np.tanh(float(i) / self.config["width"])
|
||||
v *= self.config["height"]
|
||||
return dict(
|
||||
episode_reward_mean=v,
|
||||
episode_len_mean=v,
|
||||
timesteps_this_iter=self.config["iter_timesteps"],
|
||||
time_this_iter_s=self.config["iter_time"],
|
||||
info={},
|
||||
)
|
||||
|
||||
|
||||
class _ParameterTuningTrainer(_MockTrainer):
|
||||
@classmethod
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return with_common_config(
|
||||
{
|
||||
"reward_amt": 10,
|
||||
"dummy_param": 10,
|
||||
"dummy_param2": 15,
|
||||
"iter_time": 10,
|
||||
"iter_timesteps": 1,
|
||||
"num_workers": 0,
|
||||
}
|
||||
)
|
||||
|
||||
def step(self):
|
||||
return dict(
|
||||
episode_reward_mean=self.config["reward_amt"] * self.iteration,
|
||||
episode_len_mean=self.config["reward_amt"],
|
||||
timesteps_this_iter=self.config["iter_timesteps"],
|
||||
time_this_iter_s=self.config["iter_time"],
|
||||
info={},
|
||||
)
|
||||
|
||||
|
||||
def _algorithm_import_failed(trace):
|
||||
"""Returns dummy Algorithm class for if PyTorch etc. is not installed."""
|
||||
|
||||
class _TrainerImportFailed(Algorithm):
|
||||
_name = "TrainerImportFailed"
|
||||
|
||||
def setup(self, config):
|
||||
raise ImportError(trace)
|
||||
|
||||
return _TrainerImportFailed
|
|
@ -1,21 +1,21 @@
|
|||
from typing import Type
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict
|
||||
|
||||
|
||||
class PGConfig(TrainerConfig):
|
||||
"""Defines a configuration class from which a PG Trainer can be built.
|
||||
class PGConfig(AlgorithmConfig):
|
||||
"""Defines a configuration class from which a PG Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.pg import PGConfig
|
||||
>>> config = PGConfig().training(lr=0.01).resources(num_gpus=1)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -41,11 +41,11 @@ class PGConfig(TrainerConfig):
|
|||
|
||||
def __init__(self):
|
||||
"""Initializes a PGConfig instance."""
|
||||
super().__init__(trainer_class=PG)
|
||||
super().__init__(algo_class=PG)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# Override some of TrainerConfig's default values with PG-specific values.
|
||||
# Override some of AlgorithmConfig's default values with PG-specific values.
|
||||
self.num_workers = 0
|
||||
self.lr = 0.0004
|
||||
self._disable_preprocessor_api = True
|
||||
|
@ -53,7 +53,7 @@ class PGConfig(TrainerConfig):
|
|||
# fmt: on
|
||||
|
||||
|
||||
class PG(Trainer):
|
||||
class PG(Algorithm):
|
||||
"""Policy Gradient (PG) Trainer.
|
||||
|
||||
Defines the distributed Trainer class for policy gradients.
|
||||
|
@ -69,11 +69,11 @@ class PG(Trainer):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return PGConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(self, config) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.algorithms.pg.pg_torch_policy import PGTorchPolicy
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Proximal Policy Optimization (PPO)
|
||||
==================================
|
||||
|
||||
This file defines the distributed Trainer class for proximal policy
|
||||
This file defines the distributed Algorithm class for proximal policy
|
||||
optimization.
|
||||
See `ppo_[tf|torch]_policy.py` for the definition of the policy loss.
|
||||
|
||||
|
@ -14,8 +14,8 @@ from typing import List, Optional, Type, Union
|
|||
import math
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.algorithms.algorithm import Algorithm
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
standardize_fields,
|
||||
)
|
||||
|
@ -27,9 +27,13 @@ from ray.rllib.utils.annotations import ExperimentalAPI
|
|||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.deprecation import (
|
||||
Deprecated,
|
||||
DEPRECATED_VALUE,
|
||||
deprecation_warning,
|
||||
)
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, ResultDict
|
||||
from ray.rllib.utils.typing import AlgorithmConfigDict, ResultDict
|
||||
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
|
||||
from ray.rllib.utils.metrics import (
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
|
@ -40,8 +44,8 @@ from ray.rllib.utils.metrics import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PPOConfig(TrainerConfig):
|
||||
"""Defines a configuration class from which a PPO Trainer can be built.
|
||||
class PPOConfig(AlgorithmConfig):
|
||||
"""Defines a configuration class from which a PPO Algorithm can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
@ -49,7 +53,7 @@ class PPOConfig(TrainerConfig):
|
|||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
|
@ -72,9 +76,9 @@ class PPOConfig(TrainerConfig):
|
|||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
def __init__(self, algo_class=None):
|
||||
"""Initializes a PPOConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or PPO)
|
||||
super().__init__(algo_class=algo_class or PPO)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -95,7 +99,7 @@ class PPOConfig(TrainerConfig):
|
|||
self.grad_clip = None
|
||||
self.kl_target = 0.01
|
||||
|
||||
# Override some of TrainerConfig's default values with PPO-specific values.
|
||||
# Override some of AlgorithmConfig's default values with PPO-specific values.
|
||||
self.rollout_fragment_length = 200
|
||||
self.train_batch_size = 4000
|
||||
self.lr = 5e-5
|
||||
|
@ -103,7 +107,10 @@ class PPOConfig(TrainerConfig):
|
|||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
@override(TrainerConfig)
|
||||
# Deprecated keys.
|
||||
self.vf_share_layers = DEPRECATED_VALUE
|
||||
|
||||
@override(AlgorithmConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
|
@ -122,6 +129,8 @@ class PPOConfig(TrainerConfig):
|
|||
vf_clip_param: Optional[float] = None,
|
||||
grad_clip: Optional[float] = None,
|
||||
kl_target: Optional[float] = None,
|
||||
# Deprecated.
|
||||
vf_share_layers=None,
|
||||
**kwargs,
|
||||
) -> "PPOConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
@ -155,7 +164,7 @@ class PPOConfig(TrainerConfig):
|
|||
kl_target: Target value for KL divergence.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
This updated AlgorithmConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
@ -191,6 +200,14 @@ class PPOConfig(TrainerConfig):
|
|||
if kl_target is not None:
|
||||
self.kl_target = kl_target
|
||||
|
||||
if vf_share_layers is not None:
|
||||
self.model["vf_share_layers"] = vf_share_layers
|
||||
deprecation_warning(
|
||||
old="ppo.DEFAULT_CONFIG['vf_share_layers']",
|
||||
new="PPOConfig().training(model={'vf_share_layers': ...})",
|
||||
error=False,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
|
@ -268,20 +285,20 @@ def warn_about_bad_reward_scales(config, result):
|
|||
return result
|
||||
|
||||
|
||||
class PPO(Trainer):
|
||||
# TODO: Change the return value of this method to return a TrainerConfig object
|
||||
class PPO(Algorithm):
|
||||
# TODO: Change the return value of this method to return a AlgorithmConfig object
|
||||
# instead.
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
return PPOConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
"""Validates the Trainer's config dict.
|
||||
@override(Algorithm)
|
||||
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
||||
"""Validates the Algorithm's config dict.
|
||||
|
||||
Args:
|
||||
config: The Trainer's config to check.
|
||||
config: The Algorithm's config to check.
|
||||
|
||||
Raises:
|
||||
ValueError: In case something is wrong with the config.
|
||||
|
@ -364,8 +381,8 @@ class PPO(Trainer):
|
|||
"simple_optimizer=True if this doesn't work for you."
|
||||
)
|
||||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue