[rllib] Rename Agent to Trainer (#4556)

This commit is contained in:
Eric Liang 2019-04-07 00:36:18 -07:00 committed by GitHub
parent 820c71b7d0
commit 37208216ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
63 changed files with 1212 additions and 1092 deletions

View file

@ -93,8 +93,8 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin
rllib.rst
rllib-training.rst
rllib-env.rst
rllib-algorithms.rst
rllib-models.rst
rllib-algorithms.rst
rllib-offline.rst
rllib-dev.rst
rllib-concepts.rst

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 72 KiB

After

Width:  |  Height:  |  Size: 76 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 75 KiB

After

Width:  |  Height:  |  Size: 75 KiB

View file

@ -31,13 +31,13 @@ Contributing Algorithms
These are the guidelines for merging new algorithms into RLlib:
* Contributed algorithms (`rllib/contrib <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib>`__):
- must subclass Agent and implement the ``_train()`` method
- must subclass Trainer and implement the ``_train()`` method
- must include a lightweight test (`example <https://github.com/ray-project/ray/blob/6bb110393008c9800177490688c6ed38b2da52a9/test/jenkins_tests/run_multi_node_tests.sh#L45>`__) to ensure the algorithm runs
- should include tuned hyperparameter examples and documentation
- should offer functionality not present in existing algorithms
* Fully integrated algorithms (`rllib/agents <https://github.com/ray-project/ray/tree/master/python/ray/rllib/agents>`__) have the following additional requirements:
- must fully implement the Agent API
- must fully implement the Trainer API
- must offer substantial new functionality not possible to add to other algorithms
- should support custom models and preprocessors
- should use RLlib abstractions and support distributed execution
@ -46,14 +46,14 @@ Both integrated and contributed algorithms ship with the ``ray`` PyPI package, a
How to add an algorithm to ``contrib``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
It takes just two changes to add an algorithm to `contrib <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib>`__. A minimal example can be found `here <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib/random_agent/random_agent.py>`__. First, subclass `Agent <https://github.com/ray-project/ray/tree/master/python/ray/rllib/agents/agent.py>`__ and implement the ``_init`` and ``_train`` methods:
It takes just two changes to add an algorithm to `contrib <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib>`__. A minimal example can be found `here <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib/random_agent/random_agent.py>`__. First, subclass `Trainer <https://github.com/ray-project/ray/tree/master/python/ray/rllib/agents/agent.py>`__ and implement the ``_init`` and ``_train`` methods:
.. literalinclude:: ../../python/ray/rllib/contrib/random_agent/random_agent.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
Second, register the agent with a name in `contrib/registry.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/contrib/registry.py>`__.
Second, register the trainer with a name in `contrib/registry.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/contrib/registry.py>`__.
.. code-block:: python
@ -66,12 +66,12 @@ Second, register the agent with a name in `contrib/registry.py <https://github.c
return RandomAgent2
CONTRIBUTED_ALGORITHMS = {
"contrib/RandomAgent": _import_random_agent,
"contrib/RandomAgent2": _import_random_agent_2,
"contrib/RandomAgent": _import_random_trainer,
"contrib/RandomAgent2": _import_random_trainer_2,
# ...
}
After registration, you can run and visualize agent progress using ``rllib train``:
After registration, you can run and visualize training progress using ``rllib train``:
.. code-block:: bash

View file

@ -26,7 +26,7 @@ MARWIL **Yes** `+parametric`_ **Yes** **Yes** **Yes**
.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces
You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name <https://gym.openai.com/envs>`__. Custom env classes passed directly to the agent must take a single ``env_config`` parameter in their constructor:
You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name <https://gym.openai.com/envs>`__. Custom env classes passed directly to the trainer must take a single ``env_config`` parameter in their constructor:
.. code-block:: python
@ -43,7 +43,7 @@ You can pass either a string name or a Python class to specify an environment. B
return <obs>, <reward: float>, <done: bool>, <info: dict>
ray.init()
trainer = ppo.PPOAgent(env=MyEnv, config={
trainer = ppo.PPOTrainer(env=MyEnv, config={
"env_config": {}, # config to pass to env class
})
@ -60,7 +60,7 @@ You can also register a custom env creator function with a string name. This fun
return MyEnv(...) # return an env instance
register_env("my_env", env_creator)
trainer = ppo.PPOAgent(env="my_env")
trainer = ppo.PPOTrainer(env="my_env")
For a full runnable code example using the custom environment API, see `custom_env.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__.
@ -71,7 +71,7 @@ For a full runnable code example using the custom environment API, see `custom_e
Configuring Environments
------------------------
In the above example, note that the ``env_creator`` function takes in an ``env_config`` object. This is a dict containing options passed in through your agent. You can also access ``env_config.worker_index`` and ``env_config.vector_index`` to get the worker id and env id within the worker (if ``num_envs_per_worker > 0``). This can be useful if you want to train over an ensemble of different environments, for example:
In the above example, note that the ``env_creator`` function takes in an ``env_config`` object. This is a dict containing options passed in through your trainer. You can also access ``env_config.worker_index`` and ``env_config.vector_index`` to get the worker id and env id within the worker (if ``num_envs_per_worker > 0``). This can be useful if you want to train over an ensemble of different environments, for example:
.. code-block:: python
@ -128,7 +128,7 @@ Multi-Agent and Hierarchical
.. note::
Learn more about multi-agent reinforcement learning in RLlib by reading the `blog post <https://bair.berkeley.edu/blog/2018/12/12/rllib/>`__.
Learn more about multi-agent reinforcement learning in RLlib by checking out some of the `code examples <rllib-examples.html#multi-agent-and-hierarchical>`__ or reading the `blog post <https://bair.berkeley.edu/blog/2018/12/12/rllib/>`__.
A multi-agent environment is one which has multiple acting entities per step, e.g., in a traffic simulation, there may be multiple "car" and "traffic light" agents in the environment. The model for multi-agent in RLlib as follows: (1) as a user you define the number of policies available up front, and (2) a function that maps agent ids to policy ids. This is summarized by the below figure:

View file

@ -127,7 +127,7 @@ Custom TF models should subclass the common RLlib `model class <https://github.c
ModelCatalog.register_custom_model("my_model", MyModelClass)
ray.init()
agent = ppo.PPOAgent(env="CartPole-v0", config={
trainer = ppo.PPOTrainer(env="CartPole-v0", config={
"model": {
"custom_model": "my_model",
"custom_options": {}, # extra options to pass to your model
@ -220,7 +220,7 @@ Similarly, you can create and register custom PyTorch models for use with PyTorc
ModelCatalog.register_custom_model("my_model", CustomTorchModel)
ray.init()
agent = a3c.A2CAgent(env="CartPole-v0", config={
trainer = a3c.A2CTrainer(env="CartPole-v0", config={
"use_pytorch": True,
"model": {
"custom_model": "my_model",
@ -249,7 +249,7 @@ Custom preprocessors should subclass the RLlib `preprocessor class <https://gith
ModelCatalog.register_custom_preprocessor("my_prep", MyPreprocessorClass)
ray.init()
agent = ppo.PPOAgent(env="CartPole-v0", config={
trainer = ppo.PPOTrainer(env="CartPole-v0", config={
"model": {
"custom_preprocessor": "my_prep",
"custom_options": {}, # extra options to pass to your preprocessor
@ -315,7 +315,7 @@ Depending on your use case it may make sense to use just the masking, just actio
Customizing Policy Graphs
-------------------------
For deeper customization of algorithms, you can modify the policy graphs of the agent classes. Here's an example of extending the DDPG policy graph to specify custom sub-network modules:
For deeper customization of algorithms, you can modify the policy graphs of the trainer classes. Here's an example of extending the DDPG policy graph to specify custom sub-network modules:
.. code-block:: python
@ -349,15 +349,15 @@ For deeper customization of algorithms, you can modify the policy graphs of the
self.config["critic_hiddens"],
self.config["critic_hidden_activation"]).value
Then, you can create an agent with your custom policy graph by:
Then, you can create an trainer with your custom policy graph by:
.. code-block:: python
from ray.rllib.agents.ddpg.ddpg import DDPGAgent
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer
from custom_policy_graph import CustomDDPGPolicyGraph
DDPGAgent._policy_graph = CustomDDPGPolicyGraph
agent = DDPGAgent(...)
DDPGTrainer._policy_graph = CustomDDPGPolicyGraph
trainer = DDPGTrainer(...)
In this example we overrode existing methods of the existing DDPG policy graph, i.e., `_build_q_network`, `_build_p_network`, `_build_action_network`, `_build_actor_critic_loss`, but you can also replace the entire graph class entirely.

View file

@ -69,13 +69,13 @@ This example plot shows the Q-value metric in addition to importance sampling (I
.. code-block:: python
agent = DQNAgent(...)
... # train agent offline
trainer = DQNTrainer(...)
... # train policy offline
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
estimator = WeightedImportanceSamplingEstimator(agent.get_policy(), gamma=0.99)
estimator = WeightedImportanceSamplingEstimator(trainer.get_policy(), gamma=0.99)
reader = JsonReader("/path/to/data")
for _ in range(1000):
batch = reader.next()
@ -155,7 +155,7 @@ Input API
You can configure experience input for an agent using the following options:
.. literalinclude:: ../../python/ray/rllib/agents/agent.py
.. literalinclude:: ../../python/ray/rllib/agents/trainer.py
:language: python
:start-after: === Offline Datasets ===
:end-before: Specify where experiences should be saved
@ -170,7 +170,7 @@ Output API
You can configure experience output for an agent using the following options:
.. literalinclude:: ../../python/ray/rllib/agents/agent.py
.. literalinclude:: ../../python/ray/rllib/agents/trainer.py
:language: python
:start-after: shuffle_buffer_size
:end-before: === Multiagent ===

View file

@ -7,16 +7,19 @@ ray.rllib.agents
.. automodule:: ray.rllib.agents
:members:
.. autoclass:: ray.rllib.agents.a3c.A2CAgent
.. autoclass:: ray.rllib.agents.a3c.A3CAgent
.. autoclass:: ray.rllib.agents.ddpg.ApexDDPGAgent
.. autoclass:: ray.rllib.agents.ddpg.DDPGAgent
.. autoclass:: ray.rllib.agents.dqn.ApexAgent
.. autoclass:: ray.rllib.agents.dqn.DQNAgent
.. autoclass:: ray.rllib.agents.es.ESAgent
.. autoclass:: ray.rllib.agents.pg.PGAgent
.. autoclass:: ray.rllib.agents.impala.ImpalaAgent
.. autoclass:: ray.rllib.agents.ppo.PPOAgent
.. autoclass:: ray.rllib.agents.a3c.A2CTrainer
.. autoclass:: ray.rllib.agents.a3c.A3CTrainer
.. autoclass:: ray.rllib.agents.ddpg.ApexDDPGTrainer
.. autoclass:: ray.rllib.agents.ddpg.DDPGTrainer
.. autoclass:: ray.rllib.agents.dqn.ApexTrainer
.. autoclass:: ray.rllib.agents.dqn.DQNTrainer
.. autoclass:: ray.rllib.agents.es.ESTrainer
.. autoclass:: ray.rllib.agents.pg.PGTrainer
.. autoclass:: ray.rllib.agents.impala.ImpalaTrainer
.. autoclass:: ray.rllib.agents.ppo.APPOTrainer
.. autoclass:: ray.rllib.agents.ppo.PPOTrainer
.. autoclass:: ray.rllib.agents.marwil.MARWILTrainer
ray.rllib.env
-------------

View file

@ -4,13 +4,13 @@ RLlib Training APIs
Getting Started
---------------
At a high level, RLlib provides an ``Agent`` class which
holds a policy for environment interaction. Through the agent interface, the policy can
be trained, checkpointed, or an action computed.
At a high level, RLlib provides an ``Trainer`` class which
holds a policy for environment interaction. Through the trainer interface, the policy can
be trained, checkpointed, or an action computed. In multi-agent training, the trainer manages the querying and optimization of multiple policies at once.
.. image:: rllib-api.svg
You can train a simple DQN agent with the following command:
You can train a simple DQN trainer with the following command:
.. code-block:: bash
@ -39,15 +39,15 @@ with ``--env`` (any OpenAI gym environment including ones registered by the user
can be used) and for choosing the algorithm with ``--run``
(available options are ``PPO``, ``PG``, ``A2C``, ``A3C``, ``IMPALA``, ``ES``, ``DDPG``, ``DQN``, ``MARWIL``, ``APEX``, and ``APEX_DDPG``).
Evaluating Trained Agents
~~~~~~~~~~~~~~~~~~~~~~~~~
Evaluating Trained Policies
~~~~~~~~~~~~~~~~~~~~~~~~~~~
In order to save checkpoints from which to evaluate agents,
In order to save checkpoints from which to evaluate policies,
set ``--checkpoint-freq`` (number of training iterations between checkpoints)
when running ``rllib train``.
An example of evaluating a previously trained DQN agent is as follows:
An example of evaluating a previously trained DQN policy is as follows:
.. code-block:: bash
@ -55,7 +55,7 @@ An example of evaluating a previously trained DQN agent is as follows:
~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1 \
--run DQN --env CartPole-v0 --steps 10000
The ``rollout.py`` helper script reconstructs a DQN agent from the checkpoint
The ``rollout.py`` helper script reconstructs a DQN policy from the checkpoint
located at ``~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1``
and renders its behavior in the environment specified by ``--env``.
@ -65,7 +65,7 @@ Configuration
Specifying Parameters
~~~~~~~~~~~~~~~~~~~~~
Each algorithm has specific hyperparameters that can be set with ``--config``, in addition to a number of `common hyperparameters <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/agent.py>`__. See the
Each algorithm has specific hyperparameters that can be set with ``--config``, in addition to a number of `common hyperparameters <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/trainer.py>`__. See the
`algorithms documentation <rllib-algorithms.html>`__ for more information.
In an example below, we train A2C by specifying 8 workers through the config flag.
@ -77,16 +77,16 @@ In an example below, we train A2C by specifying 8 workers through the config fla
Specifying Resources
~~~~~~~~~~~~~~~~~~~~
You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most agents. The number of GPUs the driver should use can be set via the ``num_gpus`` option. Similarly, the resource allocation to workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five agents onto one GPU by setting ``num_gpus: 0.2``.
You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most algorithms. The number of GPUs the driver should use can be set via the ``num_gpus`` option. Similarly, the resource allocation to workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five trainers onto one GPU by setting ``num_gpus: 0.2``.
.. image:: rllib-config.svg
Common Parameters
~~~~~~~~~~~~~~~~~
The following is a list of the common agent hyperparameters:
The following is a list of the common algorithm hyperparameters:
.. literalinclude:: ../../python/ray/rllib/agents/agent.py
.. literalinclude:: ../../python/ray/rllib/agents/trainer.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
@ -122,25 +122,25 @@ Here is an example of the basic usage (for a more complete example, see `custom_
config = ppo.DEFAULT_CONFIG.copy()
config["num_gpus"] = 0
config["num_workers"] = 1
agent = ppo.PPOAgent(config=config, env="CartPole-v0")
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
# Can optionally call agent.restore(path) to load a checkpoint.
# Can optionally call trainer.restore(path) to load a checkpoint.
for i in range(1000):
# Perform one iteration of training the policy with PPO
result = agent.train()
result = trainer.train()
print(pretty_print(result))
if i % 100 == 0:
checkpoint = agent.save()
checkpoint = trainer.save()
print("checkpoint saved at", checkpoint)
.. note::
It's recommended that you run RLlib agents with `Tune <tune.html>`__, for easy experiment management and visualization of results. Just set ``"run": AGENT_NAME, "env": ENV_NAME`` in the experiment config.
It's recommended that you run RLlib trainers with `Tune <tune.html>`__, for easy experiment management and visualization of results. Just set ``"run": ALG_NAME, "env": ENV_NAME`` in the experiment config.
All RLlib agents are compatible with the `Tune API <tune-usage.html>`__. This enables them to be easily used in experiments with `Tune <tune.html>`__. For example, the following code performs a simple hyperparam sweep of PPO:
All RLlib trainers are compatible with the `Tune API <tune-usage.html>`__. This enables them to be easily used in experiments with `Tune <tune.html>`__. For example, the following code performs a simple hyperparam sweep of PPO:
.. code-block:: python
@ -176,27 +176,27 @@ Tune will schedule the trials to run in parallel on your Ray cluster:
Custom Training Workflows
~~~~~~~~~~~~~~~~~~~~~~~~~
In the `basic training example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__, Tune will call ``train()`` on your agent once per iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports `custom trainable functions <tune-usage.html#training-api>`__ that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_train_fn.py>`__.
In the `basic training example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__, Tune will call ``train()`` on your trainer once per iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports `custom trainable functions <tune-usage.html#training-api>`__ that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_train_fn.py>`__.
Accessing Policy State
~~~~~~~~~~~~~~~~~~~~~~
It is common to need to access an agent's internal state, e.g., to set or get internal weights. In RLlib an agent's state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``agent.optimizer.foreach_evaluator()`` or ``agent.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list.
It is common to need to access a trainer's internal state, e.g., to set or get internal weights. In RLlib trainer state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``trainer.optimizer.foreach_evaluator()`` or ``trainer.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list.
You can also access just the "master" copy of the agent state through ``agent.get_policy()`` or ``agent.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``agent.get_policy().get_weights()``. This is also equivalent to ``agent.local_evaluator.policy_map["default_policy"].get_weights()``:
You can also access just the "master" copy of the trainer state through ``trainer.get_policy()`` or ``trainer.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``trainer.get_policy().get_weights()``. This is also equivalent to ``trainer.local_evaluator.policy_map["default_policy"].get_weights()``:
.. code-block:: python
# Get weights of the default local policy
agent.get_policy().get_weights()
trainer.get_policy().get_weights()
# Same as above
agent.local_evaluator.policy_map["default_policy"].get_weights()
trainer.local_evaluator.policy_map["default_policy"].get_weights()
# Get list of weights of each evaluator, including remote replicas
agent.optimizer.foreach_evaluator(lambda ev: ev.get_policy().get_weights())
trainer.optimizer.foreach_evaluator(lambda ev: ev.get_policy().get_weights())
# Same as above
agent.optimizer.foreach_evaluator_with_index(lambda ev, i: ev.get_policy().get_weights())
trainer.optimizer.foreach_evaluator_with_index(lambda ev, i: ev.get_policy().get_weights())
Global Coordination
~~~~~~~~~~~~~~~~~~~
@ -252,8 +252,8 @@ You can provide callback functions to be called at points during policy evaluati
episode.custom_metrics["pole_angle"] = pole_angle
def on_train_result(info):
print("agent.train() result: {} -> {} episodes".format(
info["agent"].__name__, info["result"]["episodes_this_iter"]))
print("trainer.train() result: {} -> {} episodes".format(
info["trainer"].__name__, info["result"]["episodes_this_iter"]))
ray.init()
trials = tune.run(
@ -278,18 +278,18 @@ Example: Curriculum Learning
Let's look at two ways to use the above APIs to implement `curriculum learning <https://bair.berkeley.edu/blog/2017/12/20/reverse-curriculum/>`__. In curriculum learning, the agent task is adjusted over time to improve the learning process. Suppose that we have an environment class with a ``set_phase()`` method that we can call to adjust the task difficulty over time:
Approach 1: Use the Agent API and update the environment between calls to ``train()``. This example shows the agent being run inside a Tune function:
Approach 1: Use the Trainer API and update the environment between calls to ``train()``. This example shows the trainer being run inside a Tune function:
.. code-block:: python
import ray
from ray import tune
from ray.rllib.agents.ppo import PPOAgent
from ray.rllib.agents.ppo import PPOTrainer
def train(config, reporter):
agent = PPOAgent(config=config, env=YourEnv)
trainer = PPOTrainer(config=config, env=YourEnv)
while True:
result = agent.train()
result = trainer.train()
reporter(**result)
if result["episode_reward_mean"] > 200:
phase = 2
@ -297,7 +297,7 @@ Approach 1: Use the Agent API and update the environment between calls to ``trai
phase = 1
else:
phase = 0
agent.optimizer.foreach_evaluator(
trainer.optimizer.foreach_evaluator(
lambda ev: ev.foreach_env(
lambda env: env.set_phase(phase)))
@ -330,8 +330,8 @@ Approach 2: Use the callbacks API to update the environment on new training resu
phase = 1
else:
phase = 0
agent = info["agent"]
agent.optimizer.foreach_evaluator(
trainer = info["trainer"]
trainer.optimizer.foreach_evaluator(
lambda ev: ev.foreach_env(
lambda env: env.set_phase(phase)))
@ -382,7 +382,7 @@ You can use the `data output API <rllib-offline.html>`__ to save episode traces
Log Verbosity
~~~~~~~~~~~~~
You can control the agent log level via the ``"log_level"`` flag. Valid values are "INFO" (default), "DEBUG", "WARN", and "ERROR". This can be used to increase or decrease the verbosity of internal logging. For example:
You can control the trainer log level via the ``"log_level"`` flag. Valid values are "INFO" (default), "DEBUG", "WARN", and "ERROR". This can be used to increase or decrease the verbosity of internal logging. For example:
.. code-block:: bash

View file

@ -42,6 +42,16 @@ Environments
* `Interfacing with External Agents <rllib-env.html#interfacing-with-external-agents>`__
* `Advanced Integrations <rllib-env.html#advanced-integrations>`__
Models and Preprocessors
------------------------
* `RLlib Models and Preprocessors Overview <rllib-models.html>`__
* `Custom Models (TensorFlow) <rllib-models.html#custom-models-tensorflow>`__
* `Custom Models (PyTorch) <rllib-models.html#custom-models-pytorch>`__
* `Custom Preprocessors <rllib-models.html#custom-preprocessors>`__
* `Supervised Model Losses <rllib-models.html#supervised-model-losses>`__
* `Variable-length / Parametric Action Spaces <rllib-models.html#variable-length-parametric-action-spaces>`__
* `Customizing Policy Graphs <rllib-models.html#customizing-policy-graphs>`__
Algorithms
----------
@ -79,16 +89,6 @@ Algorithms
- `Advantage Re-Weighted Imitation Learning (MARWIL) <rllib-algorithms.html#advantage-re-weighted-imitation-learning-marwil>`__
Models and Preprocessors
------------------------
* `RLlib Models and Preprocessors Overview <rllib-models.html>`__
* `Custom Models (TensorFlow) <rllib-models.html#custom-models-tensorflow>`__
* `Custom Models (PyTorch) <rllib-models.html#custom-models-pytorch>`__
* `Custom Preprocessors <rllib-models.html#custom-preprocessors>`__
* `Supervised Model Losses <rllib-models.html#supervised-model-losses>`__
* `Variable-length / Parametric Action Spaces <rllib-models.html#variable-length-parametric-action-spaces>`__
* `Customizing Policy Graphs <rllib-models.html#customizing-policy-graphs>`__
Offline Datasets
----------------
* `Working with Offline Datasets <rllib-offline.html>`__

View file

@ -1,3 +1,4 @@
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.agent import Agent
__all__ = ["Agent", "with_common_config"]
__all__ = ["Agent", "Trainer", "with_common_config"]

View file

@ -1,4 +1,10 @@
from ray.rllib.agents.a3c.a3c import A3CAgent, DEFAULT_CONFIG
from ray.rllib.agents.a3c.a2c import A2CAgent
from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG
from ray.rllib.agents.a3c.a2c import A2CTrainer
from ray.rllib.utils import renamed_class
__all__ = ["A2CAgent", "A3CAgent", "DEFAULT_CONFIG"]
A2CAgent = renamed_class(A2CTrainer)
A3CAgent = renamed_class(A3CTrainer)
__all__ = [
"A2CAgent", "A3CAgent", "A2CTrainer", "A3CTrainer", "DEFAULT_CONFIG"
]

View file

@ -2,7 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.a3c.a3c import A3CAgent, DEFAULT_CONFIG as A3C_CONFIG
from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG as A3C_CONFIG
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.utils.annotations import override
from ray.rllib.utils import merge_dicts
@ -17,13 +17,13 @@ A2C_DEFAULT_CONFIG = merge_dicts(
)
class A2CAgent(A3CAgent):
"""Synchronous variant of the A3CAgent."""
class A2CTrainer(A3CTrainer):
"""Synchronous variant of the A3CTrainer."""
_agent_name = "A2C"
_name = "A2C"
_default_config = A2C_DEFAULT_CONFIG
@override(A3CAgent)
@override(A3CTrainer)
def _make_optimizer(self):
return SyncSamplesOptimizer(
self.local_evaluator, self.remote_evaluators,

View file

@ -5,7 +5,7 @@ from __future__ import print_function
import time
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.optimizers import AsyncGradientsOptimizer
from ray.rllib.utils.annotations import override
@ -38,14 +38,14 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class A3CAgent(Agent):
class A3CTrainer(Trainer):
"""A3C implementations in TensorFlow and PyTorch."""
_agent_name = "A3C"
_name = "A3C"
_default_config = DEFAULT_CONFIG
_policy_graph = A3CPolicyGraph
@override(Agent)
@override(Trainer)
def _init(self, config, env_creator):
if config["use_pytorch"]:
from ray.rllib.agents.a3c.a3c_torch_policy_graph import \
@ -63,7 +63,7 @@ class A3CAgent(Agent):
env_creator, policy_cls, config["num_workers"])
self.optimizer = self._make_optimizer()
@override(Agent)
@override(Trainer)
def _train(self):
prev_steps = self.optimizer.num_steps_sampled
start = time.time()

View file

@ -2,797 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import copy
import logging
import os
import pickle
import six
import tempfile
import tensorflow as tf
from types import FunctionType
from ray.rllib.agents.trainer import Trainer
from ray.rllib.utils import renamed_class
import ray
from ray.exceptions import RayError
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
ShuffledInput
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \
_validate_multiagent_config
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.trainable import Trainable
from ray.tune.trial import Resources, ExportFormat
from ray.tune.logger import UnifiedLogger
from ray.tune.result import DEFAULT_RESULTS_DIR
logger = logging.getLogger(__name__)
# Max number of times to retry a worker failure. We shouldn't try too many
# times in a row since that would indicate a persistent cluster issue.
MAX_WORKER_FAILURE_RETRIES = 3
# yapf: disable
# __sphinx_doc_begin__
COMMON_CONFIG = {
# === Debugging ===
# Whether to write episode stats and videos to the agent log dir
"monitor": False,
# Set the ray.rllib.* log level for the agent process and its evaluators.
# Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also
# periodically print out summaries of relevant internal dataflow (this is
# also printed out once at startup at the INFO level).
"log_level": "INFO",
# Callbacks that will be run during various phases of training. These all
# take a single "info" dict as an argument. For episode callbacks, custom
# metrics can be attached to the episode by updating the episode object's
# custom metrics dict (see examples/custom_metrics_and_callbacks.py).
"callbacks": {
"on_episode_start": None, # arg: {"env": .., "episode": ...}
"on_episode_step": None, # arg: {"env": .., "episode": ...}
"on_episode_end": None, # arg: {"env": .., "episode": ...}
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
"on_train_result": None, # arg: {"agent": ..., "result": ...}
},
# Whether to attempt to continue training if a worker crashes.
"ignore_worker_failures": False,
# === Policy ===
# Arguments to pass to model. See models/catalog.py for a full list of the
# available model options.
"model": MODEL_DEFAULTS,
# Arguments to pass to the policy optimizer. These vary by optimizer.
"optimizer": {},
# === Environment ===
# Discount factor of the MDP
"gamma": 0.99,
# Number of steps after which the episode is forced to terminate. Defaults
# to `env.spec.max_episode_steps` (if present) for Gym envs.
"horizon": None,
# Calculate rewards but don't reset the environment when the horizon is
# hit. This allows value estimation and RNN state to span across logical
# episodes denoted by horizon. This only has an effect if horizon != inf.
"soft_horizon": False,
# Arguments to pass to the env creator
"env_config": {},
# Environment name can also be passed via config
"env": None,
# Whether to clip rewards prior to experience postprocessing. Setting to
# None means clip for Atari only.
"clip_rewards": None,
# Whether to np.clip() actions to the action space low/high range spec.
"clip_actions": True,
# Whether to use rllib or deepmind preprocessors by default
"preprocessor_pref": "deepmind",
# === Resources ===
# Number of actors used for parallelism
"num_workers": 2,
# Number of GPUs to allocate to the driver. Note that not all algorithms
# can take advantage of driver GPUs. This can be fraction (e.g., 0.3 GPUs).
"num_gpus": 0,
# Number of CPUs to allocate per worker.
"num_cpus_per_worker": 1,
# Number of GPUs to allocate per worker. This can be fractional.
"num_gpus_per_worker": 0,
# Any custom resources to allocate per worker.
"custom_resources_per_worker": {},
# Number of CPUs to allocate for the driver. Note: this only takes effect
# when running in Tune.
"num_cpus_for_driver": 1,
# === Execution ===
# Number of environments to evaluate vectorwise per worker.
"num_envs_per_worker": 1,
# Default sample batch size (unroll length). Batches of this size are
# collected from workers until train_batch_size is met. When using
# multiple envs per worker, this is multiplied by num_envs_per_worker.
"sample_batch_size": 200,
# Training batch size, if applicable. Should be >= sample_batch_size.
# Samples batches will be concatenated together to this size for training.
"train_batch_size": 200,
# Whether to rollout "complete_episodes" or "truncate_episodes"
"batch_mode": "truncate_episodes",
# (Deprecated) Use a background thread for sampling (slightly off-policy)
"sample_async": False,
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter"
"observation_filter": "NoFilter",
# Whether to synchronize the statistics of remote filters.
"synchronize_filters": True,
# Configure TF for single-process operation by default
"tf_session_args": {
# note: overriden by `local_evaluator_tf_session_args`
"intra_op_parallelism_threads": 2,
"inter_op_parallelism_threads": 2,
"gpu_options": {
"allow_growth": True,
},
"log_device_placement": False,
"device_count": {
"CPU": 1
},
"allow_soft_placement": True, # required by PPO multi-gpu
},
# Override the following tf session args on the local evaluator
"local_evaluator_tf_session_args": {
# Allow a higher level of parallelism by default, but not unlimited
# since that can cause crashes with many concurrent drivers.
"intra_op_parallelism_threads": 8,
"inter_op_parallelism_threads": 8,
},
# Whether to LZ4 compress individual observations
"compress_observations": False,
# Drop metric batches from unresponsive workers after this many seconds
"collect_metrics_timeout": 180,
# Smooth metrics over this many episodes.
"metrics_smoothing_episodes": 100,
# If using num_envs_per_worker > 1, whether to create those new envs in
# remote processes instead of in the same worker. This adds overheads, but
# can make sense if your envs can take much time to step / reset
# (e.g., for StarCraft). Use this cautiously; overheads are significant.
"remote_worker_envs": False,
# Timeout that remote workers are waiting when polling environments.
# 0 (continue when at least one env is ready) is a reasonable default,
# but optimal value could be obtained by measuring your environment
# step / reset and model inference perf.
"remote_env_batch_wait_ms": 0,
# === Offline Datasets ===
# Specify how to generate experiences:
# - "sampler": generate experiences via online simulation (default)
# - a local directory or file glob expression (e.g., "/tmp/*.json")
# - a list of individual file paths/URIs (e.g., ["/tmp/1.json",
# "s3://bucket/2.json"])
# - a dict with string keys and sampling probabilities as values (e.g.,
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
# - a function that returns a rllib.offline.InputReader
"input": "sampler",
# Specify how to evaluate the current policy. This only has an effect when
# reading offline experiences. Available options:
# - "wis": the weighted step-wise importance sampling estimator.
# - "is": the step-wise importance sampling estimator.
# - "simulation": run the environment in the background, but use
# this data for evaluation only and not for learning.
"input_evaluation": ["is", "wis"],
# Whether to run postprocess_trajectory() on the trajectory fragments from
# offline inputs. Note that postprocessing will be done using the *current*
# policy, not the *behaviour* policy, which is typically undesirable for
# on-policy algorithms.
"postprocess_inputs": False,
# If positive, input batches will be shuffled via a sliding window buffer
# of this number of batches. Use this if the input data is not in random
# enough order. Input is delayed until the shuffle buffer is filled.
"shuffle_buffer_size": 0,
# Specify where experiences should be saved:
# - None: don't save any experiences
# - "logdir" to save to the agent log dir
# - a path/URI to save to a custom output directory (e.g., "s3://bucket/")
# - a function that returns a rllib.offline.OutputWriter
"output": None,
# What sample batch columns to LZ4 compress in the output data.
"output_compress_columns": ["obs", "new_obs"],
# Max output file size before rolling over to a new file.
"output_max_file_size": 64 * 1024 * 1024,
# === Multiagent ===
"multiagent": {
# Map from policy ids to tuples of (policy_graph_cls, obs_space,
# act_space, config). See policy_evaluator.py for more info.
"policy_graphs": {},
# Function mapping agent ids to policy ids.
"policy_mapping_fn": None,
# Optional whitelist of policies to train, or None for all policies.
"policies_to_train": None,
},
}
# __sphinx_doc_end__
# yapf: enable
@DeveloperAPI
def with_common_config(extra_config):
"""Returns the given config dict merged with common agent confs."""
return with_base_config(COMMON_CONFIG, extra_config)
def with_base_config(base_config, extra_config):
"""Returns the given config dict merged with a base agent conf."""
config = copy.deepcopy(base_config)
config.update(extra_config)
return config
@PublicAPI
class Agent(Trainable):
"""All RLlib agents extend this base class.
Agent objects retain internal model state between calls to train(), so
you should create a new agent instance for each training session.
Attributes:
env_creator (func): Function that creates a new training env.
config (obj): Algorithm-specific configuration data.
logdir (str): Directory in which training outputs should be placed.
"""
_allow_unknown_configs = False
_allow_unknown_subkeys = [
"tf_session_args", "env_config", "model", "optimizer", "multiagent",
"custom_resources_per_worker"
]
@PublicAPI
def __init__(self, config=None, env=None, logger_creator=None):
"""Initialize an RLLib agent.
Args:
config (dict): Algorithm-specific configuration data.
env (str): Name of the environment to use. Note that this can also
be specified as the `env` key in config.
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
config = config or {}
# Vars to synchronize to evaluators on each train call
self.global_vars = {"timestep": 0}
# Agents allow env ids to be passed directly to the constructor.
self._env_id = self._register_if_needed(env or config.get("env"))
# Create a default logger creator if no logger_creator is specified
if logger_creator is None:
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
logdir_prefix = "{}_{}_{}".format(self._agent_name, self._env_id,
timestr)
def default_logger_creator(config):
"""Creates a Unified logger with a default logdir prefix
containing the agent name and the env id
"""
if not os.path.exists(DEFAULT_RESULTS_DIR):
os.makedirs(DEFAULT_RESULTS_DIR)
logdir = tempfile.mkdtemp(
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
return UnifiedLogger(config, logdir, None)
logger_creator = default_logger_creator
Trainable.__init__(self, config, logger_creator)
@classmethod
@override(Trainable)
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
Agent._validate_config(cf)
# TODO(ekl): add custom resources here once tune supports them
return Resources(
cpu=cf["num_cpus_for_driver"],
gpu=cf["num_gpus"],
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
@override(Trainable)
@PublicAPI
def train(self):
"""Overrides super.train to synchronize global vars."""
if self._has_policy_optimizer():
self.global_vars["timestep"] = self.optimizer.num_steps_sampled
self.optimizer.local_evaluator.set_global_vars(self.global_vars)
for ev in self.optimizer.remote_evaluators:
ev.set_global_vars.remote(self.global_vars)
logger.debug("updated global vars: {}".format(self.global_vars))
result = None
for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
try:
result = Trainable.train(self)
except RayError as e:
if self.config["ignore_worker_failures"]:
logger.exception(
"Error in train call, attempting to recover")
self._try_recover()
else:
logger.info(
"Worker crashed during call to train(). To attempt to "
"continue training without the failed worker, set "
"`'ignore_worker_failures': True`.")
raise e
else:
break
if result is None:
raise RuntimeError("Failed to recover from worker crash")
if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
and hasattr(self, "local_evaluator")):
FilterManager.synchronize(
self.local_evaluator.filters,
self.remote_evaluators,
update_remote=self.config["synchronize_filters"])
logger.debug("synchronized filters: {}".format(
self.local_evaluator.filters))
if self._has_policy_optimizer():
result["num_healthy_workers"] = len(
self.optimizer.remote_evaluators)
return result
@override(Trainable)
def _log_result(self, result):
if self.config["callbacks"].get("on_train_result"):
self.config["callbacks"]["on_train_result"]({
"agent": self,
"result": result,
})
# log after the callback is invoked, so that the user has a chance
# to mutate the result
Trainable._log_result(self, result)
@override(Trainable)
def _setup(self, config):
env = self._env_id
if env:
config["env"] = env
if _global_registry.contains(ENV_CREATOR, env):
self.env_creator = _global_registry.get(ENV_CREATOR, env)
else:
import gym # soft dependency
self.env_creator = lambda env_config: gym.make(env)
else:
self.env_creator = lambda env_config: None
# Merge the supplied config with the class default
merged_config = copy.deepcopy(self._default_config)
merged_config = deep_update(merged_config, config,
self._allow_unknown_configs,
self._allow_unknown_subkeys)
self.raw_user_config = config
self.config = merged_config
Agent._validate_config(self.config)
if self.config.get("log_level"):
logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
# TODO(ekl) setting the graph is unnecessary for PyTorch agents
with tf.Graph().as_default():
self._init(self.config, self.env_creator)
@override(Trainable)
def _stop(self):
# Call stop on all evaluators to release resources
if hasattr(self, "local_evaluator"):
self.local_evaluator.stop()
if hasattr(self, "remote_evaluators"):
for ev in self.remote_evaluators:
ev.stop.remote()
# workaround for https://github.com/ray-project/ray/issues/1516
if hasattr(self, "remote_evaluators"):
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote()
if hasattr(self, "optimizer"):
self.optimizer.stop()
@override(Trainable)
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
pickle.dump(self.__getstate__(), open(checkpoint_path, "wb"))
return checkpoint_path
@override(Trainable)
def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path, "rb"))
self.__setstate__(extra_data)
@DeveloperAPI
def _init(self, config, env_creator):
"""Subclasses should override this for custom initialization."""
raise NotImplementedError
@PublicAPI
def compute_action(self,
observation,
state=None,
prev_action=None,
prev_reward=None,
info=None,
policy_id=DEFAULT_POLICY_ID):
"""Computes an action for the specified policy.
Note that you can also access the policy object through
self.get_policy(policy_id) and call compute_actions() on it directly.
Arguments:
observation (obj): observation from the environment.
state (list): RNN hidden state, if any. If state is not None,
then all of compute_single_action(...) is returned
(computed action, rnn state, logits dictionary).
Otherwise compute_single_action(...)[0] is
returned (computed action).
prev_action (obj): previous action value, if any
prev_reward (int): previous reward, if any
info (dict): info object, if any
policy_id (str): policy to query (only applies to multi-agent).
"""
if state is None:
state = []
preprocessed = self.local_evaluator.preprocessors[policy_id].transform(
observation)
filtered_obs = self.local_evaluator.filters[policy_id](
preprocessed, update=False)
if state:
return self.get_policy(policy_id).compute_single_action(
filtered_obs,
state,
prev_action,
prev_reward,
info,
clip_actions=self.config["clip_actions"])
return self.get_policy(policy_id).compute_single_action(
filtered_obs,
state,
prev_action,
prev_reward,
info,
clip_actions=self.config["clip_actions"])[0]
@property
def iteration(self):
"""Current training iter, auto-incremented with each train() call."""
return self._iteration
@property
def _agent_name(self):
"""Subclasses should override this to declare their name."""
raise NotImplementedError
@property
def _default_config(self):
"""Subclasses should override this to declare their default config."""
raise NotImplementedError
@PublicAPI
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy graph for the specified id, or None.
Arguments:
policy_id (str): id of policy graph to return.
"""
return self.local_evaluator.get_policy(policy_id)
@PublicAPI
def get_weights(self, policies=None):
"""Return a dictionary of policy ids to weights.
Arguments:
policies (list): Optional list of policies to return weights for,
or None for all policies.
"""
return self.local_evaluator.get_weights(policies)
@PublicAPI
def set_weights(self, weights):
"""Set policy weights by policy id.
Arguments:
weights (dict): Map of policy ids to weights to set.
"""
self.local_evaluator.set_weights(weights)
@DeveloperAPI
def make_local_evaluator(self,
env_creator,
policy_graph,
extra_config=None):
"""Convenience method to return configured local evaluator."""
return self._make_evaluator(
PolicyEvaluator,
env_creator,
policy_graph,
0,
merge_dicts(
# important: allow local tf to use more CPUs for optimization
merge_dicts(
self.config, {
"tf_session_args": self.
config["local_evaluator_tf_session_args"]
}),
extra_config or {}))
@DeveloperAPI
def make_remote_evaluators(self, env_creator, policy_graph, count):
"""Convenience method to return a number of remote evaluators."""
remote_args = {
"num_cpus": self.config["num_cpus_per_worker"],
"num_gpus": self.config["num_gpus_per_worker"],
"resources": self.config["custom_resources_per_worker"],
}
cls = PolicyEvaluator.as_remote(**remote_args).remote
return [
self._make_evaluator(cls, env_creator, policy_graph, i + 1,
self.config) for i in range(count)
]
@DeveloperAPI
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
"""Export policy model with given policy_id to local directory.
Arguments:
export_dir (string): Writable local directory.
policy_id (string): Optional policy id to export.
Example:
>>> agent = MyAgent()
>>> for _ in range(10):
>>> agent.train()
>>> agent.export_policy_model("/tmp/export_dir")
"""
self.local_evaluator.export_policy_model(export_dir, policy_id)
@DeveloperAPI
def export_policy_checkpoint(self,
export_dir,
filename_prefix="model",
policy_id=DEFAULT_POLICY_ID):
"""Export tensorflow policy model checkpoint to local directory.
Arguments:
export_dir (string): Writable local directory.
filename_prefix (string): file name prefix of checkpoint files.
policy_id (string): Optional policy id to export.
Example:
>>> agent = MyAgent()
>>> for _ in range(10):
>>> agent.train()
>>> agent.export_policy_checkpoint("/tmp/export_dir")
"""
self.local_evaluator.export_policy_checkpoint(
export_dir, filename_prefix, policy_id)
@DeveloperAPI
def collect_metrics(self, selected_evaluators=None):
"""Collects metrics from the remote evaluators of this agent.
This is the same data as returned by a call to train().
"""
return self.optimizer.collect_metrics(
self.config["collect_metrics_timeout"],
min_history=self.config["metrics_smoothing_episodes"],
selected_evaluators=selected_evaluators)
@classmethod
def resource_help(cls, config):
return ("\n\nYou can adjust the resource requests of RLlib agents by "
"setting `num_workers`, `num_gpus`, and other configs. See "
"the DEFAULT_CONFIG defined by each agent for more info.\n\n"
"The config of this agent is: {}".format(config))
@staticmethod
def _validate_config(config):
if "gpu" in config:
raise ValueError(
"The `gpu` config is deprecated, please use `num_gpus=0|1` "
"instead.")
if "gpu_fraction" in config:
raise ValueError(
"The `gpu_fraction` config is deprecated, please use "
"`num_gpus=<fraction>` instead.")
if "use_gpu_for_workers" in config:
raise ValueError(
"The `use_gpu_for_workers` config is deprecated, please use "
"`num_gpus_per_worker=1` instead.")
if type(config["input_evaluation"]) != list:
raise ValueError(
"`input_evaluation` must be a list of strings, got {}".format(
config["input_evaluation"]))
def _try_recover(self):
"""Try to identify and blacklist any unhealthy workers.
This method is called after an unexpected remote error is encountered
from a worker. It issues check requests to all current workers and
blacklists any that respond with error. If no healthy workers remain,
an error is raised.
"""
if not self._has_policy_optimizer():
raise NotImplementedError(
"Recovery is not supported for this algorithm")
logger.info("Health checking all workers...")
checks = []
for ev in self.optimizer.remote_evaluators:
_, obj_id = ev.sample_with_count.remote()
checks.append(obj_id)
healthy_evaluators = []
for i, obj_id in enumerate(checks):
ev = self.optimizer.remote_evaluators[i]
try:
ray.get(obj_id)
healthy_evaluators.append(ev)
logger.info("Worker {} looks healthy".format(i + 1))
except RayError:
logger.exception("Blacklisting worker {}".format(i + 1))
try:
ev.__ray_terminate__.remote()
except Exception:
logger.exception("Error terminating unhealthy worker")
if len(healthy_evaluators) < 1:
raise RuntimeError(
"Not enough healthy workers remain to continue.")
self.optimizer.reset(healthy_evaluators)
def _has_policy_optimizer(self):
return hasattr(self, "optimizer") and isinstance(
self.optimizer, PolicyOptimizer)
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
config):
def session_creator():
logger.debug("Creating TF session {}".format(
config["tf_session_args"]))
return tf.Session(
config=tf.ConfigProto(**config["tf_session_args"]))
if isinstance(config["input"], FunctionType):
input_creator = config["input"]
elif config["input"] == "sampler":
input_creator = (lambda ioctx: ioctx.default_sampler_input())
elif isinstance(config["input"], dict):
input_creator = (lambda ioctx: ShuffledInput(
MixedInput(config["input"], ioctx), config[
"shuffle_buffer_size"]))
else:
input_creator = (lambda ioctx: ShuffledInput(
JsonReader(config["input"], ioctx), config[
"shuffle_buffer_size"]))
if isinstance(config["output"], FunctionType):
output_creator = config["output"]
elif config["output"] is None:
output_creator = (lambda ioctx: NoopOutput())
elif config["output"] == "logdir":
output_creator = (lambda ioctx: JsonWriter(
ioctx.log_dir,
ioctx,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
else:
output_creator = (lambda ioctx: JsonWriter(
config["output"],
ioctx,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
if config["input"] == "sampler":
input_evaluation = []
else:
input_evaluation = config["input_evaluation"]
# Fill in the default policy graph if 'None' is specified in multiagent
if self.config["multiagent"]["policy_graphs"]:
tmp = self.config["multiagent"]["policy_graphs"]
_validate_multiagent_config(tmp, allow_none_graph=True)
for k, v in tmp.items():
if v[0] is None:
tmp[k] = (policy_graph, v[1], v[2], v[3])
policy_graph = tmp
return cls(
env_creator,
policy_graph,
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
policies_to_train=self.config["multiagent"]["policies_to_train"],
tf_session_creator=(session_creator
if config["tf_session_args"] else None),
batch_steps=config["sample_batch_size"],
batch_mode=config["batch_mode"],
episode_horizon=config["horizon"],
preprocessor_pref=config["preprocessor_pref"],
sample_async=config["sample_async"],
compress_observations=config["compress_observations"],
num_envs=config["num_envs_per_worker"],
observation_filter=config["observation_filter"],
clip_rewards=config["clip_rewards"],
clip_actions=config["clip_actions"],
env_config=config["env_config"],
model_config=config["model"],
policy_config=config,
worker_index=worker_index,
monitor_path=self.logdir if config["monitor"] else None,
log_dir=self.logdir,
log_level=config["log_level"],
callbacks=config["callbacks"],
input_creator=input_creator,
input_evaluation=input_evaluation,
output_creator=output_creator,
remote_worker_envs=config["remote_worker_envs"],
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
soft_horizon=config["soft_horizon"],
_fake_sampler=config.get("_fake_sampler", False))
@override(Trainable)
def _export_model(self, export_formats, export_dir):
ExportFormat.validate(export_formats)
exported = {}
if ExportFormat.CHECKPOINT in export_formats:
path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
self.export_policy_checkpoint(path)
exported[ExportFormat.CHECKPOINT] = path
if ExportFormat.MODEL in export_formats:
path = os.path.join(export_dir, ExportFormat.MODEL)
self.export_policy_model(path)
exported[ExportFormat.MODEL] = path
return exported
def __getstate__(self):
state = {}
if hasattr(self, "local_evaluator"):
state["evaluator"] = self.local_evaluator.save()
if hasattr(self, "optimizer") and hasattr(self.optimizer, "save"):
state["optimizer"] = self.optimizer.save()
return state
def __setstate__(self, state):
if "evaluator" in state:
self.local_evaluator.restore(state["evaluator"])
remote_state = ray.put(state["evaluator"])
for r in self.remote_evaluators:
r.restore.remote(remote_state)
if "optimizer" in state:
self.optimizer.restore(state["optimizer"])
def _register_if_needed(self, env_object):
if isinstance(env_object, six.string_types):
return env_object
elif isinstance(env_object, type):
name = env_object.__name__
register_env(name, lambda config: env_object(config))
return name
raise ValueError(
"{} is an invalid env specification. ".format(env_object) +
"You can specify a custom env as either a class "
"(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").")
Agent = renamed_class(Trainer)

View file

@ -1,3 +1,6 @@
from ray.rllib.agents.ars.ars import (ARSAgent, DEFAULT_CONFIG)
from ray.rllib.agents.ars.ars import (ARSTrainer, DEFAULT_CONFIG)
from ray.rllib.utils import renamed_class
__all__ = ["ARSAgent", "DEFAULT_CONFIG"]
ARSAgent = renamed_class(ARSTrainer)
__all__ = ["ARSAgent", "ARSTrainer", "DEFAULT_CONFIG"]

View file

@ -12,7 +12,7 @@ import numpy as np
import time
import ray
from ray.rllib.agents import Agent, with_common_config
from ray.rllib.agents import Trainer, with_common_config
from ray.rllib.agents.ars import optimizers
from ray.rllib.agents.ars import policies
@ -157,13 +157,13 @@ class Worker(object):
eval_lengths=eval_lengths)
class ARSAgent(Agent):
class ARSTrainer(Trainer):
"""Large-scale implementation of Augmented Random Search in Ray."""
_agent_name = "ARS"
_name = "ARS"
_default_config = DEFAULT_CONFIG
@override(Agent)
@override(Trainer)
def _init(self, config, env_creator):
env = env_creator(config["env_config"])
from ray.rllib import models
@ -195,7 +195,7 @@ class ARSAgent(Agent):
self.reward_list = []
self.tstart = time.time()
@override(Agent)
@override(Trainer)
def _train(self):
config = self.config
@ -291,13 +291,13 @@ class ARSAgent(Agent):
return result
@override(Agent)
@override(Trainer)
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for w in self.workers:
w.__ray_terminate__.remote()
@override(Agent)
@override(Trainer)
def compute_action(self, observation):
return self.policy.compute(observation, update=True)[0]

View file

@ -2,7 +2,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.ddpg.apex import ApexDDPGAgent
from ray.rllib.agents.ddpg.ddpg import DDPGAgent, DEFAULT_CONFIG
from ray.rllib.agents.ddpg.apex import ApexDDPGTrainer
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, DEFAULT_CONFIG
from ray.rllib.utils import renamed_class
__all__ = ["DDPGAgent", "ApexDDPGAgent", "DEFAULT_CONFIG"]
ApexDDPGAgent = renamed_class(ApexDDPGTrainer)
DDPGAgent = renamed_class(DDPGTrainer)
__all__ = [
"DDPGAgent", "ApexDDPGAgent", "DDPGTrainer", "ApexDDPGTrainer",
"DEFAULT_CONFIG"
]

View file

@ -2,7 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.ddpg.ddpg import DDPGAgent, DEFAULT_CONFIG as DDPG_CONFIG
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, \
DEFAULT_CONFIG as DDPG_CONFIG
from ray.rllib.utils.annotations import override
from ray.rllib.utils import merge_dicts
@ -32,17 +33,17 @@ APEX_DDPG_DEFAULT_CONFIG = merge_dicts(
)
class ApexDDPGAgent(DDPGAgent):
class ApexDDPGTrainer(DDPGTrainer):
"""DDPG variant that uses the Ape-X distributed policy optimizer.
By default, this is configured for a large single node (32 cores). For
running in a large cluster, increase the `num_workers` config var.
"""
_agent_name = "APEX_DDPG"
_name = "APEX_DDPG"
_default_config = APEX_DDPG_DEFAULT_CONFIG
@override(DDPGAgent)
@override(DDPGTrainer)
def update_target_if_needed(self):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \

View file

@ -2,8 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.agent import with_common_config
from ray.rllib.agents.dqn.dqn import DQNAgent
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph
from ray.rllib.utils.annotations import override
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
@ -132,13 +132,13 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class DDPGAgent(DQNAgent):
class DDPGTrainer(DQNTrainer):
"""DDPG implementation in TensorFlow."""
_agent_name = "DDPG"
_name = "DDPG"
_default_config = DEFAULT_CONFIG
_policy_graph = DDPGPolicyGraph
@override(DQNAgent)
@override(DQNTrainer)
def _make_exploration_schedule(self, worker_index):
# Override DQN's schedule to take into account `noise_scale`
if self.config["per_worker_exploration"]:

View file

@ -2,7 +2,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.dqn.apex import ApexAgent
from ray.rllib.agents.dqn.dqn import DQNAgent, DEFAULT_CONFIG
from ray.rllib.agents.dqn.apex import ApexTrainer
from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG
from ray.rllib.utils import renamed_class
__all__ = ["ApexAgent", "DQNAgent", "DEFAULT_CONFIG"]
DQNAgent = renamed_class(DQNTrainer)
ApexAgent = renamed_class(ApexTrainer)
__all__ = [
"DQNAgent", "ApexAgent", "ApexTrainer", "DQNTrainer", "DEFAULT_CONFIG"
]

View file

@ -2,7 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.dqn.dqn import DQNAgent, DEFAULT_CONFIG as DQN_CONFIG
from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_CONFIG
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
@ -36,17 +36,17 @@ APEX_DEFAULT_CONFIG = merge_dicts(
# yapf: enable
class ApexAgent(DQNAgent):
class ApexTrainer(DQNTrainer):
"""DQN variant that uses the Ape-X distributed policy optimizer.
By default, this is configured for a large single node (32 cores). For
running in a large cluster, increase the `num_workers` config var.
"""
_agent_name = "APEX"
_name = "APEX"
_default_config = APEX_DEFAULT_CONFIG
@override(DQNAgent)
@override(DQNTrainer)
def update_target_if_needed(self):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \

View file

@ -7,7 +7,7 @@ import time
from ray import tune
from ray.rllib import optimizers
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
@ -137,15 +137,15 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class DQNAgent(Agent):
class DQNTrainer(Trainer):
"""DQN implementation in TensorFlow."""
_agent_name = "DQN"
_name = "DQN"
_default_config = DEFAULT_CONFIG
_policy_graph = DQNPolicyGraph
_optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS
@override(Agent)
@override(Trainer)
def _init(self, config, env_creator):
self._validate_config()
@ -161,7 +161,7 @@ class DQNAgent(Agent):
]
for k in self._optimizer_shared_configs:
if self._agent_name != "DQN" and k in [
if self._name != "DQN" and k in [
"schedule_max_timesteps", "beta_annealing_fraction",
"final_prioritized_replay_beta"
]:
@ -238,7 +238,7 @@ class DQNAgent(Agent):
self.last_target_update_ts = 0
self.num_target_updates = 0
@override(Agent)
@override(Trainer)
def _train(self):
start_timestep = self.global_timestep
@ -326,7 +326,7 @@ class DQNAgent(Agent):
final_p=self.config["exploration_final_eps"])
def __getstate__(self):
state = Agent.__getstate__(self)
state = Trainer.__getstate__(self)
state.update({
"num_target_updates": self.num_target_updates,
"last_target_update_ts": self.last_target_update_ts,
@ -334,7 +334,7 @@ class DQNAgent(Agent):
return state
def __setstate__(self, state):
Agent.__setstate__(self, state)
Trainer.__setstate__(self, state)
self.num_target_updates = state["num_target_updates"]
self.last_target_update_ts = state["last_target_update_ts"]

View file

@ -1,3 +1,6 @@
from ray.rllib.agents.es.es import (ESAgent, DEFAULT_CONFIG)
from ray.rllib.agents.es.es import (ESTrainer, DEFAULT_CONFIG)
from ray.rllib.utils import renamed_class
__all__ = ["ESAgent", "DEFAULT_CONFIG"]
ESAgent = renamed_class(ESTrainer)
__all__ = ["ESAgent", "ESTrainer", "DEFAULT_CONFIG"]

View file

@ -11,7 +11,7 @@ import numpy as np
import time
import ray
from ray.rllib.agents import Agent, with_common_config
from ray.rllib.agents import Trainer, with_common_config
from ray.rllib.agents.es import optimizers
from ray.rllib.agents.es import policies
@ -163,13 +163,13 @@ class Worker(object):
eval_lengths=eval_lengths)
class ESAgent(Agent):
class ESTrainer(Trainer):
"""Large-scale implementation of Evolution Strategies in Ray."""
_agent_name = "ES"
_name = "ES"
_default_config = DEFAULT_CONFIG
@override(Agent)
@override(Trainer)
def _init(self, config, env_creator):
policy_params = {"action_noise_std": 0.01}
@ -200,7 +200,7 @@ class ESAgent(Agent):
self.reward_list = []
self.tstart = time.time()
@override(Agent)
@override(Trainer)
def _train(self):
config = self.config
@ -288,11 +288,11 @@ class ESAgent(Agent):
return result
@override(Agent)
@override(Trainer)
def compute_action(self, observation):
return self.policy.compute(observation, update=False)[0]
@override(Agent)
@override(Trainer)
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for w in self.workers:

View file

@ -1,3 +1,6 @@
from ray.rllib.agents.impala.impala import ImpalaAgent, DEFAULT_CONFIG
from ray.rllib.agents.impala.impala import ImpalaTrainer, DEFAULT_CONFIG
from ray.rllib.utils import renamed_class
__all__ = ["ImpalaAgent", "DEFAULT_CONFIG"]
ImpalaAgent = renamed_class(ImpalaTrainer)
__all__ = ["ImpalaAgent", "ImpalaTrainer", "DEFAULT_CONFIG"]

View file

@ -6,7 +6,7 @@ import time
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.optimizers import AsyncSamplesOptimizer
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
from ray.rllib.utils.annotations import override
@ -100,14 +100,14 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class ImpalaAgent(Agent):
class ImpalaTrainer(Trainer):
"""IMPALA implementation using DeepMind's V-trace."""
_agent_name = "IMPALA"
_name = "IMPALA"
_default_config = DEFAULT_CONFIG
_policy_graph = VTracePolicyGraph
@override(Agent)
@override(Trainer)
def _init(self, config, env_creator):
for k in OPTIMIZER_SHARED_CONFIGS:
if k not in config["optimizer"]:
@ -136,7 +136,7 @@ class ImpalaAgent(Agent):
@override(Trainable)
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
Agent._validate_config(cf)
Trainer._validate_config(cf)
return Resources(
cpu=cf["num_cpus_for_driver"],
gpu=cf["num_gpus"],
@ -144,7 +144,7 @@ class ImpalaAgent(Agent):
cf["num_aggregation_workers"],
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
@override(Agent)
@override(Trainer)
def _train(self):
prev_steps = self.optimizer.num_steps_sampled
start = time.time()

View file

@ -2,6 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.marwil.marwil import MARWILAgent, DEFAULT_CONFIG
from ray.rllib.agents.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
__all__ = ["MARWILAgent", "DEFAULT_CONFIG"]
__all__ = ["MARWILTrainer", "DEFAULT_CONFIG"]

View file

@ -2,7 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.marwil.marwil_policy_graph import MARWILPolicyGraph
from ray.rllib.optimizers import SyncBatchReplayOptimizer
from ray.rllib.utils.annotations import override
@ -39,14 +39,14 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class MARWILAgent(Agent):
class MARWILTrainer(Trainer):
"""MARWIL implementation in TensorFlow."""
_agent_name = "MARWIL"
_name = "MARWIL"
_default_config = DEFAULT_CONFIG
_policy_graph = MARWILPolicyGraph
@override(Agent)
@override(Trainer)
def _init(self, config, env_creator):
self.local_evaluator = self.make_local_evaluator(
env_creator, self._policy_graph)
@ -59,7 +59,7 @@ class MARWILAgent(Agent):
"train_batch_size": config["train_batch_size"],
})
@override(Agent)
@override(Trainer)
def _train(self):
prev_steps = self.optimizer.num_steps_sampled
fetches = self.optimizer.step()

View file

@ -6,13 +6,13 @@ import os
import pickle
import numpy as np
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.agents.trainer import Trainer, with_common_config
class _MockAgent(Agent):
"""Mock agent for use in tests"""
class _MockTrainer(Trainer):
"""Mock trainer for use in tests"""
_agent_name = "MockAgent"
_name = "MockTrainer"
_default_config = with_common_config({
"mock_error": False,
"persistent_error": False,
@ -57,12 +57,12 @@ class _MockAgent(Agent):
return self.info
class _SigmoidFakeData(_MockAgent):
"""Agent that returns sigmoid learning curves.
class _SigmoidFakeData(_MockTrainer):
"""Trainer that returns sigmoid learning curves.
This can be helpful for evaluating early stopping algorithms."""
_agent_name = "SigmoidFakeData"
_name = "SigmoidFakeData"
_default_config = with_common_config({
"width": 100,
"height": 100,
@ -84,9 +84,9 @@ class _SigmoidFakeData(_MockAgent):
info={})
class _ParameterTuningAgent(_MockAgent):
class _ParameterTuningTrainer(_MockTrainer):
_agent_name = "ParameterTuningAgent"
_name = "ParameterTuningTrainer"
_default_config = with_common_config({
"reward_amt": 10,
"dummy_param": 10,
@ -108,8 +108,8 @@ class _ParameterTuningAgent(_MockAgent):
def _agent_import_failed(trace):
"""Returns dummy agent class for if PyTorch etc. is not installed."""
class _AgentImportFailed(Agent):
_agent_name = "AgentImportFailed"
class _AgentImportFailed(Trainer):
_name = "AgentImportFailed"
_default_config = with_common_config({})
def _setup(self, config):

View file

@ -1,3 +1,6 @@
from ray.rllib.agents.pg.pg import PGAgent, DEFAULT_CONFIG
from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG
from ray.rllib.utils import renamed_class
__all__ = ["PGAgent", "DEFAULT_CONFIG"]
PGAgent = renamed_class(PGTrainer)
__all__ = ["PGAgent", "PGTrainer", "DEFAULT_CONFIG"]

View file

@ -2,7 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.optimizers import SyncSamplesOptimizer
@ -22,18 +22,18 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class PGAgent(Agent):
class PGTrainer(Trainer):
"""Simple policy gradient agent.
This is an example agent to show how to implement algorithms in RLlib.
In most cases, you will probably want to use the PPO agent instead.
"""
_agent_name = "PG"
_name = "PG"
_default_config = DEFAULT_CONFIG
_policy_graph = PGPolicyGraph
@override(Agent)
@override(Trainer)
def _init(self, config, env_creator):
if config["use_pytorch"]:
from ray.rllib.agents.pg.torch_pg_policy_graph import \
@ -51,7 +51,7 @@ class PGAgent(Agent):
self.optimizer = SyncSamplesOptimizer(
self.local_evaluator, self.remote_evaluators, optimizer_config)
@override(Agent)
@override(Trainer)
def _train(self):
prev_steps = self.optimizer.num_steps_sampled
self.optimizer.step()

View file

@ -1,4 +1,7 @@
from ray.rllib.agents.ppo.ppo import (PPOAgent, DEFAULT_CONFIG)
from ray.rllib.agents.ppo.appo import APPOAgent
from ray.rllib.agents.ppo.ppo import PPOTrainer, DEFAULT_CONFIG
from ray.rllib.agents.ppo.appo import APPOTrainer
from ray.rllib.utils import renamed_class
__all__ = ["APPOAgent", "PPOAgent", "DEFAULT_CONFIG"]
PPOAgent = renamed_class(PPOTrainer)
__all__ = ["PPOAgent", "APPOTrainer", "PPOTrainer", "DEFAULT_CONFIG"]

View file

@ -3,7 +3,7 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOPolicyGraph
from ray.rllib.agents.agent import with_base_config
from ray.rllib.agents.trainer import with_base_config
from ray.rllib.agents import impala
from ray.rllib.utils.annotations import override
@ -52,13 +52,13 @@ DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
# yapf: enable
class APPOAgent(impala.ImpalaAgent):
class APPOTrainer(impala.ImpalaTrainer):
"""PPO surrogate loss with IMPALA-architecture."""
_agent_name = "APPO"
_name = "APPO"
_default_config = DEFAULT_CONFIG
_policy_graph = AsyncPPOPolicyGraph
@override(impala.ImpalaAgent)
@override(impala.ImpalaTrainer)
def _get_policy_graph(self):
return AsyncPPOPolicyGraph

View file

@ -4,7 +4,7 @@ from __future__ import print_function
import logging
from ray.rllib.agents import Agent, with_common_config
from ray.rllib.agents import Trainer, with_common_config
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer
from ray.rllib.utils.annotations import override
@ -63,14 +63,14 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class PPOAgent(Agent):
class PPOTrainer(Trainer):
"""Multi-GPU optimized implementation of PPO in TensorFlow."""
_agent_name = "PPO"
_name = "PPO"
_default_config = DEFAULT_CONFIG
_policy_graph = PPOPolicyGraph
@override(Agent)
@override(Trainer)
def _init(self, config, env_creator):
self._validate_config()
self.local_evaluator = self.make_local_evaluator(
@ -96,7 +96,7 @@ class PPOAgent(Agent):
"straggler_mitigation": config["straggler_mitigation"],
})
@override(Agent)
@override(Trainer)
def _train(self):
if "observation_filter" not in self.raw_user_config:
# TODO(ekl) remove this message after a few releases

View file

@ -2,7 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG
from ray.rllib.agents.qmix.apex import ApexQMixAgent
from ray.rllib.agents.qmix.qmix import QMixTrainer, DEFAULT_CONFIG
from ray.rllib.agents.qmix.apex import ApexQMixTrainer
__all__ = ["QMixAgent", "ApexQMixAgent", "DEFAULT_CONFIG"]
__all__ = ["QMixTrainer", "ApexQMixTrainer", "DEFAULT_CONFIG"]

View file

@ -4,7 +4,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG as QMIX_CONFIG
from ray.rllib.agents.qmix.qmix import QMixTrainer, \
DEFAULT_CONFIG as QMIX_CONFIG
from ray.rllib.utils.annotations import override
from ray.rllib.utils import merge_dicts
@ -34,17 +35,17 @@ APEX_QMIX_DEFAULT_CONFIG = merge_dicts(
)
class ApexQMixAgent(QMixAgent):
class ApexQMixTrainer(QMixTrainer):
"""QMIX variant that uses the Ape-X distributed policy optimizer.
By default, this is configured for a large single node (32 cores). For
running in a large cluster, increase the `num_workers` config var.
"""
_agent_name = "APEX_QMIX"
_name = "APEX_QMIX"
_default_config = APEX_QMIX_DEFAULT_CONFIG
@override(QMixAgent)
@override(QMixTrainer)
def update_target_if_needed(self):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \

View file

@ -2,8 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.agent import with_common_config
from ray.rllib.agents.dqn.dqn import DQNAgent
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.rllib.agents.qmix.qmix_policy_graph import QMixPolicyGraph
# yapf: disable
@ -90,10 +90,10 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class QMixAgent(DQNAgent):
class QMixTrainer(DQNTrainer):
"""QMix implementation in PyTorch."""
_agent_name = "QMIX"
_name = "QMIX"
_default_config = DEFAULT_CONFIG
_policy_graph = QMixPolicyGraph
_optimizer_shared_configs = [

View file

@ -11,77 +11,77 @@ from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
def _import_appo():
from ray.rllib.agents import ppo
return ppo.APPOAgent
return ppo.APPOTrainer
def _import_qmix():
from ray.rllib.agents import qmix
return qmix.QMixAgent
return qmix.QMixTrainer
def _import_apex_qmix():
from ray.rllib.agents import qmix
return qmix.ApexQMixAgent
return qmix.ApexQMixTrainer
def _import_ddpg():
from ray.rllib.agents import ddpg
return ddpg.DDPGAgent
return ddpg.DDPGTrainer
def _import_apex_ddpg():
from ray.rllib.agents import ddpg
return ddpg.ApexDDPGAgent
return ddpg.ApexDDPGTrainer
def _import_ppo():
from ray.rllib.agents import ppo
return ppo.PPOAgent
return ppo.PPOTrainer
def _import_es():
from ray.rllib.agents import es
return es.ESAgent
return es.ESTrainer
def _import_ars():
from ray.rllib.agents import ars
return ars.ARSAgent
return ars.ARSTrainer
def _import_dqn():
from ray.rllib.agents import dqn
return dqn.DQNAgent
return dqn.DQNTrainer
def _import_apex():
from ray.rllib.agents import dqn
return dqn.ApexAgent
return dqn.ApexTrainer
def _import_a3c():
from ray.rllib.agents import a3c
return a3c.A3CAgent
return a3c.A3CTrainer
def _import_a2c():
from ray.rllib.agents import a3c
return a3c.A2CAgent
return a3c.A2CTrainer
def _import_pg():
from ray.rllib.agents import pg
return pg.PGAgent
return pg.PGTrainer
def _import_impala():
from ray.rllib.agents import impala
return impala.ImpalaAgent
return impala.ImpalaTrainer
def _import_marwil():
from ray.rllib.agents import marwil
return marwil.MARWILAgent
return marwil.MARWILTrainer
ALGORITHMS = {
@ -122,13 +122,13 @@ def _get_agent_class(alg):
from ray.tune import script_runner
return script_runner.ScriptRunner
elif alg == "__fake":
from ray.rllib.agents.mock import _MockAgent
return _MockAgent
from ray.rllib.agents.mock import _MockTrainer
return _MockTrainer
elif alg == "__sigmoid_fake_data":
from ray.rllib.agents.mock import _SigmoidFakeData
return _SigmoidFakeData
elif alg == "__parameter_tuning":
from ray.rllib.agents.mock import _ParameterTuningAgent
return _ParameterTuningAgent
from ray.rllib.agents.mock import _ParameterTuningTrainer
return _ParameterTuningTrainer
else:
raise Exception(("Unknown algorithm {}.").format(alg))

View file

@ -0,0 +1,817 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import copy
import logging
import os
import pickle
import six
import time
import tempfile
import tensorflow as tf
from types import FunctionType
import ray
from ray.exceptions import RayError
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
ShuffledInput
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \
_validate_multiagent_config
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.trainable import Trainable
from ray.tune.trial import Resources, ExportFormat
from ray.tune.logger import UnifiedLogger
from ray.tune.result import DEFAULT_RESULTS_DIR
logger = logging.getLogger(__name__)
# Max number of times to retry a worker failure. We shouldn't try too many
# times in a row since that would indicate a persistent cluster issue.
MAX_WORKER_FAILURE_RETRIES = 3
# yapf: disable
# __sphinx_doc_begin__
COMMON_CONFIG = {
# === Debugging ===
# Whether to write episode stats and videos to the agent log dir
"monitor": False,
# Set the ray.rllib.* log level for the agent process and its evaluators.
# Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also
# periodically print out summaries of relevant internal dataflow (this is
# also printed out once at startup at the INFO level).
"log_level": "INFO",
# Callbacks that will be run during various phases of training. These all
# take a single "info" dict as an argument. For episode callbacks, custom
# metrics can be attached to the episode by updating the episode object's
# custom metrics dict (see examples/custom_metrics_and_callbacks.py).
"callbacks": {
"on_episode_start": None, # arg: {"env": .., "episode": ...}
"on_episode_step": None, # arg: {"env": .., "episode": ...}
"on_episode_end": None, # arg: {"env": .., "episode": ...}
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
"on_train_result": None, # arg: {"trainer": ..., "result": ...}
"on_postprocess_traj": None, # arg: {"batch": ..., "episode": ...}
},
# Whether to attempt to continue training if a worker crashes.
"ignore_worker_failures": False,
# === Policy ===
# Arguments to pass to model. See models/catalog.py for a full list of the
# available model options.
"model": MODEL_DEFAULTS,
# Arguments to pass to the policy optimizer. These vary by optimizer.
"optimizer": {},
# === Environment ===
# Discount factor of the MDP
"gamma": 0.99,
# Number of steps after which the episode is forced to terminate. Defaults
# to `env.spec.max_episode_steps` (if present) for Gym envs.
"horizon": None,
# Calculate rewards but don't reset the environment when the horizon is
# hit. This allows value estimation and RNN state to span across logical
# episodes denoted by horizon. This only has an effect if horizon != inf.
"soft_horizon": False,
# Arguments to pass to the env creator
"env_config": {},
# Environment name can also be passed via config
"env": None,
# Whether to clip rewards prior to experience postprocessing. Setting to
# None means clip for Atari only.
"clip_rewards": None,
# Whether to np.clip() actions to the action space low/high range spec.
"clip_actions": True,
# Whether to use rllib or deepmind preprocessors by default
"preprocessor_pref": "deepmind",
# === Resources ===
# Number of actors used for parallelism
"num_workers": 2,
# Number of GPUs to allocate to the driver. Note that not all algorithms
# can take advantage of driver GPUs. This can be fraction (e.g., 0.3 GPUs).
"num_gpus": 0,
# Number of CPUs to allocate per worker.
"num_cpus_per_worker": 1,
# Number of GPUs to allocate per worker. This can be fractional.
"num_gpus_per_worker": 0,
# Any custom resources to allocate per worker.
"custom_resources_per_worker": {},
# Number of CPUs to allocate for the driver. Note: this only takes effect
# when running in Tune.
"num_cpus_for_driver": 1,
# === Execution ===
# Number of environments to evaluate vectorwise per worker.
"num_envs_per_worker": 1,
# Default sample batch size (unroll length). Batches of this size are
# collected from workers until train_batch_size is met. When using
# multiple envs per worker, this is multiplied by num_envs_per_worker.
"sample_batch_size": 200,
# Training batch size, if applicable. Should be >= sample_batch_size.
# Samples batches will be concatenated together to this size for training.
"train_batch_size": 200,
# Whether to rollout "complete_episodes" or "truncate_episodes"
"batch_mode": "truncate_episodes",
# (Deprecated) Use a background thread for sampling (slightly off-policy)
"sample_async": False,
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter"
"observation_filter": "NoFilter",
# Whether to synchronize the statistics of remote filters.
"synchronize_filters": True,
# Configure TF for single-process operation by default
"tf_session_args": {
# note: overriden by `local_evaluator_tf_session_args`
"intra_op_parallelism_threads": 2,
"inter_op_parallelism_threads": 2,
"gpu_options": {
"allow_growth": True,
},
"log_device_placement": False,
"device_count": {
"CPU": 1
},
"allow_soft_placement": True, # required by PPO multi-gpu
},
# Override the following tf session args on the local evaluator
"local_evaluator_tf_session_args": {
# Allow a higher level of parallelism by default, but not unlimited
# since that can cause crashes with many concurrent drivers.
"intra_op_parallelism_threads": 8,
"inter_op_parallelism_threads": 8,
},
# Whether to LZ4 compress individual observations
"compress_observations": False,
# Drop metric batches from unresponsive workers after this many seconds
"collect_metrics_timeout": 180,
# Smooth metrics over this many episodes.
"metrics_smoothing_episodes": 100,
# If using num_envs_per_worker > 1, whether to create those new envs in
# remote processes instead of in the same worker. This adds overheads, but
# can make sense if your envs can take much time to step / reset
# (e.g., for StarCraft). Use this cautiously; overheads are significant.
"remote_worker_envs": False,
# Timeout that remote workers are waiting when polling environments.
# 0 (continue when at least one env is ready) is a reasonable default,
# but optimal value could be obtained by measuring your environment
# step / reset and model inference perf.
"remote_env_batch_wait_ms": 0,
# === Offline Datasets ===
# Specify how to generate experiences:
# - "sampler": generate experiences via online simulation (default)
# - a local directory or file glob expression (e.g., "/tmp/*.json")
# - a list of individual file paths/URIs (e.g., ["/tmp/1.json",
# "s3://bucket/2.json"])
# - a dict with string keys and sampling probabilities as values (e.g.,
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
# - a function that returns a rllib.offline.InputReader
"input": "sampler",
# Specify how to evaluate the current policy. This only has an effect when
# reading offline experiences. Available options:
# - "wis": the weighted step-wise importance sampling estimator.
# - "is": the step-wise importance sampling estimator.
# - "simulation": run the environment in the background, but use
# this data for evaluation only and not for learning.
"input_evaluation": ["is", "wis"],
# Whether to run postprocess_trajectory() on the trajectory fragments from
# offline inputs. Note that postprocessing will be done using the *current*
# policy, not the *behaviour* policy, which is typically undesirable for
# on-policy algorithms.
"postprocess_inputs": False,
# If positive, input batches will be shuffled via a sliding window buffer
# of this number of batches. Use this if the input data is not in random
# enough order. Input is delayed until the shuffle buffer is filled.
"shuffle_buffer_size": 0,
# Specify where experiences should be saved:
# - None: don't save any experiences
# - "logdir" to save to the agent log dir
# - a path/URI to save to a custom output directory (e.g., "s3://bucket/")
# - a function that returns a rllib.offline.OutputWriter
"output": None,
# What sample batch columns to LZ4 compress in the output data.
"output_compress_columns": ["obs", "new_obs"],
# Max output file size before rolling over to a new file.
"output_max_file_size": 64 * 1024 * 1024,
# === Multiagent ===
"multiagent": {
# Map from policy ids to tuples of (policy_graph_cls, obs_space,
# act_space, config). See policy_evaluator.py for more info.
"policy_graphs": {},
# Function mapping agent ids to policy ids.
"policy_mapping_fn": None,
# Optional whitelist of policies to train, or None for all policies.
"policies_to_train": None,
},
}
# __sphinx_doc_end__
# yapf: enable
@DeveloperAPI
def with_common_config(extra_config):
"""Returns the given config dict merged with common agent confs."""
return with_base_config(COMMON_CONFIG, extra_config)
def with_base_config(base_config, extra_config):
"""Returns the given config dict merged with a base agent conf."""
config = copy.deepcopy(base_config)
config.update(extra_config)
return config
@PublicAPI
class Trainer(Trainable):
"""A trainer coordinates the optimization of one or more RL policies.
All RLlib trainers extend this base class, e.g., the A3CTrainer implements
the A3C algorithm for single and multi-agent training.
Trainer objects retain internal model state between calls to train(), so
you should create a new trainer instance for each training session.
Attributes:
env_creator (func): Function that creates a new training env.
config (obj): Algorithm-specific configuration data.
logdir (str): Directory in which training outputs should be placed.
"""
_allow_unknown_configs = False
_allow_unknown_subkeys = [
"tf_session_args", "env_config", "model", "optimizer", "multiagent",
"custom_resources_per_worker"
]
@PublicAPI
def __init__(self, config=None, env=None, logger_creator=None):
"""Initialize an RLLib trainer.
Args:
config (dict): Algorithm-specific configuration data.
env (str): Name of the environment to use. Note that this can also
be specified as the `env` key in config.
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
config = config or {}
# Vars to synchronize to evaluators on each train call
self.global_vars = {"timestep": 0}
# Trainers allow env ids to be passed directly to the constructor.
self._env_id = self._register_if_needed(env or config.get("env"))
# Create a default logger creator if no logger_creator is specified
if logger_creator is None:
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
logdir_prefix = "{}_{}_{}".format(self._name, self._env_id,
timestr)
def default_logger_creator(config):
"""Creates a Unified logger with a default logdir prefix
containing the agent name and the env id
"""
if not os.path.exists(DEFAULT_RESULTS_DIR):
os.makedirs(DEFAULT_RESULTS_DIR)
logdir = tempfile.mkdtemp(
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
return UnifiedLogger(config, logdir, None)
logger_creator = default_logger_creator
Trainable.__init__(self, config, logger_creator)
@classmethod
@override(Trainable)
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
Trainer._validate_config(cf)
# TODO(ekl): add custom resources here once tune supports them
return Resources(
cpu=cf["num_cpus_for_driver"],
gpu=cf["num_gpus"],
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
@override(Trainable)
@PublicAPI
def train(self):
"""Overrides super.train to synchronize global vars."""
if self._has_policy_optimizer():
self.global_vars["timestep"] = self.optimizer.num_steps_sampled
self.optimizer.local_evaluator.set_global_vars(self.global_vars)
for ev in self.optimizer.remote_evaluators:
ev.set_global_vars.remote(self.global_vars)
logger.debug("updated global vars: {}".format(self.global_vars))
result = None
for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
try:
result = Trainable.train(self)
except RayError as e:
if self.config["ignore_worker_failures"]:
logger.exception(
"Error in train call, attempting to recover")
self._try_recover()
else:
logger.info(
"Worker crashed during call to train(). To attempt to "
"continue training without the failed worker, set "
"`'ignore_worker_failures': True`.")
raise e
except Exception as e:
time.sleep(0.5) # allow logs messages to propagate
raise e
else:
break
if result is None:
raise RuntimeError("Failed to recover from worker crash")
if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
and hasattr(self, "local_evaluator")):
FilterManager.synchronize(
self.local_evaluator.filters,
self.remote_evaluators,
update_remote=self.config["synchronize_filters"])
logger.debug("synchronized filters: {}".format(
self.local_evaluator.filters))
if self._has_policy_optimizer():
result["num_healthy_workers"] = len(
self.optimizer.remote_evaluators)
return result
@override(Trainable)
def _log_result(self, result):
if self.config["callbacks"].get("on_train_result"):
self.config["callbacks"]["on_train_result"]({
"trainer": self,
"result": result,
})
# log after the callback is invoked, so that the user has a chance
# to mutate the result
Trainable._log_result(self, result)
@override(Trainable)
def _setup(self, config):
env = self._env_id
if env:
config["env"] = env
if _global_registry.contains(ENV_CREATOR, env):
self.env_creator = _global_registry.get(ENV_CREATOR, env)
else:
import gym # soft dependency
self.env_creator = lambda env_config: gym.make(env)
else:
self.env_creator = lambda env_config: None
# Merge the supplied config with the class default
merged_config = copy.deepcopy(self._default_config)
merged_config = deep_update(merged_config, config,
self._allow_unknown_configs,
self._allow_unknown_subkeys)
self.raw_user_config = config
self.config = merged_config
Trainer._validate_config(self.config)
if self.config.get("log_level"):
logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
# TODO(ekl) setting the graph is unnecessary for PyTorch agents
with tf.Graph().as_default():
self._init(self.config, self.env_creator)
@override(Trainable)
def _stop(self):
# Call stop on all evaluators to release resources
if hasattr(self, "local_evaluator"):
self.local_evaluator.stop()
if hasattr(self, "remote_evaluators"):
for ev in self.remote_evaluators:
ev.stop.remote()
# workaround for https://github.com/ray-project/ray/issues/1516
if hasattr(self, "remote_evaluators"):
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote()
if hasattr(self, "optimizer"):
self.optimizer.stop()
@override(Trainable)
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
pickle.dump(self.__getstate__(), open(checkpoint_path, "wb"))
return checkpoint_path
@override(Trainable)
def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path, "rb"))
self.__setstate__(extra_data)
@DeveloperAPI
def _init(self, config, env_creator):
"""Subclasses should override this for custom initialization."""
raise NotImplementedError
@PublicAPI
def compute_action(self,
observation,
state=None,
prev_action=None,
prev_reward=None,
info=None,
policy_id=DEFAULT_POLICY_ID,
full_fetch=False):
"""Computes an action for the specified policy.
Note that you can also access the policy object through
self.get_policy(policy_id) and call compute_actions() on it directly.
Arguments:
observation (obj): observation from the environment.
state (list): RNN hidden state, if any. If state is not None,
then all of compute_single_action(...) is returned
(computed action, rnn state, logits dictionary).
Otherwise compute_single_action(...)[0] is
returned (computed action).
prev_action (obj): previous action value, if any
prev_reward (int): previous reward, if any
info (dict): info object, if any
policy_id (str): policy to query (only applies to multi-agent).
full_fetch (bool): whether to return extra action fetch results.
This is always set to true if RNN state is specified.
Returns:
Just the computed action if full_fetch=False, or the full output
of policy.compute_actions() otherwise.
"""
if state is None:
state = []
preprocessed = self.local_evaluator.preprocessors[policy_id].transform(
observation)
filtered_obs = self.local_evaluator.filters[policy_id](
preprocessed, update=False)
if state:
return self.get_policy(policy_id).compute_single_action(
filtered_obs,
state,
prev_action,
prev_reward,
info,
clip_actions=self.config["clip_actions"])
res = self.get_policy(policy_id).compute_single_action(
filtered_obs,
state,
prev_action,
prev_reward,
info,
clip_actions=self.config["clip_actions"])
if full_fetch:
return res
else:
return res[0] # backwards compatibility
@property
def iteration(self):
"""Current training iter, auto-incremented with each train() call."""
return self._iteration
@property
def _name(self):
"""Subclasses should override this to declare their name."""
raise NotImplementedError
@property
def _default_config(self):
"""Subclasses should override this to declare their default config."""
raise NotImplementedError
@PublicAPI
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy graph for the specified id, or None.
Arguments:
policy_id (str): id of policy graph to return.
"""
return self.local_evaluator.get_policy(policy_id)
@PublicAPI
def get_weights(self, policies=None):
"""Return a dictionary of policy ids to weights.
Arguments:
policies (list): Optional list of policies to return weights for,
or None for all policies.
"""
return self.local_evaluator.get_weights(policies)
@PublicAPI
def set_weights(self, weights):
"""Set policy weights by policy id.
Arguments:
weights (dict): Map of policy ids to weights to set.
"""
self.local_evaluator.set_weights(weights)
@DeveloperAPI
def make_local_evaluator(self,
env_creator,
policy_graph,
extra_config=None):
"""Convenience method to return configured local evaluator."""
return self._make_evaluator(
PolicyEvaluator,
env_creator,
policy_graph,
0,
merge_dicts(
# important: allow local tf to use more CPUs for optimization
merge_dicts(
self.config, {
"tf_session_args": self.
config["local_evaluator_tf_session_args"]
}),
extra_config or {}))
@DeveloperAPI
def make_remote_evaluators(self, env_creator, policy_graph, count):
"""Convenience method to return a number of remote evaluators."""
remote_args = {
"num_cpus": self.config["num_cpus_per_worker"],
"num_gpus": self.config["num_gpus_per_worker"],
"resources": self.config["custom_resources_per_worker"],
}
cls = PolicyEvaluator.as_remote(**remote_args).remote
return [
self._make_evaluator(cls, env_creator, policy_graph, i + 1,
self.config) for i in range(count)
]
@DeveloperAPI
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
"""Export policy model with given policy_id to local directory.
Arguments:
export_dir (string): Writable local directory.
policy_id (string): Optional policy id to export.
Example:
>>> trainer = MyTrainer()
>>> for _ in range(10):
>>> trainer.train()
>>> trainer.export_policy_model("/tmp/export_dir")
"""
self.local_evaluator.export_policy_model(export_dir, policy_id)
@DeveloperAPI
def export_policy_checkpoint(self,
export_dir,
filename_prefix="model",
policy_id=DEFAULT_POLICY_ID):
"""Export tensorflow policy model checkpoint to local directory.
Arguments:
export_dir (string): Writable local directory.
filename_prefix (string): file name prefix of checkpoint files.
policy_id (string): Optional policy id to export.
Example:
>>> trainer = MyTrainer()
>>> for _ in range(10):
>>> trainer.train()
>>> trainer.export_policy_checkpoint("/tmp/export_dir")
"""
self.local_evaluator.export_policy_checkpoint(
export_dir, filename_prefix, policy_id)
@DeveloperAPI
def collect_metrics(self, selected_evaluators=None):
"""Collects metrics from the remote evaluators of this agent.
This is the same data as returned by a call to train().
"""
return self.optimizer.collect_metrics(
self.config["collect_metrics_timeout"],
min_history=self.config["metrics_smoothing_episodes"],
selected_evaluators=selected_evaluators)
@classmethod
def resource_help(cls, config):
return ("\n\nYou can adjust the resource requests of RLlib agents by "
"setting `num_workers`, `num_gpus`, and other configs. See "
"the DEFAULT_CONFIG defined by each agent for more info.\n\n"
"The config of this agent is: {}".format(config))
@staticmethod
def _validate_config(config):
if "gpu" in config:
raise ValueError(
"The `gpu` config is deprecated, please use `num_gpus=0|1` "
"instead.")
if "gpu_fraction" in config:
raise ValueError(
"The `gpu_fraction` config is deprecated, please use "
"`num_gpus=<fraction>` instead.")
if "use_gpu_for_workers" in config:
raise ValueError(
"The `use_gpu_for_workers` config is deprecated, please use "
"`num_gpus_per_worker=1` instead.")
if type(config["input_evaluation"]) != list:
raise ValueError(
"`input_evaluation` must be a list of strings, got {}".format(
config["input_evaluation"]))
def _try_recover(self):
"""Try to identify and blacklist any unhealthy workers.
This method is called after an unexpected remote error is encountered
from a worker. It issues check requests to all current workers and
blacklists any that respond with error. If no healthy workers remain,
an error is raised.
"""
if not self._has_policy_optimizer():
raise NotImplementedError(
"Recovery is not supported for this algorithm")
logger.info("Health checking all workers...")
checks = []
for ev in self.optimizer.remote_evaluators:
_, obj_id = ev.sample_with_count.remote()
checks.append(obj_id)
healthy_evaluators = []
for i, obj_id in enumerate(checks):
ev = self.optimizer.remote_evaluators[i]
try:
ray.get(obj_id)
healthy_evaluators.append(ev)
logger.info("Worker {} looks healthy".format(i + 1))
except RayError:
logger.exception("Blacklisting worker {}".format(i + 1))
try:
ev.__ray_terminate__.remote()
except Exception:
logger.exception("Error terminating unhealthy worker")
if len(healthy_evaluators) < 1:
raise RuntimeError(
"Not enough healthy workers remain to continue.")
self.optimizer.reset(healthy_evaluators)
def _has_policy_optimizer(self):
return hasattr(self, "optimizer") and isinstance(
self.optimizer, PolicyOptimizer)
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
config):
def session_creator():
logger.debug("Creating TF session {}".format(
config["tf_session_args"]))
return tf.Session(
config=tf.ConfigProto(**config["tf_session_args"]))
if isinstance(config["input"], FunctionType):
input_creator = config["input"]
elif config["input"] == "sampler":
input_creator = (lambda ioctx: ioctx.default_sampler_input())
elif isinstance(config["input"], dict):
input_creator = (lambda ioctx: ShuffledInput(
MixedInput(config["input"], ioctx), config[
"shuffle_buffer_size"]))
else:
input_creator = (lambda ioctx: ShuffledInput(
JsonReader(config["input"], ioctx), config[
"shuffle_buffer_size"]))
if isinstance(config["output"], FunctionType):
output_creator = config["output"]
elif config["output"] is None:
output_creator = (lambda ioctx: NoopOutput())
elif config["output"] == "logdir":
output_creator = (lambda ioctx: JsonWriter(
ioctx.log_dir,
ioctx,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
else:
output_creator = (lambda ioctx: JsonWriter(
config["output"],
ioctx,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
if config["input"] == "sampler":
input_evaluation = []
else:
input_evaluation = config["input_evaluation"]
# Fill in the default policy graph if 'None' is specified in multiagent
if self.config["multiagent"]["policy_graphs"]:
tmp = self.config["multiagent"]["policy_graphs"]
_validate_multiagent_config(tmp, allow_none_graph=True)
for k, v in tmp.items():
if v[0] is None:
tmp[k] = (policy_graph, v[1], v[2], v[3])
policy_graph = tmp
return cls(
env_creator,
policy_graph,
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
policies_to_train=self.config["multiagent"]["policies_to_train"],
tf_session_creator=(session_creator
if config["tf_session_args"] else None),
batch_steps=config["sample_batch_size"],
batch_mode=config["batch_mode"],
episode_horizon=config["horizon"],
preprocessor_pref=config["preprocessor_pref"],
sample_async=config["sample_async"],
compress_observations=config["compress_observations"],
num_envs=config["num_envs_per_worker"],
observation_filter=config["observation_filter"],
clip_rewards=config["clip_rewards"],
clip_actions=config["clip_actions"],
env_config=config["env_config"],
model_config=config["model"],
policy_config=config,
worker_index=worker_index,
monitor_path=self.logdir if config["monitor"] else None,
log_dir=self.logdir,
log_level=config["log_level"],
callbacks=config["callbacks"],
input_creator=input_creator,
input_evaluation=input_evaluation,
output_creator=output_creator,
remote_worker_envs=config["remote_worker_envs"],
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
soft_horizon=config["soft_horizon"],
_fake_sampler=config.get("_fake_sampler", False))
@override(Trainable)
def _export_model(self, export_formats, export_dir):
ExportFormat.validate(export_formats)
exported = {}
if ExportFormat.CHECKPOINT in export_formats:
path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
self.export_policy_checkpoint(path)
exported[ExportFormat.CHECKPOINT] = path
if ExportFormat.MODEL in export_formats:
path = os.path.join(export_dir, ExportFormat.MODEL)
self.export_policy_model(path)
exported[ExportFormat.MODEL] = path
return exported
def __getstate__(self):
state = {}
if hasattr(self, "local_evaluator"):
state["evaluator"] = self.local_evaluator.save()
if hasattr(self, "optimizer") and hasattr(self.optimizer, "save"):
state["optimizer"] = self.optimizer.save()
return state
def __setstate__(self, state):
if "evaluator" in state:
self.local_evaluator.restore(state["evaluator"])
remote_state = ray.put(state["evaluator"])
for r in self.remote_evaluators:
r.restore.remote(remote_state)
if "optimizer" in state:
self.optimizer.restore(state["optimizer"])
def _register_if_needed(self, env_object):
if isinstance(env_object, six.string_types):
return env_object
elif isinstance(env_object, type):
name = env_object.__name__
register_env(name, lambda config: env_object(config))
return name
raise ValueError(
"{} is an invalid env specification. ".format(env_object) +
"You can specify a custom env as either a class "
"(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").")

View file

@ -4,25 +4,25 @@ from __future__ import print_function
import numpy as np
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.utils.annotations import override
# yapf: disable
# __sphinx_doc_begin__
class RandomAgent(Agent):
"""Agent that takes random actions and never learns."""
class RandomAgent(Trainer):
"""Policy that takes random actions and never learns."""
_agent_name = "RandomAgent"
_name = "RandomAgent"
_default_config = with_common_config({
"rollouts_per_iteration": 10,
})
@override(Agent)
@override(Trainer)
def _init(self, config, env_creator):
self.env = env_creator(config["env_config"])
@override(Agent)
@override(Trainer)
def _train(self):
rewards = []
steps = 0
@ -45,8 +45,8 @@ class RandomAgent(Agent):
if __name__ == "__main__":
agent = RandomAgent(
trainer = RandomAgent(
env="CartPole-v0", config={"rollouts_per_iteration": 10})
result = agent.train()
result = trainer.train()
assert result["episode_reward_mean"] > 10, result
print("Test: OK")

View file

@ -34,9 +34,9 @@ class ExternalEnv(threading.Thread):
Examples:
>>> register_env("my_env", lambda config: YourExternalEnv(config))
>>> agent = DQNAgent(env="my_env")
>>> trainer = DQNTrainer(env="my_env")
>>> while True:
print(agent.train())
print(trainer.train())
"""
@PublicAPI

View file

@ -35,6 +35,19 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder
logger = logging.getLogger(__name__)
# Handle to the current evaluator, which will be set to the most recently
# created PolicyEvaluator in this process. This can be helpful to access in
# custom env or policy classes for debugging or advanced use cases.
_global_evaluator = None
@DeveloperAPI
def get_global_evaluator():
"""Returns a handle to the active policy evaluator in this process."""
global _global_evaluator
return _global_evaluator
@DeveloperAPI
class PolicyEvaluator(EvaluatorInterface):
@ -215,6 +228,9 @@ class PolicyEvaluator(EvaluatorInterface):
_fake_sampler (bool): Use a fake (inf speed) sampler for testing.
"""
global _global_evaluator
_global_evaluator = self
if log_level:
logging.getLogger("ray.rllib").setLevel(log_level)

View file

@ -71,12 +71,13 @@ class MultiAgentSampleBatchBuilder(object):
corresponding policy batch for the agent's policy.
"""
def __init__(self, policy_map, clip_rewards):
def __init__(self, policy_map, clip_rewards, postp_callback):
"""Initialize a MultiAgentSampleBatchBuilder.
Arguments:
policy_map (dict): Maps policy ids to policy graph instances.
clip_rewards (bool): Whether to clip rewards before postprocessing.
postp_callback: function to call on each postprocessed batch.
"""
self.policy_map = policy_map
@ -87,6 +88,7 @@ class MultiAgentSampleBatchBuilder(object):
}
self.agent_builders = {}
self.agent_to_policy = {}
self.postp_callback = postp_callback
self.count = 0 # increment this manually
def total(self):
@ -158,6 +160,8 @@ class MultiAgentSampleBatchBuilder(object):
for agent_id, post_batch in sorted(post_batches.items()):
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
post_batch)
if self.postp_callback:
self.postp_callback({"episode": episode, "batch": post_batch})
self.agent_builders.clear()
self.agent_to_policy.clear()

View file

@ -279,7 +279,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
if batch_builder_pool:
return batch_builder_pool.pop()
else:
return MultiAgentSampleBatchBuilder(policies, clip_rewards)
return MultiAgentSampleBatchBuilder(
policies, clip_rewards, callbacks.get("on_postprocess_traj"))
def new_episode():
episode = MultiAgentEpisode(policies, policy_mapping_fn,

View file

@ -38,12 +38,21 @@ def on_sample_end(info):
def on_train_result(info):
print("agent.train() result: {} -> {} episodes".format(
info["agent"], info["result"]["episodes_this_iter"]))
print("trainer.train() result: {} -> {} episodes".format(
info["trainer"], info["result"]["episodes_this_iter"]))
# you can mutate the result dict to add new fields to return
info["result"]["callback_ok"] = True
def on_postprocess_traj(info):
episode = info["episode"]
batch = info["batch"]
print("postprocessed {} steps".format(batch.count))
if "num_batches" not in episode.custom_metrics:
episode.custom_metrics["num_batches"] = 0
episode.custom_metrics["num_batches"] += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=2000)
@ -63,6 +72,7 @@ if __name__ == "__main__":
"on_episode_end": tune.function(on_episode_end),
"on_sample_end": tune.function(on_sample_end),
"on_train_result": tune.function(on_train_result),
"on_postprocess_traj": tune.function(on_postprocess_traj),
},
},
)
@ -73,5 +83,6 @@ if __name__ == "__main__":
assert "pole_angle_mean" in custom_metrics
assert "pole_angle_min" in custom_metrics
assert "pole_angle_max" in custom_metrics
assert "num_batches_mean" in custom_metrics
assert type(custom_metrics["pole_angle_mean"]) is float
assert "callback_ok" in trials[0].last_result

View file

@ -12,12 +12,12 @@ from __future__ import print_function
import ray
from ray import tune
from ray.rllib.agents.ppo import PPOAgent
from ray.rllib.agents.ppo import PPOTrainer
def my_train_fn(config, reporter):
# Train for 100 iterations with high LR
agent1 = PPOAgent(env="CartPole-v0", config=config)
agent1 = PPOTrainer(env="CartPole-v0", config=config)
for _ in range(10):
result = agent1.train()
result["phase"] = 1
@ -28,7 +28,7 @@ def my_train_fn(config, reporter):
# Train for 100 iterations with low LR
config["lr"] = 0.0001
agent2 = PPOAgent(env="CartPole-v0", config=config)
agent2 = PPOTrainer(env="CartPole-v0", config=config)
agent2.restore(state)
for _ in range(10):
result = agent2.train()

View file

@ -15,9 +15,9 @@ import argparse
import gym
import ray
from ray.rllib.agents.dqn.dqn import DQNAgent
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.rllib.agents.ppo.ppo import PPOAgent
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
from ray.tune.logger import pretty_print
@ -49,7 +49,7 @@ if __name__ == "__main__":
else:
return "dqn_policy"
ppo_trainer = PPOAgent(
ppo_trainer = PPOTrainer(
env="multi_cartpole",
config={
"multiagent": {
@ -62,7 +62,7 @@ if __name__ == "__main__":
"observation_filter": "NoFilter",
})
dqn_trainer = DQNAgent(
dqn_trainer = DQNTrainer(
env="multi_cartpole",
config={
"multiagent": {

View file

@ -13,7 +13,7 @@ from gym import spaces
import numpy as np
import ray
from ray.rllib.agents.dqn import DQNAgent
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.utils.policy_server import PolicyServer
from ray.tune.logger import pretty_print
@ -43,7 +43,7 @@ if __name__ == "__main__":
# We use DQN since it supports off-policy actions, but you can choose and
# configure any agent.
dqn = DQNAgent(
dqn = DQNTrainer(
env="srv",
config={
# Use a single process to avoid needing to set up a load balancer

View file

@ -35,7 +35,7 @@ class MinibatchBuffer(object):
released: True if the item is now removed from the ring buffer.
"""
if self.ttl[self.idx] <= 0:
self.buffers[self.idx] = self.inqueue.get(timeout=60.0)
self.buffers[self.idx] = self.inqueue.get(timeout=300.0)
self.ttl[self.idx] = self.cur_max_ttl
if self.cur_max_ttl < self.max_ttl:
self.cur_max_ttl += 1

View file

@ -7,7 +7,7 @@ from gym.spaces import Tuple, Discrete, Dict, Box
import ray
from ray.tune import register_env
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.agents.qmix import QMixAgent
from ray.rllib.agents.qmix import QMixTrainer
class AvailActionsTestEnv(MultiAgentEnv):
@ -55,7 +55,7 @@ if __name__ == "__main__":
grouping, obs_space=obs_space, act_space=act_space))
ray.init()
agent = QMixAgent(
agent = QMixTrainer(
env="action_mask_test",
config={
"num_envs_per_worker": 5, # test with vectorization on

View file

@ -5,7 +5,7 @@ from __future__ import print_function
import unittest
import ray
from ray.rllib.agents.dqn import DQNAgent
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.dqn.dqn_policy_graph import _adjust_nstep
@ -26,7 +26,8 @@ class DQNTest(unittest.TestCase):
def testEvaluationOption(self):
ray.init()
agent = DQNAgent(env="CartPole-v0", config={"evaluation_interval": 2})
agent = DQNTrainer(
env="CartPole-v0", config={"evaluation_interval": 2})
r0 = agent.train()
r1 = agent.train()
r2 = agent.train()

View file

@ -9,8 +9,8 @@ import unittest
import uuid
import ray
from ray.rllib.agents.dqn import DQNAgent
from ray.rllib.agents.pg import PGAgent
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.tests.test_policy_evaluator import (BadPolicyGraph,
@ -163,7 +163,7 @@ class TestExternalEnv(unittest.TestCase):
register_env(
"test3", lambda _: PartOffPolicyServing(
gym.make("CartPole-v0"), off_pol_frac=0.2))
dqn = DQNAgent(env="test3", config={"exploration_fraction": 0.001})
dqn = DQNTrainer(env="test3", config={"exploration_fraction": 0.001})
for i in range(100):
result = dqn.train()
print("Iteration {}, reward {}, timesteps {}".format(
@ -174,7 +174,7 @@ class TestExternalEnv(unittest.TestCase):
def testTrainCartpole(self):
register_env("test", lambda _: SimpleServing(gym.make("CartPole-v0")))
pg = PGAgent(env="test", config={"num_workers": 0})
pg = PGTrainer(env="test", config={"num_workers": 0})
for i in range(100):
result = pg.train()
print("Iteration {}, reward {}, timesteps {}".format(
@ -186,7 +186,7 @@ class TestExternalEnv(unittest.TestCase):
def testTrainCartpoleMulti(self):
register_env("test2",
lambda _: MultiServing(lambda: gym.make("CartPole-v0")))
pg = PGAgent(env="test2", config={"num_workers": 0})
pg = PGTrainer(env="test2", config={"num_workers": 0})
for i in range(100):
result = pg.train()
print("Iteration {}, reward {}, timesteps {}".format(

View file

@ -14,7 +14,7 @@ import time
import unittest
import ray
from ray.rllib.agents.pg import PGAgent
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.evaluation import SampleBatch
from ray.rllib.offline import IOContext, JsonWriter, JsonReader
@ -44,7 +44,7 @@ class AgentIOTest(unittest.TestCase):
shutil.rmtree(self.test_dir)
def writeOutputs(self, output):
agent = PGAgent(
agent = PGTrainer(
env="CartPole-v0",
config={
"output": output,
@ -65,7 +65,7 @@ class AgentIOTest(unittest.TestCase):
def testAgentInputDir(self):
self.writeOutputs(self.test_dir)
agent = PGAgent(
agent = PGTrainer(
env="CartPole-v0",
config={
"input": self.test_dir,
@ -97,7 +97,7 @@ class AgentIOTest(unittest.TestCase):
for data in out:
f.write(json.dumps(data))
agent = PGAgent(
agent = PGTrainer(
env="CartPole-v0",
config={
"input": self.test_dir,
@ -111,7 +111,7 @@ class AgentIOTest(unittest.TestCase):
def testAgentInputEvalSim(self):
self.writeOutputs(self.test_dir)
agent = PGAgent(
agent = PGTrainer(
env="CartPole-v0",
config={
"input": self.test_dir,
@ -126,7 +126,7 @@ class AgentIOTest(unittest.TestCase):
def testAgentInputList(self):
self.writeOutputs(self.test_dir)
agent = PGAgent(
agent = PGTrainer(
env="CartPole-v0",
config={
"input": glob.glob(self.test_dir + "/*.json"),
@ -139,7 +139,7 @@ class AgentIOTest(unittest.TestCase):
def testAgentInputDict(self):
self.writeOutputs(self.test_dir)
agent = PGAgent(
agent = PGTrainer(
env="CartPole-v0",
config={
"input": {
@ -161,7 +161,7 @@ class AgentIOTest(unittest.TestCase):
act_space = single_env.action_space
return (PGPolicyGraph, obs_space, act_space, {})
pg = PGAgent(
pg = PGTrainer(
env="multi_cartpole",
config={
"num_workers": 0,
@ -180,7 +180,7 @@ class AgentIOTest(unittest.TestCase):
self.assertEqual(len(os.listdir(self.test_dir)), 1)
pg.stop()
pg = PGAgent(
pg = PGTrainer(
env="multi_cartpole",
config={
"num_workers": 0,

View file

@ -4,7 +4,7 @@ from __future__ import print_function
import unittest
from ray.rllib.agents.ppo import PPOAgent, DEFAULT_CONFIG
from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG
import ray
@ -12,7 +12,7 @@ class LocalModeTest(unittest.TestCase):
def testLocal(self):
ray.init(local_mode=True)
cf = DEFAULT_CONFIG.copy()
agent = PPOAgent(cf, "CartPole-v0")
agent = PPOTrainer(cf, "CartPole-v0")
print(agent.train())

View file

@ -10,7 +10,7 @@ import tensorflow as tf
import tensorflow.contrib.rnn as rnn
import ray
from ray.rllib.agents.ppo import PPOAgent
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.models import ModelCatalog
from ray.rllib.models.lstm import add_time_dimension, chop_into_sequences
from ray.rllib.models.misc import linear, normc_initializer
@ -149,7 +149,7 @@ class RNNSequencing(unittest.TestCase):
def testSimpleOptimizerSequencing(self):
ModelCatalog.register_custom_model("rnn", RNNSpyModel)
register_env("counter", lambda _: DebugCounterEnv())
ppo = PPOAgent(
ppo = PPOTrainer(
env="counter",
config={
"num_workers": 0,
@ -205,7 +205,7 @@ class RNNSequencing(unittest.TestCase):
def testMinibatchSequencing(self):
ModelCatalog.register_custom_model("rnn", RNNSpyModel)
register_env("counter", lambda _: DebugCounterEnv())
ppo = PPOAgent(
ppo = PPOTrainer(
env="counter",
config={
"num_workers": 0,

View file

@ -7,7 +7,7 @@ import random
import unittest
import ray
from ray.rllib.agents.pg import PGAgent
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer,
@ -519,7 +519,7 @@ class TestMultiAgentEnv(unittest.TestCase):
def testTrainMultiCartpoleSinglePolicy(self):
n = 10
register_env("multi_cartpole", lambda _: MultiCartpole(n))
pg = PGAgent(env="multi_cartpole", config={"num_workers": 0})
pg = PGTrainer(env="multi_cartpole", config={"num_workers": 0})
for i in range(100):
result = pg.train()
print("Iteration {}, reward {}, timesteps {}".format(
@ -542,7 +542,7 @@ class TestMultiAgentEnv(unittest.TestCase):
act_space = single_env.action_space
return (None, obs_space, act_space, config)
pg = PGAgent(
pg = PGTrainer(
env="multi_cartpole",
config={
"num_workers": 0,

View file

@ -12,8 +12,8 @@ import tensorflow as tf
import unittest
import ray
from ray.rllib.agents.a3c import A2CAgent
from ray.rllib.agents.pg import PGAgent
from ray.rllib.agents.a3c import A2CTrainer
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import BaseEnv
@ -215,7 +215,7 @@ class TupleSpyModel(Model):
class NestedSpacesTest(unittest.TestCase):
def testInvalidModel(self):
ModelCatalog.register_custom_model("invalid", InvalidModel)
self.assertRaises(ValueError, lambda: PGAgent(
self.assertRaises(ValueError, lambda: PGTrainer(
env="CartPole-v0", config={
"model": {
"custom_model": "invalid",
@ -226,7 +226,7 @@ class NestedSpacesTest(unittest.TestCase):
ModelCatalog.register_custom_model("invalid2", InvalidModel2)
self.assertRaisesRegexp(
ValueError, "Expected output.*",
lambda: PGAgent(
lambda: PGTrainer(
env="CartPole-v0", config={
"model": {
"custom_model": "invalid2",
@ -236,7 +236,7 @@ class NestedSpacesTest(unittest.TestCase):
def doTestNestedDict(self, make_env, test_lstm=False):
ModelCatalog.register_custom_model("composite", DictSpyModel)
register_env("nested", make_env)
pg = PGAgent(
pg = PGTrainer(
env="nested",
config={
"num_workers": 0,
@ -265,7 +265,7 @@ class NestedSpacesTest(unittest.TestCase):
def doTestNestedTuple(self, make_env):
ModelCatalog.register_custom_model("composite2", TupleSpyModel)
register_env("nested2", make_env)
pg = PGAgent(
pg = PGTrainer(
env="nested2",
config={
"num_workers": 0,
@ -323,7 +323,7 @@ class NestedSpacesTest(unittest.TestCase):
ModelCatalog.register_custom_model("tuple_spy", TupleSpyModel)
register_env("nested_ma", lambda _: NestedMultiAgentEnv())
act_space = spaces.Discrete(2)
pg = PGAgent(
pg = PGTrainer(
env="nested_ma",
config={
"num_workers": 0,
@ -370,13 +370,13 @@ class NestedSpacesTest(unittest.TestCase):
def testRolloutDictSpace(self):
register_env("nested", lambda _: NestedDictEnv())
agent = PGAgent(env="nested")
agent = PGTrainer(env="nested")
agent.train()
path = agent.save()
agent.stop()
# Test train works on restore
agent2 = PGAgent(env="nested")
agent2 = PGTrainer(env="nested")
agent2.restore(path)
agent2.train()
@ -386,7 +386,7 @@ class NestedSpacesTest(unittest.TestCase):
def testPyTorchModel(self):
ModelCatalog.register_custom_model("composite", TorchSpyModel)
register_env("nested", lambda _: NestedDictEnv())
a2c = A2CAgent(
a2c = A2CTrainer(
env="nested",
config={
"num_workers": 0,

View file

@ -9,7 +9,7 @@ import time
import unittest
import ray
from ray.rllib.agents.ppo import PPOAgent
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.evaluation import SampleBatch
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
@ -41,7 +41,7 @@ class PPOCollectTest(unittest.TestCase):
ray.init(num_cpus=4)
# Check we at least collect the initial wave of samples
ppo = PPOAgent(
ppo = PPOTrainer(
env="CartPole-v0",
config={
"sample_batch_size": 200,
@ -53,7 +53,7 @@ class PPOCollectTest(unittest.TestCase):
ppo.stop()
# Check we collect at least the specified amount of samples
ppo = PPOAgent(
ppo = PPOTrainer(
env="CartPole-v0",
config={
"sample_batch_size": 200,
@ -65,7 +65,7 @@ class PPOCollectTest(unittest.TestCase):
ppo.stop()
# Check in vectorized mode
ppo = PPOAgent(
ppo = PPOTrainer(
env="CartPole-v0",
config={
"sample_batch_size": 200,
@ -78,7 +78,7 @@ class PPOCollectTest(unittest.TestCase):
ppo.stop()
# Check legacy mode
ppo = PPOAgent(
ppo = PPOTrainer(
env="CartPole-v0",
config={
"sample_batch_size": 200,

View file

@ -10,8 +10,8 @@ import unittest
from collections import Counter
import ray
from ray.rllib.agents.pg import PGAgent
from ray.rllib.agents.a3c import A2CAgent
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.a3c import A2CTrainer
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.policy_graph import PolicyGraph
@ -170,7 +170,7 @@ class TestPolicyEvaluator(unittest.TestCase):
print()
def testGlobalVarsUpdate(self):
agent = A2CAgent(
agent = A2CTrainer(
env="CartPole-v0",
config={
"lr_schedule": [[0, 0.1], [400, 0.000001]],
@ -182,12 +182,12 @@ class TestPolicyEvaluator(unittest.TestCase):
def testNoStepOnInit(self):
register_env("fail", lambda _: FailOnStepEnv())
pg = PGAgent(env="fail", config={"num_workers": 1})
pg = PGTrainer(env="fail", config={"num_workers": 1})
self.assertRaises(Exception, lambda: pg.train())
def testCallbacks(self):
counts = Counter()
pg = PGAgent(
pg = PGTrainer(
env="CartPole-v0", config={
"num_workers": 0,
"sample_batch_size": 50,
@ -211,7 +211,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testQueryEvaluators(self):
register_env("test", lambda _: gym.make("CartPole-v0"))
pg = PGAgent(
pg = PGTrainer(
env="test",
config={
"num_workers": 2,

View file

@ -1,10 +1,33 @@
import logging
from ray.rllib.utils.filter_manager import FilterManager
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.policy_client import PolicyClient
from ray.rllib.utils.policy_server import PolicyServer
from ray.tune.util import merge_dicts, deep_update
logger = logging.getLogger(__name__)
def renamed_class(cls):
class DeprecationWrapper(cls):
def __init__(self, *args, **kwargs):
old_name = cls.__name__.replace("Trainer", "Agent")
new_name = cls.__name__
logger.warn("DeprecationWarning: {} has been renamed to {}. ".
format(old_name, new_name) +
"This will raise an error in the future.")
cls.__init__(self, *args, **kwargs)
return DeprecationWrapper
__all__ = [
"Filter", "FilterManager", "PolicyClient", "PolicyServer", "merge_dicts",
"deep_update"
"Filter",
"FilterManager",
"PolicyClient",
"PolicyServer",
"merge_dicts",
"deep_update",
"renamed_class",
]

View file

@ -27,10 +27,10 @@ def PublicAPI(obj):
can expect these APIs to remain stable across RLlib releases.
Subclasses that inherit from a ``@PublicAPI`` base class can be
assumed part of the RLlib public API as well (e.g., all agent classes
are in public API because Agent is ``@PublicAPI``).
assumed part of the RLlib public API as well (e.g., all trainer classes
are in public API because Trainer is ``@PublicAPI``).
In addition, you can assume all agent configurations are part of their
In addition, you can assume all trainer configurations are part of their
public API as well.
"""

View file

@ -39,7 +39,7 @@ class PolicyServer(ThreadingMixIn, HTTPServer):
server = PolicyServer(self, "localhost", 8900)
server.serve_forever()
>>> register_env("srv", lambda _: CartpoleServing())
>>> pg = PGAgent(env="srv", config={"num_workers": 0})
>>> pg = PGTrainer(env="srv", config={"num_workers": 0})
>>> while True:
pg.train()