mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Rename Agent to Trainer (#4556)
This commit is contained in:
parent
820c71b7d0
commit
37208216ae
63 changed files with 1212 additions and 1092 deletions
|
@ -93,8 +93,8 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin
|
|||
rllib.rst
|
||||
rllib-training.rst
|
||||
rllib-env.rst
|
||||
rllib-algorithms.rst
|
||||
rllib-models.rst
|
||||
rllib-algorithms.rst
|
||||
rllib-offline.rst
|
||||
rllib-dev.rst
|
||||
rllib-concepts.rst
|
||||
|
|
File diff suppressed because one or more lines are too long
Before Width: | Height: | Size: 72 KiB After Width: | Height: | Size: 76 KiB |
File diff suppressed because one or more lines are too long
Before Width: | Height: | Size: 75 KiB After Width: | Height: | Size: 75 KiB |
|
@ -31,13 +31,13 @@ Contributing Algorithms
|
|||
These are the guidelines for merging new algorithms into RLlib:
|
||||
|
||||
* Contributed algorithms (`rllib/contrib <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib>`__):
|
||||
- must subclass Agent and implement the ``_train()`` method
|
||||
- must subclass Trainer and implement the ``_train()`` method
|
||||
- must include a lightweight test (`example <https://github.com/ray-project/ray/blob/6bb110393008c9800177490688c6ed38b2da52a9/test/jenkins_tests/run_multi_node_tests.sh#L45>`__) to ensure the algorithm runs
|
||||
- should include tuned hyperparameter examples and documentation
|
||||
- should offer functionality not present in existing algorithms
|
||||
|
||||
* Fully integrated algorithms (`rllib/agents <https://github.com/ray-project/ray/tree/master/python/ray/rllib/agents>`__) have the following additional requirements:
|
||||
- must fully implement the Agent API
|
||||
- must fully implement the Trainer API
|
||||
- must offer substantial new functionality not possible to add to other algorithms
|
||||
- should support custom models and preprocessors
|
||||
- should use RLlib abstractions and support distributed execution
|
||||
|
@ -46,14 +46,14 @@ Both integrated and contributed algorithms ship with the ``ray`` PyPI package, a
|
|||
|
||||
How to add an algorithm to ``contrib``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
It takes just two changes to add an algorithm to `contrib <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib>`__. A minimal example can be found `here <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib/random_agent/random_agent.py>`__. First, subclass `Agent <https://github.com/ray-project/ray/tree/master/python/ray/rllib/agents/agent.py>`__ and implement the ``_init`` and ``_train`` methods:
|
||||
It takes just two changes to add an algorithm to `contrib <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib>`__. A minimal example can be found `here <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib/random_agent/random_agent.py>`__. First, subclass `Trainer <https://github.com/ray-project/ray/tree/master/python/ray/rllib/agents/agent.py>`__ and implement the ``_init`` and ``_train`` methods:
|
||||
|
||||
.. literalinclude:: ../../python/ray/rllib/contrib/random_agent/random_agent.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
||||
Second, register the agent with a name in `contrib/registry.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/contrib/registry.py>`__.
|
||||
Second, register the trainer with a name in `contrib/registry.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/contrib/registry.py>`__.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -66,12 +66,12 @@ Second, register the agent with a name in `contrib/registry.py <https://github.c
|
|||
return RandomAgent2
|
||||
|
||||
CONTRIBUTED_ALGORITHMS = {
|
||||
"contrib/RandomAgent": _import_random_agent,
|
||||
"contrib/RandomAgent2": _import_random_agent_2,
|
||||
"contrib/RandomAgent": _import_random_trainer,
|
||||
"contrib/RandomAgent2": _import_random_trainer_2,
|
||||
# ...
|
||||
}
|
||||
|
||||
After registration, you can run and visualize agent progress using ``rllib train``:
|
||||
After registration, you can run and visualize training progress using ``rllib train``:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ MARWIL **Yes** `+parametric`_ **Yes** **Yes** **Yes**
|
|||
|
||||
.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces
|
||||
|
||||
You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name <https://gym.openai.com/envs>`__. Custom env classes passed directly to the agent must take a single ``env_config`` parameter in their constructor:
|
||||
You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name <https://gym.openai.com/envs>`__. Custom env classes passed directly to the trainer must take a single ``env_config`` parameter in their constructor:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -43,7 +43,7 @@ You can pass either a string name or a Python class to specify an environment. B
|
|||
return <obs>, <reward: float>, <done: bool>, <info: dict>
|
||||
|
||||
ray.init()
|
||||
trainer = ppo.PPOAgent(env=MyEnv, config={
|
||||
trainer = ppo.PPOTrainer(env=MyEnv, config={
|
||||
"env_config": {}, # config to pass to env class
|
||||
})
|
||||
|
||||
|
@ -60,7 +60,7 @@ You can also register a custom env creator function with a string name. This fun
|
|||
return MyEnv(...) # return an env instance
|
||||
|
||||
register_env("my_env", env_creator)
|
||||
trainer = ppo.PPOAgent(env="my_env")
|
||||
trainer = ppo.PPOTrainer(env="my_env")
|
||||
|
||||
For a full runnable code example using the custom environment API, see `custom_env.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__.
|
||||
|
||||
|
@ -71,7 +71,7 @@ For a full runnable code example using the custom environment API, see `custom_e
|
|||
Configuring Environments
|
||||
------------------------
|
||||
|
||||
In the above example, note that the ``env_creator`` function takes in an ``env_config`` object. This is a dict containing options passed in through your agent. You can also access ``env_config.worker_index`` and ``env_config.vector_index`` to get the worker id and env id within the worker (if ``num_envs_per_worker > 0``). This can be useful if you want to train over an ensemble of different environments, for example:
|
||||
In the above example, note that the ``env_creator`` function takes in an ``env_config`` object. This is a dict containing options passed in through your trainer. You can also access ``env_config.worker_index`` and ``env_config.vector_index`` to get the worker id and env id within the worker (if ``num_envs_per_worker > 0``). This can be useful if you want to train over an ensemble of different environments, for example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -128,7 +128,7 @@ Multi-Agent and Hierarchical
|
|||
|
||||
.. note::
|
||||
|
||||
Learn more about multi-agent reinforcement learning in RLlib by reading the `blog post <https://bair.berkeley.edu/blog/2018/12/12/rllib/>`__.
|
||||
Learn more about multi-agent reinforcement learning in RLlib by checking out some of the `code examples <rllib-examples.html#multi-agent-and-hierarchical>`__ or reading the `blog post <https://bair.berkeley.edu/blog/2018/12/12/rllib/>`__.
|
||||
|
||||
A multi-agent environment is one which has multiple acting entities per step, e.g., in a traffic simulation, there may be multiple "car" and "traffic light" agents in the environment. The model for multi-agent in RLlib as follows: (1) as a user you define the number of policies available up front, and (2) a function that maps agent ids to policy ids. This is summarized by the below figure:
|
||||
|
||||
|
|
|
@ -127,7 +127,7 @@ Custom TF models should subclass the common RLlib `model class <https://github.c
|
|||
ModelCatalog.register_custom_model("my_model", MyModelClass)
|
||||
|
||||
ray.init()
|
||||
agent = ppo.PPOAgent(env="CartPole-v0", config={
|
||||
trainer = ppo.PPOTrainer(env="CartPole-v0", config={
|
||||
"model": {
|
||||
"custom_model": "my_model",
|
||||
"custom_options": {}, # extra options to pass to your model
|
||||
|
@ -220,7 +220,7 @@ Similarly, you can create and register custom PyTorch models for use with PyTorc
|
|||
ModelCatalog.register_custom_model("my_model", CustomTorchModel)
|
||||
|
||||
ray.init()
|
||||
agent = a3c.A2CAgent(env="CartPole-v0", config={
|
||||
trainer = a3c.A2CTrainer(env="CartPole-v0", config={
|
||||
"use_pytorch": True,
|
||||
"model": {
|
||||
"custom_model": "my_model",
|
||||
|
@ -249,7 +249,7 @@ Custom preprocessors should subclass the RLlib `preprocessor class <https://gith
|
|||
ModelCatalog.register_custom_preprocessor("my_prep", MyPreprocessorClass)
|
||||
|
||||
ray.init()
|
||||
agent = ppo.PPOAgent(env="CartPole-v0", config={
|
||||
trainer = ppo.PPOTrainer(env="CartPole-v0", config={
|
||||
"model": {
|
||||
"custom_preprocessor": "my_prep",
|
||||
"custom_options": {}, # extra options to pass to your preprocessor
|
||||
|
@ -315,7 +315,7 @@ Depending on your use case it may make sense to use just the masking, just actio
|
|||
Customizing Policy Graphs
|
||||
-------------------------
|
||||
|
||||
For deeper customization of algorithms, you can modify the policy graphs of the agent classes. Here's an example of extending the DDPG policy graph to specify custom sub-network modules:
|
||||
For deeper customization of algorithms, you can modify the policy graphs of the trainer classes. Here's an example of extending the DDPG policy graph to specify custom sub-network modules:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -349,15 +349,15 @@ For deeper customization of algorithms, you can modify the policy graphs of the
|
|||
self.config["critic_hiddens"],
|
||||
self.config["critic_hidden_activation"]).value
|
||||
|
||||
Then, you can create an agent with your custom policy graph by:
|
||||
Then, you can create an trainer with your custom policy graph by:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.rllib.agents.ddpg.ddpg import DDPGAgent
|
||||
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer
|
||||
from custom_policy_graph import CustomDDPGPolicyGraph
|
||||
|
||||
DDPGAgent._policy_graph = CustomDDPGPolicyGraph
|
||||
agent = DDPGAgent(...)
|
||||
DDPGTrainer._policy_graph = CustomDDPGPolicyGraph
|
||||
trainer = DDPGTrainer(...)
|
||||
|
||||
In this example we overrode existing methods of the existing DDPG policy graph, i.e., `_build_q_network`, `_build_p_network`, `_build_action_network`, `_build_actor_critic_loss`, but you can also replace the entire graph class entirely.
|
||||
|
||||
|
|
|
@ -69,13 +69,13 @@ This example plot shows the Q-value metric in addition to importance sampling (I
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
agent = DQNAgent(...)
|
||||
... # train agent offline
|
||||
trainer = DQNTrainer(...)
|
||||
... # train policy offline
|
||||
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
||||
|
||||
estimator = WeightedImportanceSamplingEstimator(agent.get_policy(), gamma=0.99)
|
||||
estimator = WeightedImportanceSamplingEstimator(trainer.get_policy(), gamma=0.99)
|
||||
reader = JsonReader("/path/to/data")
|
||||
for _ in range(1000):
|
||||
batch = reader.next()
|
||||
|
@ -155,7 +155,7 @@ Input API
|
|||
|
||||
You can configure experience input for an agent using the following options:
|
||||
|
||||
.. literalinclude:: ../../python/ray/rllib/agents/agent.py
|
||||
.. literalinclude:: ../../python/ray/rllib/agents/trainer.py
|
||||
:language: python
|
||||
:start-after: === Offline Datasets ===
|
||||
:end-before: Specify where experiences should be saved
|
||||
|
@ -170,7 +170,7 @@ Output API
|
|||
|
||||
You can configure experience output for an agent using the following options:
|
||||
|
||||
.. literalinclude:: ../../python/ray/rllib/agents/agent.py
|
||||
.. literalinclude:: ../../python/ray/rllib/agents/trainer.py
|
||||
:language: python
|
||||
:start-after: shuffle_buffer_size
|
||||
:end-before: === Multiagent ===
|
||||
|
|
|
@ -7,16 +7,19 @@ ray.rllib.agents
|
|||
.. automodule:: ray.rllib.agents
|
||||
:members:
|
||||
|
||||
.. autoclass:: ray.rllib.agents.a3c.A2CAgent
|
||||
.. autoclass:: ray.rllib.agents.a3c.A3CAgent
|
||||
.. autoclass:: ray.rllib.agents.ddpg.ApexDDPGAgent
|
||||
.. autoclass:: ray.rllib.agents.ddpg.DDPGAgent
|
||||
.. autoclass:: ray.rllib.agents.dqn.ApexAgent
|
||||
.. autoclass:: ray.rllib.agents.dqn.DQNAgent
|
||||
.. autoclass:: ray.rllib.agents.es.ESAgent
|
||||
.. autoclass:: ray.rllib.agents.pg.PGAgent
|
||||
.. autoclass:: ray.rllib.agents.impala.ImpalaAgent
|
||||
.. autoclass:: ray.rllib.agents.ppo.PPOAgent
|
||||
.. autoclass:: ray.rllib.agents.a3c.A2CTrainer
|
||||
.. autoclass:: ray.rllib.agents.a3c.A3CTrainer
|
||||
.. autoclass:: ray.rllib.agents.ddpg.ApexDDPGTrainer
|
||||
.. autoclass:: ray.rllib.agents.ddpg.DDPGTrainer
|
||||
.. autoclass:: ray.rllib.agents.dqn.ApexTrainer
|
||||
.. autoclass:: ray.rllib.agents.dqn.DQNTrainer
|
||||
.. autoclass:: ray.rllib.agents.es.ESTrainer
|
||||
.. autoclass:: ray.rllib.agents.pg.PGTrainer
|
||||
.. autoclass:: ray.rllib.agents.impala.ImpalaTrainer
|
||||
.. autoclass:: ray.rllib.agents.ppo.APPOTrainer
|
||||
.. autoclass:: ray.rllib.agents.ppo.PPOTrainer
|
||||
.. autoclass:: ray.rllib.agents.marwil.MARWILTrainer
|
||||
|
||||
|
||||
ray.rllib.env
|
||||
-------------
|
||||
|
|
|
@ -4,13 +4,13 @@ RLlib Training APIs
|
|||
Getting Started
|
||||
---------------
|
||||
|
||||
At a high level, RLlib provides an ``Agent`` class which
|
||||
holds a policy for environment interaction. Through the agent interface, the policy can
|
||||
be trained, checkpointed, or an action computed.
|
||||
At a high level, RLlib provides an ``Trainer`` class which
|
||||
holds a policy for environment interaction. Through the trainer interface, the policy can
|
||||
be trained, checkpointed, or an action computed. In multi-agent training, the trainer manages the querying and optimization of multiple policies at once.
|
||||
|
||||
.. image:: rllib-api.svg
|
||||
|
||||
You can train a simple DQN agent with the following command:
|
||||
You can train a simple DQN trainer with the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
@ -39,15 +39,15 @@ with ``--env`` (any OpenAI gym environment including ones registered by the user
|
|||
can be used) and for choosing the algorithm with ``--run``
|
||||
(available options are ``PPO``, ``PG``, ``A2C``, ``A3C``, ``IMPALA``, ``ES``, ``DDPG``, ``DQN``, ``MARWIL``, ``APEX``, and ``APEX_DDPG``).
|
||||
|
||||
Evaluating Trained Agents
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Evaluating Trained Policies
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
In order to save checkpoints from which to evaluate agents,
|
||||
In order to save checkpoints from which to evaluate policies,
|
||||
set ``--checkpoint-freq`` (number of training iterations between checkpoints)
|
||||
when running ``rllib train``.
|
||||
|
||||
|
||||
An example of evaluating a previously trained DQN agent is as follows:
|
||||
An example of evaluating a previously trained DQN policy is as follows:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
@ -55,7 +55,7 @@ An example of evaluating a previously trained DQN agent is as follows:
|
|||
~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1 \
|
||||
--run DQN --env CartPole-v0 --steps 10000
|
||||
|
||||
The ``rollout.py`` helper script reconstructs a DQN agent from the checkpoint
|
||||
The ``rollout.py`` helper script reconstructs a DQN policy from the checkpoint
|
||||
located at ``~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1``
|
||||
and renders its behavior in the environment specified by ``--env``.
|
||||
|
||||
|
@ -65,7 +65,7 @@ Configuration
|
|||
Specifying Parameters
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Each algorithm has specific hyperparameters that can be set with ``--config``, in addition to a number of `common hyperparameters <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/agent.py>`__. See the
|
||||
Each algorithm has specific hyperparameters that can be set with ``--config``, in addition to a number of `common hyperparameters <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/trainer.py>`__. See the
|
||||
`algorithms documentation <rllib-algorithms.html>`__ for more information.
|
||||
|
||||
In an example below, we train A2C by specifying 8 workers through the config flag.
|
||||
|
@ -77,16 +77,16 @@ In an example below, we train A2C by specifying 8 workers through the config fla
|
|||
Specifying Resources
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most agents. The number of GPUs the driver should use can be set via the ``num_gpus`` option. Similarly, the resource allocation to workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five agents onto one GPU by setting ``num_gpus: 0.2``.
|
||||
You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most algorithms. The number of GPUs the driver should use can be set via the ``num_gpus`` option. Similarly, the resource allocation to workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five trainers onto one GPU by setting ``num_gpus: 0.2``.
|
||||
|
||||
.. image:: rllib-config.svg
|
||||
|
||||
Common Parameters
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
The following is a list of the common agent hyperparameters:
|
||||
The following is a list of the common algorithm hyperparameters:
|
||||
|
||||
.. literalinclude:: ../../python/ray/rllib/agents/agent.py
|
||||
.. literalinclude:: ../../python/ray/rllib/agents/trainer.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
@ -122,25 +122,25 @@ Here is an example of the basic usage (for a more complete example, see `custom_
|
|||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
config["num_gpus"] = 0
|
||||
config["num_workers"] = 1
|
||||
agent = ppo.PPOAgent(config=config, env="CartPole-v0")
|
||||
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
||||
|
||||
# Can optionally call agent.restore(path) to load a checkpoint.
|
||||
# Can optionally call trainer.restore(path) to load a checkpoint.
|
||||
|
||||
for i in range(1000):
|
||||
# Perform one iteration of training the policy with PPO
|
||||
result = agent.train()
|
||||
result = trainer.train()
|
||||
print(pretty_print(result))
|
||||
|
||||
if i % 100 == 0:
|
||||
checkpoint = agent.save()
|
||||
checkpoint = trainer.save()
|
||||
print("checkpoint saved at", checkpoint)
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
It's recommended that you run RLlib agents with `Tune <tune.html>`__, for easy experiment management and visualization of results. Just set ``"run": AGENT_NAME, "env": ENV_NAME`` in the experiment config.
|
||||
It's recommended that you run RLlib trainers with `Tune <tune.html>`__, for easy experiment management and visualization of results. Just set ``"run": ALG_NAME, "env": ENV_NAME`` in the experiment config.
|
||||
|
||||
All RLlib agents are compatible with the `Tune API <tune-usage.html>`__. This enables them to be easily used in experiments with `Tune <tune.html>`__. For example, the following code performs a simple hyperparam sweep of PPO:
|
||||
All RLlib trainers are compatible with the `Tune API <tune-usage.html>`__. This enables them to be easily used in experiments with `Tune <tune.html>`__. For example, the following code performs a simple hyperparam sweep of PPO:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -176,27 +176,27 @@ Tune will schedule the trials to run in parallel on your Ray cluster:
|
|||
Custom Training Workflows
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
In the `basic training example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__, Tune will call ``train()`` on your agent once per iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports `custom trainable functions <tune-usage.html#training-api>`__ that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_train_fn.py>`__.
|
||||
In the `basic training example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py>`__, Tune will call ``train()`` on your trainer once per iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports `custom trainable functions <tune-usage.html#training-api>`__ that can be used to implement `custom training workflows (example) <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_train_fn.py>`__.
|
||||
|
||||
Accessing Policy State
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
It is common to need to access an agent's internal state, e.g., to set or get internal weights. In RLlib an agent's state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``agent.optimizer.foreach_evaluator()`` or ``agent.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list.
|
||||
It is common to need to access a trainer's internal state, e.g., to set or get internal weights. In RLlib trainer state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``trainer.optimizer.foreach_evaluator()`` or ``trainer.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list.
|
||||
|
||||
You can also access just the "master" copy of the agent state through ``agent.get_policy()`` or ``agent.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``agent.get_policy().get_weights()``. This is also equivalent to ``agent.local_evaluator.policy_map["default_policy"].get_weights()``:
|
||||
You can also access just the "master" copy of the trainer state through ``trainer.get_policy()`` or ``trainer.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``trainer.get_policy().get_weights()``. This is also equivalent to ``trainer.local_evaluator.policy_map["default_policy"].get_weights()``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Get weights of the default local policy
|
||||
agent.get_policy().get_weights()
|
||||
trainer.get_policy().get_weights()
|
||||
|
||||
# Same as above
|
||||
agent.local_evaluator.policy_map["default_policy"].get_weights()
|
||||
trainer.local_evaluator.policy_map["default_policy"].get_weights()
|
||||
|
||||
# Get list of weights of each evaluator, including remote replicas
|
||||
agent.optimizer.foreach_evaluator(lambda ev: ev.get_policy().get_weights())
|
||||
trainer.optimizer.foreach_evaluator(lambda ev: ev.get_policy().get_weights())
|
||||
|
||||
# Same as above
|
||||
agent.optimizer.foreach_evaluator_with_index(lambda ev, i: ev.get_policy().get_weights())
|
||||
trainer.optimizer.foreach_evaluator_with_index(lambda ev, i: ev.get_policy().get_weights())
|
||||
|
||||
Global Coordination
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -252,8 +252,8 @@ You can provide callback functions to be called at points during policy evaluati
|
|||
episode.custom_metrics["pole_angle"] = pole_angle
|
||||
|
||||
def on_train_result(info):
|
||||
print("agent.train() result: {} -> {} episodes".format(
|
||||
info["agent"].__name__, info["result"]["episodes_this_iter"]))
|
||||
print("trainer.train() result: {} -> {} episodes".format(
|
||||
info["trainer"].__name__, info["result"]["episodes_this_iter"]))
|
||||
|
||||
ray.init()
|
||||
trials = tune.run(
|
||||
|
@ -278,18 +278,18 @@ Example: Curriculum Learning
|
|||
|
||||
Let's look at two ways to use the above APIs to implement `curriculum learning <https://bair.berkeley.edu/blog/2017/12/20/reverse-curriculum/>`__. In curriculum learning, the agent task is adjusted over time to improve the learning process. Suppose that we have an environment class with a ``set_phase()`` method that we can call to adjust the task difficulty over time:
|
||||
|
||||
Approach 1: Use the Agent API and update the environment between calls to ``train()``. This example shows the agent being run inside a Tune function:
|
||||
Approach 1: Use the Trainer API and update the environment between calls to ``train()``. This example shows the trainer being run inside a Tune function:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.ppo import PPOAgent
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
|
||||
def train(config, reporter):
|
||||
agent = PPOAgent(config=config, env=YourEnv)
|
||||
trainer = PPOTrainer(config=config, env=YourEnv)
|
||||
while True:
|
||||
result = agent.train()
|
||||
result = trainer.train()
|
||||
reporter(**result)
|
||||
if result["episode_reward_mean"] > 200:
|
||||
phase = 2
|
||||
|
@ -297,7 +297,7 @@ Approach 1: Use the Agent API and update the environment between calls to ``trai
|
|||
phase = 1
|
||||
else:
|
||||
phase = 0
|
||||
agent.optimizer.foreach_evaluator(
|
||||
trainer.optimizer.foreach_evaluator(
|
||||
lambda ev: ev.foreach_env(
|
||||
lambda env: env.set_phase(phase)))
|
||||
|
||||
|
@ -330,8 +330,8 @@ Approach 2: Use the callbacks API to update the environment on new training resu
|
|||
phase = 1
|
||||
else:
|
||||
phase = 0
|
||||
agent = info["agent"]
|
||||
agent.optimizer.foreach_evaluator(
|
||||
trainer = info["trainer"]
|
||||
trainer.optimizer.foreach_evaluator(
|
||||
lambda ev: ev.foreach_env(
|
||||
lambda env: env.set_phase(phase)))
|
||||
|
||||
|
@ -382,7 +382,7 @@ You can use the `data output API <rllib-offline.html>`__ to save episode traces
|
|||
Log Verbosity
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
You can control the agent log level via the ``"log_level"`` flag. Valid values are "INFO" (default), "DEBUG", "WARN", and "ERROR". This can be used to increase or decrease the verbosity of internal logging. For example:
|
||||
You can control the trainer log level via the ``"log_level"`` flag. Valid values are "INFO" (default), "DEBUG", "WARN", and "ERROR". This can be used to increase or decrease the verbosity of internal logging. For example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
|
|
@ -42,6 +42,16 @@ Environments
|
|||
* `Interfacing with External Agents <rllib-env.html#interfacing-with-external-agents>`__
|
||||
* `Advanced Integrations <rllib-env.html#advanced-integrations>`__
|
||||
|
||||
Models and Preprocessors
|
||||
------------------------
|
||||
* `RLlib Models and Preprocessors Overview <rllib-models.html>`__
|
||||
* `Custom Models (TensorFlow) <rllib-models.html#custom-models-tensorflow>`__
|
||||
* `Custom Models (PyTorch) <rllib-models.html#custom-models-pytorch>`__
|
||||
* `Custom Preprocessors <rllib-models.html#custom-preprocessors>`__
|
||||
* `Supervised Model Losses <rllib-models.html#supervised-model-losses>`__
|
||||
* `Variable-length / Parametric Action Spaces <rllib-models.html#variable-length-parametric-action-spaces>`__
|
||||
* `Customizing Policy Graphs <rllib-models.html#customizing-policy-graphs>`__
|
||||
|
||||
Algorithms
|
||||
----------
|
||||
|
||||
|
@ -79,16 +89,6 @@ Algorithms
|
|||
|
||||
- `Advantage Re-Weighted Imitation Learning (MARWIL) <rllib-algorithms.html#advantage-re-weighted-imitation-learning-marwil>`__
|
||||
|
||||
Models and Preprocessors
|
||||
------------------------
|
||||
* `RLlib Models and Preprocessors Overview <rllib-models.html>`__
|
||||
* `Custom Models (TensorFlow) <rllib-models.html#custom-models-tensorflow>`__
|
||||
* `Custom Models (PyTorch) <rllib-models.html#custom-models-pytorch>`__
|
||||
* `Custom Preprocessors <rllib-models.html#custom-preprocessors>`__
|
||||
* `Supervised Model Losses <rllib-models.html#supervised-model-losses>`__
|
||||
* `Variable-length / Parametric Action Spaces <rllib-models.html#variable-length-parametric-action-spaces>`__
|
||||
* `Customizing Policy Graphs <rllib-models.html#customizing-policy-graphs>`__
|
||||
|
||||
Offline Datasets
|
||||
----------------
|
||||
* `Working with Offline Datasets <rllib-offline.html>`__
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.agent import Agent
|
||||
|
||||
__all__ = ["Agent", "with_common_config"]
|
||||
__all__ = ["Agent", "Trainer", "with_common_config"]
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
from ray.rllib.agents.a3c.a3c import A3CAgent, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.a3c.a2c import A2CAgent
|
||||
from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.a3c.a2c import A2CTrainer
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
__all__ = ["A2CAgent", "A3CAgent", "DEFAULT_CONFIG"]
|
||||
A2CAgent = renamed_class(A2CTrainer)
|
||||
A3CAgent = renamed_class(A3CTrainer)
|
||||
|
||||
__all__ = [
|
||||
"A2CAgent", "A3CAgent", "A2CTrainer", "A3CTrainer", "DEFAULT_CONFIG"
|
||||
]
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.a3c.a3c import A3CAgent, DEFAULT_CONFIG as A3C_CONFIG
|
||||
from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG as A3C_CONFIG
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
@ -17,13 +17,13 @@ A2C_DEFAULT_CONFIG = merge_dicts(
|
|||
)
|
||||
|
||||
|
||||
class A2CAgent(A3CAgent):
|
||||
"""Synchronous variant of the A3CAgent."""
|
||||
class A2CTrainer(A3CTrainer):
|
||||
"""Synchronous variant of the A3CTrainer."""
|
||||
|
||||
_agent_name = "A2C"
|
||||
_name = "A2C"
|
||||
_default_config = A2C_DEFAULT_CONFIG
|
||||
|
||||
@override(A3CAgent)
|
||||
@override(A3CTrainer)
|
||||
def _make_optimizer(self):
|
||||
return SyncSamplesOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators,
|
||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import print_function
|
|||
import time
|
||||
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
@ -38,14 +38,14 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
class A3CAgent(Agent):
|
||||
class A3CTrainer(Trainer):
|
||||
"""A3C implementations in TensorFlow and PyTorch."""
|
||||
|
||||
_agent_name = "A3C"
|
||||
_name = "A3C"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = A3CPolicyGraph
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
if config["use_pytorch"]:
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy_graph import \
|
||||
|
@ -63,7 +63,7 @@ class A3CAgent(Agent):
|
|||
env_creator, policy_cls, config["num_workers"])
|
||||
self.optimizer = self._make_optimizer()
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
start = time.time()
|
||||
|
|
|
@ -2,797 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from datetime import datetime
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import six
|
||||
import tempfile
|
||||
import tensorflow as tf
|
||||
from types import FunctionType
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayError
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
||||
ShuffledInput
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \
|
||||
_validate_multiagent_config
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import Resources, ExportFormat
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Max number of times to retry a worker failure. We shouldn't try too many
|
||||
# times in a row since that would indicate a persistent cluster issue.
|
||||
MAX_WORKER_FAILURE_RETRIES = 3
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
COMMON_CONFIG = {
|
||||
# === Debugging ===
|
||||
# Whether to write episode stats and videos to the agent log dir
|
||||
"monitor": False,
|
||||
# Set the ray.rllib.* log level for the agent process and its evaluators.
|
||||
# Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also
|
||||
# periodically print out summaries of relevant internal dataflow (this is
|
||||
# also printed out once at startup at the INFO level).
|
||||
"log_level": "INFO",
|
||||
# Callbacks that will be run during various phases of training. These all
|
||||
# take a single "info" dict as an argument. For episode callbacks, custom
|
||||
# metrics can be attached to the episode by updating the episode object's
|
||||
# custom metrics dict (see examples/custom_metrics_and_callbacks.py).
|
||||
"callbacks": {
|
||||
"on_episode_start": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_step": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_end": None, # arg: {"env": .., "episode": ...}
|
||||
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
|
||||
"on_train_result": None, # arg: {"agent": ..., "result": ...}
|
||||
},
|
||||
# Whether to attempt to continue training if a worker crashes.
|
||||
"ignore_worker_failures": False,
|
||||
|
||||
# === Policy ===
|
||||
# Arguments to pass to model. See models/catalog.py for a full list of the
|
||||
# available model options.
|
||||
"model": MODEL_DEFAULTS,
|
||||
# Arguments to pass to the policy optimizer. These vary by optimizer.
|
||||
"optimizer": {},
|
||||
|
||||
# === Environment ===
|
||||
# Discount factor of the MDP
|
||||
"gamma": 0.99,
|
||||
# Number of steps after which the episode is forced to terminate. Defaults
|
||||
# to `env.spec.max_episode_steps` (if present) for Gym envs.
|
||||
"horizon": None,
|
||||
# Calculate rewards but don't reset the environment when the horizon is
|
||||
# hit. This allows value estimation and RNN state to span across logical
|
||||
# episodes denoted by horizon. This only has an effect if horizon != inf.
|
||||
"soft_horizon": False,
|
||||
# Arguments to pass to the env creator
|
||||
"env_config": {},
|
||||
# Environment name can also be passed via config
|
||||
"env": None,
|
||||
# Whether to clip rewards prior to experience postprocessing. Setting to
|
||||
# None means clip for Atari only.
|
||||
"clip_rewards": None,
|
||||
# Whether to np.clip() actions to the action space low/high range spec.
|
||||
"clip_actions": True,
|
||||
# Whether to use rllib or deepmind preprocessors by default
|
||||
"preprocessor_pref": "deepmind",
|
||||
|
||||
# === Resources ===
|
||||
# Number of actors used for parallelism
|
||||
"num_workers": 2,
|
||||
# Number of GPUs to allocate to the driver. Note that not all algorithms
|
||||
# can take advantage of driver GPUs. This can be fraction (e.g., 0.3 GPUs).
|
||||
"num_gpus": 0,
|
||||
# Number of CPUs to allocate per worker.
|
||||
"num_cpus_per_worker": 1,
|
||||
# Number of GPUs to allocate per worker. This can be fractional.
|
||||
"num_gpus_per_worker": 0,
|
||||
# Any custom resources to allocate per worker.
|
||||
"custom_resources_per_worker": {},
|
||||
# Number of CPUs to allocate for the driver. Note: this only takes effect
|
||||
# when running in Tune.
|
||||
"num_cpus_for_driver": 1,
|
||||
|
||||
# === Execution ===
|
||||
# Number of environments to evaluate vectorwise per worker.
|
||||
"num_envs_per_worker": 1,
|
||||
# Default sample batch size (unroll length). Batches of this size are
|
||||
# collected from workers until train_batch_size is met. When using
|
||||
# multiple envs per worker, this is multiplied by num_envs_per_worker.
|
||||
"sample_batch_size": 200,
|
||||
# Training batch size, if applicable. Should be >= sample_batch_size.
|
||||
# Samples batches will be concatenated together to this size for training.
|
||||
"train_batch_size": 200,
|
||||
# Whether to rollout "complete_episodes" or "truncate_episodes"
|
||||
"batch_mode": "truncate_episodes",
|
||||
# (Deprecated) Use a background thread for sampling (slightly off-policy)
|
||||
"sample_async": False,
|
||||
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter"
|
||||
"observation_filter": "NoFilter",
|
||||
# Whether to synchronize the statistics of remote filters.
|
||||
"synchronize_filters": True,
|
||||
# Configure TF for single-process operation by default
|
||||
"tf_session_args": {
|
||||
# note: overriden by `local_evaluator_tf_session_args`
|
||||
"intra_op_parallelism_threads": 2,
|
||||
"inter_op_parallelism_threads": 2,
|
||||
"gpu_options": {
|
||||
"allow_growth": True,
|
||||
},
|
||||
"log_device_placement": False,
|
||||
"device_count": {
|
||||
"CPU": 1
|
||||
},
|
||||
"allow_soft_placement": True, # required by PPO multi-gpu
|
||||
},
|
||||
# Override the following tf session args on the local evaluator
|
||||
"local_evaluator_tf_session_args": {
|
||||
# Allow a higher level of parallelism by default, but not unlimited
|
||||
# since that can cause crashes with many concurrent drivers.
|
||||
"intra_op_parallelism_threads": 8,
|
||||
"inter_op_parallelism_threads": 8,
|
||||
},
|
||||
# Whether to LZ4 compress individual observations
|
||||
"compress_observations": False,
|
||||
# Drop metric batches from unresponsive workers after this many seconds
|
||||
"collect_metrics_timeout": 180,
|
||||
# Smooth metrics over this many episodes.
|
||||
"metrics_smoothing_episodes": 100,
|
||||
# If using num_envs_per_worker > 1, whether to create those new envs in
|
||||
# remote processes instead of in the same worker. This adds overheads, but
|
||||
# can make sense if your envs can take much time to step / reset
|
||||
# (e.g., for StarCraft). Use this cautiously; overheads are significant.
|
||||
"remote_worker_envs": False,
|
||||
# Timeout that remote workers are waiting when polling environments.
|
||||
# 0 (continue when at least one env is ready) is a reasonable default,
|
||||
# but optimal value could be obtained by measuring your environment
|
||||
# step / reset and model inference perf.
|
||||
"remote_env_batch_wait_ms": 0,
|
||||
|
||||
# === Offline Datasets ===
|
||||
# Specify how to generate experiences:
|
||||
# - "sampler": generate experiences via online simulation (default)
|
||||
# - a local directory or file glob expression (e.g., "/tmp/*.json")
|
||||
# - a list of individual file paths/URIs (e.g., ["/tmp/1.json",
|
||||
# "s3://bucket/2.json"])
|
||||
# - a dict with string keys and sampling probabilities as values (e.g.,
|
||||
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
|
||||
# - a function that returns a rllib.offline.InputReader
|
||||
"input": "sampler",
|
||||
# Specify how to evaluate the current policy. This only has an effect when
|
||||
# reading offline experiences. Available options:
|
||||
# - "wis": the weighted step-wise importance sampling estimator.
|
||||
# - "is": the step-wise importance sampling estimator.
|
||||
# - "simulation": run the environment in the background, but use
|
||||
# this data for evaluation only and not for learning.
|
||||
"input_evaluation": ["is", "wis"],
|
||||
# Whether to run postprocess_trajectory() on the trajectory fragments from
|
||||
# offline inputs. Note that postprocessing will be done using the *current*
|
||||
# policy, not the *behaviour* policy, which is typically undesirable for
|
||||
# on-policy algorithms.
|
||||
"postprocess_inputs": False,
|
||||
# If positive, input batches will be shuffled via a sliding window buffer
|
||||
# of this number of batches. Use this if the input data is not in random
|
||||
# enough order. Input is delayed until the shuffle buffer is filled.
|
||||
"shuffle_buffer_size": 0,
|
||||
# Specify where experiences should be saved:
|
||||
# - None: don't save any experiences
|
||||
# - "logdir" to save to the agent log dir
|
||||
# - a path/URI to save to a custom output directory (e.g., "s3://bucket/")
|
||||
# - a function that returns a rllib.offline.OutputWriter
|
||||
"output": None,
|
||||
# What sample batch columns to LZ4 compress in the output data.
|
||||
"output_compress_columns": ["obs", "new_obs"],
|
||||
# Max output file size before rolling over to a new file.
|
||||
"output_max_file_size": 64 * 1024 * 1024,
|
||||
|
||||
# === Multiagent ===
|
||||
"multiagent": {
|
||||
# Map from policy ids to tuples of (policy_graph_cls, obs_space,
|
||||
# act_space, config). See policy_evaluator.py for more info.
|
||||
"policy_graphs": {},
|
||||
# Function mapping agent ids to policy ids.
|
||||
"policy_mapping_fn": None,
|
||||
# Optional whitelist of policies to train, or None for all policies.
|
||||
"policies_to_train": None,
|
||||
},
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def with_common_config(extra_config):
|
||||
"""Returns the given config dict merged with common agent confs."""
|
||||
|
||||
return with_base_config(COMMON_CONFIG, extra_config)
|
||||
|
||||
|
||||
def with_base_config(base_config, extra_config):
|
||||
"""Returns the given config dict merged with a base agent conf."""
|
||||
|
||||
config = copy.deepcopy(base_config)
|
||||
config.update(extra_config)
|
||||
return config
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class Agent(Trainable):
|
||||
"""All RLlib agents extend this base class.
|
||||
|
||||
Agent objects retain internal model state between calls to train(), so
|
||||
you should create a new agent instance for each training session.
|
||||
|
||||
Attributes:
|
||||
env_creator (func): Function that creates a new training env.
|
||||
config (obj): Algorithm-specific configuration data.
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
"""
|
||||
|
||||
_allow_unknown_configs = False
|
||||
_allow_unknown_subkeys = [
|
||||
"tf_session_args", "env_config", "model", "optimizer", "multiagent",
|
||||
"custom_resources_per_worker"
|
||||
]
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, config=None, env=None, logger_creator=None):
|
||||
"""Initialize an RLLib agent.
|
||||
|
||||
Args:
|
||||
config (dict): Algorithm-specific configuration data.
|
||||
env (str): Name of the environment to use. Note that this can also
|
||||
be specified as the `env` key in config.
|
||||
logger_creator (func): Function that creates a ray.tune.Logger
|
||||
object. If unspecified, a default logger is created.
|
||||
"""
|
||||
|
||||
config = config or {}
|
||||
|
||||
# Vars to synchronize to evaluators on each train call
|
||||
self.global_vars = {"timestep": 0}
|
||||
|
||||
# Agents allow env ids to be passed directly to the constructor.
|
||||
self._env_id = self._register_if_needed(env or config.get("env"))
|
||||
|
||||
# Create a default logger creator if no logger_creator is specified
|
||||
if logger_creator is None:
|
||||
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
logdir_prefix = "{}_{}_{}".format(self._agent_name, self._env_id,
|
||||
timestr)
|
||||
|
||||
def default_logger_creator(config):
|
||||
"""Creates a Unified logger with a default logdir prefix
|
||||
containing the agent name and the env id
|
||||
"""
|
||||
if not os.path.exists(DEFAULT_RESULTS_DIR):
|
||||
os.makedirs(DEFAULT_RESULTS_DIR)
|
||||
logdir = tempfile.mkdtemp(
|
||||
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
|
||||
return UnifiedLogger(config, logdir, None)
|
||||
|
||||
logger_creator = default_logger_creator
|
||||
|
||||
Trainable.__init__(self, config, logger_creator)
|
||||
|
||||
@classmethod
|
||||
@override(Trainable)
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
Agent._validate_config(cf)
|
||||
# TODO(ekl): add custom resources here once tune supports them
|
||||
return Resources(
|
||||
cpu=cf["num_cpus_for_driver"],
|
||||
gpu=cf["num_gpus"],
|
||||
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
|
||||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
|
||||
|
||||
@override(Trainable)
|
||||
@PublicAPI
|
||||
def train(self):
|
||||
"""Overrides super.train to synchronize global vars."""
|
||||
|
||||
if self._has_policy_optimizer():
|
||||
self.global_vars["timestep"] = self.optimizer.num_steps_sampled
|
||||
self.optimizer.local_evaluator.set_global_vars(self.global_vars)
|
||||
for ev in self.optimizer.remote_evaluators:
|
||||
ev.set_global_vars.remote(self.global_vars)
|
||||
logger.debug("updated global vars: {}".format(self.global_vars))
|
||||
|
||||
result = None
|
||||
for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
|
||||
try:
|
||||
result = Trainable.train(self)
|
||||
except RayError as e:
|
||||
if self.config["ignore_worker_failures"]:
|
||||
logger.exception(
|
||||
"Error in train call, attempting to recover")
|
||||
self._try_recover()
|
||||
else:
|
||||
logger.info(
|
||||
"Worker crashed during call to train(). To attempt to "
|
||||
"continue training without the failed worker, set "
|
||||
"`'ignore_worker_failures': True`.")
|
||||
raise e
|
||||
else:
|
||||
break
|
||||
if result is None:
|
||||
raise RuntimeError("Failed to recover from worker crash")
|
||||
|
||||
if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
|
||||
and hasattr(self, "local_evaluator")):
|
||||
FilterManager.synchronize(
|
||||
self.local_evaluator.filters,
|
||||
self.remote_evaluators,
|
||||
update_remote=self.config["synchronize_filters"])
|
||||
logger.debug("synchronized filters: {}".format(
|
||||
self.local_evaluator.filters))
|
||||
|
||||
if self._has_policy_optimizer():
|
||||
result["num_healthy_workers"] = len(
|
||||
self.optimizer.remote_evaluators)
|
||||
return result
|
||||
|
||||
@override(Trainable)
|
||||
def _log_result(self, result):
|
||||
if self.config["callbacks"].get("on_train_result"):
|
||||
self.config["callbacks"]["on_train_result"]({
|
||||
"agent": self,
|
||||
"result": result,
|
||||
})
|
||||
# log after the callback is invoked, so that the user has a chance
|
||||
# to mutate the result
|
||||
Trainable._log_result(self, result)
|
||||
|
||||
@override(Trainable)
|
||||
def _setup(self, config):
|
||||
env = self._env_id
|
||||
if env:
|
||||
config["env"] = env
|
||||
if _global_registry.contains(ENV_CREATOR, env):
|
||||
self.env_creator = _global_registry.get(ENV_CREATOR, env)
|
||||
else:
|
||||
import gym # soft dependency
|
||||
self.env_creator = lambda env_config: gym.make(env)
|
||||
else:
|
||||
self.env_creator = lambda env_config: None
|
||||
|
||||
# Merge the supplied config with the class default
|
||||
merged_config = copy.deepcopy(self._default_config)
|
||||
merged_config = deep_update(merged_config, config,
|
||||
self._allow_unknown_configs,
|
||||
self._allow_unknown_subkeys)
|
||||
self.raw_user_config = config
|
||||
self.config = merged_config
|
||||
Agent._validate_config(self.config)
|
||||
if self.config.get("log_level"):
|
||||
logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
|
||||
|
||||
# TODO(ekl) setting the graph is unnecessary for PyTorch agents
|
||||
with tf.Graph().as_default():
|
||||
self._init(self.config, self.env_creator)
|
||||
|
||||
@override(Trainable)
|
||||
def _stop(self):
|
||||
# Call stop on all evaluators to release resources
|
||||
if hasattr(self, "local_evaluator"):
|
||||
self.local_evaluator.stop()
|
||||
if hasattr(self, "remote_evaluators"):
|
||||
for ev in self.remote_evaluators:
|
||||
ev.stop.remote()
|
||||
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
if hasattr(self, "remote_evaluators"):
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote()
|
||||
|
||||
if hasattr(self, "optimizer"):
|
||||
self.optimizer.stop()
|
||||
|
||||
@override(Trainable)
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(self.iteration))
|
||||
pickle.dump(self.__getstate__(), open(checkpoint_path, "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
@override(Trainable)
|
||||
def _restore(self, checkpoint_path):
|
||||
extra_data = pickle.load(open(checkpoint_path, "rb"))
|
||||
self.__setstate__(extra_data)
|
||||
|
||||
@DeveloperAPI
|
||||
def _init(self, config, env_creator):
|
||||
"""Subclasses should override this for custom initialization."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def compute_action(self,
|
||||
observation,
|
||||
state=None,
|
||||
prev_action=None,
|
||||
prev_reward=None,
|
||||
info=None,
|
||||
policy_id=DEFAULT_POLICY_ID):
|
||||
"""Computes an action for the specified policy.
|
||||
|
||||
Note that you can also access the policy object through
|
||||
self.get_policy(policy_id) and call compute_actions() on it directly.
|
||||
|
||||
Arguments:
|
||||
observation (obj): observation from the environment.
|
||||
state (list): RNN hidden state, if any. If state is not None,
|
||||
then all of compute_single_action(...) is returned
|
||||
(computed action, rnn state, logits dictionary).
|
||||
Otherwise compute_single_action(...)[0] is
|
||||
returned (computed action).
|
||||
prev_action (obj): previous action value, if any
|
||||
prev_reward (int): previous reward, if any
|
||||
info (dict): info object, if any
|
||||
policy_id (str): policy to query (only applies to multi-agent).
|
||||
"""
|
||||
|
||||
if state is None:
|
||||
state = []
|
||||
preprocessed = self.local_evaluator.preprocessors[policy_id].transform(
|
||||
observation)
|
||||
filtered_obs = self.local_evaluator.filters[policy_id](
|
||||
preprocessed, update=False)
|
||||
if state:
|
||||
return self.get_policy(policy_id).compute_single_action(
|
||||
filtered_obs,
|
||||
state,
|
||||
prev_action,
|
||||
prev_reward,
|
||||
info,
|
||||
clip_actions=self.config["clip_actions"])
|
||||
return self.get_policy(policy_id).compute_single_action(
|
||||
filtered_obs,
|
||||
state,
|
||||
prev_action,
|
||||
prev_reward,
|
||||
info,
|
||||
clip_actions=self.config["clip_actions"])[0]
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
"""Current training iter, auto-incremented with each train() call."""
|
||||
|
||||
return self._iteration
|
||||
|
||||
@property
|
||||
def _agent_name(self):
|
||||
"""Subclasses should override this to declare their name."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _default_config(self):
|
||||
"""Subclasses should override this to declare their default config."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Return policy graph for the specified id, or None.
|
||||
|
||||
Arguments:
|
||||
policy_id (str): id of policy graph to return.
|
||||
"""
|
||||
|
||||
return self.local_evaluator.get_policy(policy_id)
|
||||
|
||||
@PublicAPI
|
||||
def get_weights(self, policies=None):
|
||||
"""Return a dictionary of policy ids to weights.
|
||||
|
||||
Arguments:
|
||||
policies (list): Optional list of policies to return weights for,
|
||||
or None for all policies.
|
||||
"""
|
||||
return self.local_evaluator.get_weights(policies)
|
||||
|
||||
@PublicAPI
|
||||
def set_weights(self, weights):
|
||||
"""Set policy weights by policy id.
|
||||
|
||||
Arguments:
|
||||
weights (dict): Map of policy ids to weights to set.
|
||||
"""
|
||||
self.local_evaluator.set_weights(weights)
|
||||
|
||||
@DeveloperAPI
|
||||
def make_local_evaluator(self,
|
||||
env_creator,
|
||||
policy_graph,
|
||||
extra_config=None):
|
||||
"""Convenience method to return configured local evaluator."""
|
||||
|
||||
return self._make_evaluator(
|
||||
PolicyEvaluator,
|
||||
env_creator,
|
||||
policy_graph,
|
||||
0,
|
||||
merge_dicts(
|
||||
# important: allow local tf to use more CPUs for optimization
|
||||
merge_dicts(
|
||||
self.config, {
|
||||
"tf_session_args": self.
|
||||
config["local_evaluator_tf_session_args"]
|
||||
}),
|
||||
extra_config or {}))
|
||||
|
||||
@DeveloperAPI
|
||||
def make_remote_evaluators(self, env_creator, policy_graph, count):
|
||||
"""Convenience method to return a number of remote evaluators."""
|
||||
|
||||
remote_args = {
|
||||
"num_cpus": self.config["num_cpus_per_worker"],
|
||||
"num_gpus": self.config["num_gpus_per_worker"],
|
||||
"resources": self.config["custom_resources_per_worker"],
|
||||
}
|
||||
|
||||
cls = PolicyEvaluator.as_remote(**remote_args).remote
|
||||
|
||||
return [
|
||||
self._make_evaluator(cls, env_creator, policy_graph, i + 1,
|
||||
self.config) for i in range(count)
|
||||
]
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Export policy model with given policy_id to local directory.
|
||||
|
||||
Arguments:
|
||||
export_dir (string): Writable local directory.
|
||||
policy_id (string): Optional policy id to export.
|
||||
|
||||
Example:
|
||||
>>> agent = MyAgent()
|
||||
>>> for _ in range(10):
|
||||
>>> agent.train()
|
||||
>>> agent.export_policy_model("/tmp/export_dir")
|
||||
"""
|
||||
self.local_evaluator.export_policy_model(export_dir, policy_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_checkpoint(self,
|
||||
export_dir,
|
||||
filename_prefix="model",
|
||||
policy_id=DEFAULT_POLICY_ID):
|
||||
"""Export tensorflow policy model checkpoint to local directory.
|
||||
|
||||
Arguments:
|
||||
export_dir (string): Writable local directory.
|
||||
filename_prefix (string): file name prefix of checkpoint files.
|
||||
policy_id (string): Optional policy id to export.
|
||||
|
||||
Example:
|
||||
>>> agent = MyAgent()
|
||||
>>> for _ in range(10):
|
||||
>>> agent.train()
|
||||
>>> agent.export_policy_checkpoint("/tmp/export_dir")
|
||||
"""
|
||||
self.local_evaluator.export_policy_checkpoint(
|
||||
export_dir, filename_prefix, policy_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def collect_metrics(self, selected_evaluators=None):
|
||||
"""Collects metrics from the remote evaluators of this agent.
|
||||
|
||||
This is the same data as returned by a call to train().
|
||||
"""
|
||||
return self.optimizer.collect_metrics(
|
||||
self.config["collect_metrics_timeout"],
|
||||
min_history=self.config["metrics_smoothing_episodes"],
|
||||
selected_evaluators=selected_evaluators)
|
||||
|
||||
@classmethod
|
||||
def resource_help(cls, config):
|
||||
return ("\n\nYou can adjust the resource requests of RLlib agents by "
|
||||
"setting `num_workers`, `num_gpus`, and other configs. See "
|
||||
"the DEFAULT_CONFIG defined by each agent for more info.\n\n"
|
||||
"The config of this agent is: {}".format(config))
|
||||
|
||||
@staticmethod
|
||||
def _validate_config(config):
|
||||
if "gpu" in config:
|
||||
raise ValueError(
|
||||
"The `gpu` config is deprecated, please use `num_gpus=0|1` "
|
||||
"instead.")
|
||||
if "gpu_fraction" in config:
|
||||
raise ValueError(
|
||||
"The `gpu_fraction` config is deprecated, please use "
|
||||
"`num_gpus=<fraction>` instead.")
|
||||
if "use_gpu_for_workers" in config:
|
||||
raise ValueError(
|
||||
"The `use_gpu_for_workers` config is deprecated, please use "
|
||||
"`num_gpus_per_worker=1` instead.")
|
||||
if type(config["input_evaluation"]) != list:
|
||||
raise ValueError(
|
||||
"`input_evaluation` must be a list of strings, got {}".format(
|
||||
config["input_evaluation"]))
|
||||
|
||||
def _try_recover(self):
|
||||
"""Try to identify and blacklist any unhealthy workers.
|
||||
|
||||
This method is called after an unexpected remote error is encountered
|
||||
from a worker. It issues check requests to all current workers and
|
||||
blacklists any that respond with error. If no healthy workers remain,
|
||||
an error is raised.
|
||||
"""
|
||||
|
||||
if not self._has_policy_optimizer():
|
||||
raise NotImplementedError(
|
||||
"Recovery is not supported for this algorithm")
|
||||
|
||||
logger.info("Health checking all workers...")
|
||||
checks = []
|
||||
for ev in self.optimizer.remote_evaluators:
|
||||
_, obj_id = ev.sample_with_count.remote()
|
||||
checks.append(obj_id)
|
||||
|
||||
healthy_evaluators = []
|
||||
for i, obj_id in enumerate(checks):
|
||||
ev = self.optimizer.remote_evaluators[i]
|
||||
try:
|
||||
ray.get(obj_id)
|
||||
healthy_evaluators.append(ev)
|
||||
logger.info("Worker {} looks healthy".format(i + 1))
|
||||
except RayError:
|
||||
logger.exception("Blacklisting worker {}".format(i + 1))
|
||||
try:
|
||||
ev.__ray_terminate__.remote()
|
||||
except Exception:
|
||||
logger.exception("Error terminating unhealthy worker")
|
||||
|
||||
if len(healthy_evaluators) < 1:
|
||||
raise RuntimeError(
|
||||
"Not enough healthy workers remain to continue.")
|
||||
|
||||
self.optimizer.reset(healthy_evaluators)
|
||||
|
||||
def _has_policy_optimizer(self):
|
||||
return hasattr(self, "optimizer") and isinstance(
|
||||
self.optimizer, PolicyOptimizer)
|
||||
|
||||
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
|
||||
config):
|
||||
def session_creator():
|
||||
logger.debug("Creating TF session {}".format(
|
||||
config["tf_session_args"]))
|
||||
return tf.Session(
|
||||
config=tf.ConfigProto(**config["tf_session_args"]))
|
||||
|
||||
if isinstance(config["input"], FunctionType):
|
||||
input_creator = config["input"]
|
||||
elif config["input"] == "sampler":
|
||||
input_creator = (lambda ioctx: ioctx.default_sampler_input())
|
||||
elif isinstance(config["input"], dict):
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
MixedInput(config["input"], ioctx), config[
|
||||
"shuffle_buffer_size"]))
|
||||
else:
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
JsonReader(config["input"], ioctx), config[
|
||||
"shuffle_buffer_size"]))
|
||||
|
||||
if isinstance(config["output"], FunctionType):
|
||||
output_creator = config["output"]
|
||||
elif config["output"] is None:
|
||||
output_creator = (lambda ioctx: NoopOutput())
|
||||
elif config["output"] == "logdir":
|
||||
output_creator = (lambda ioctx: JsonWriter(
|
||||
ioctx.log_dir,
|
||||
ioctx,
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
else:
|
||||
output_creator = (lambda ioctx: JsonWriter(
|
||||
config["output"],
|
||||
ioctx,
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
|
||||
if config["input"] == "sampler":
|
||||
input_evaluation = []
|
||||
else:
|
||||
input_evaluation = config["input_evaluation"]
|
||||
|
||||
# Fill in the default policy graph if 'None' is specified in multiagent
|
||||
if self.config["multiagent"]["policy_graphs"]:
|
||||
tmp = self.config["multiagent"]["policy_graphs"]
|
||||
_validate_multiagent_config(tmp, allow_none_graph=True)
|
||||
for k, v in tmp.items():
|
||||
if v[0] is None:
|
||||
tmp[k] = (policy_graph, v[1], v[2], v[3])
|
||||
policy_graph = tmp
|
||||
|
||||
return cls(
|
||||
env_creator,
|
||||
policy_graph,
|
||||
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
|
||||
policies_to_train=self.config["multiagent"]["policies_to_train"],
|
||||
tf_session_creator=(session_creator
|
||||
if config["tf_session_args"] else None),
|
||||
batch_steps=config["sample_batch_size"],
|
||||
batch_mode=config["batch_mode"],
|
||||
episode_horizon=config["horizon"],
|
||||
preprocessor_pref=config["preprocessor_pref"],
|
||||
sample_async=config["sample_async"],
|
||||
compress_observations=config["compress_observations"],
|
||||
num_envs=config["num_envs_per_worker"],
|
||||
observation_filter=config["observation_filter"],
|
||||
clip_rewards=config["clip_rewards"],
|
||||
clip_actions=config["clip_actions"],
|
||||
env_config=config["env_config"],
|
||||
model_config=config["model"],
|
||||
policy_config=config,
|
||||
worker_index=worker_index,
|
||||
monitor_path=self.logdir if config["monitor"] else None,
|
||||
log_dir=self.logdir,
|
||||
log_level=config["log_level"],
|
||||
callbacks=config["callbacks"],
|
||||
input_creator=input_creator,
|
||||
input_evaluation=input_evaluation,
|
||||
output_creator=output_creator,
|
||||
remote_worker_envs=config["remote_worker_envs"],
|
||||
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
|
||||
soft_horizon=config["soft_horizon"],
|
||||
_fake_sampler=config.get("_fake_sampler", False))
|
||||
|
||||
@override(Trainable)
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
ExportFormat.validate(export_formats)
|
||||
exported = {}
|
||||
if ExportFormat.CHECKPOINT in export_formats:
|
||||
path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
|
||||
self.export_policy_checkpoint(path)
|
||||
exported[ExportFormat.CHECKPOINT] = path
|
||||
if ExportFormat.MODEL in export_formats:
|
||||
path = os.path.join(export_dir, ExportFormat.MODEL)
|
||||
self.export_policy_model(path)
|
||||
exported[ExportFormat.MODEL] = path
|
||||
return exported
|
||||
|
||||
def __getstate__(self):
|
||||
state = {}
|
||||
if hasattr(self, "local_evaluator"):
|
||||
state["evaluator"] = self.local_evaluator.save()
|
||||
if hasattr(self, "optimizer") and hasattr(self.optimizer, "save"):
|
||||
state["optimizer"] = self.optimizer.save()
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
if "evaluator" in state:
|
||||
self.local_evaluator.restore(state["evaluator"])
|
||||
remote_state = ray.put(state["evaluator"])
|
||||
for r in self.remote_evaluators:
|
||||
r.restore.remote(remote_state)
|
||||
if "optimizer" in state:
|
||||
self.optimizer.restore(state["optimizer"])
|
||||
|
||||
def _register_if_needed(self, env_object):
|
||||
if isinstance(env_object, six.string_types):
|
||||
return env_object
|
||||
elif isinstance(env_object, type):
|
||||
name = env_object.__name__
|
||||
register_env(name, lambda config: env_object(config))
|
||||
return name
|
||||
raise ValueError(
|
||||
"{} is an invalid env specification. ".format(env_object) +
|
||||
"You can specify a custom env as either a class "
|
||||
"(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").")
|
||||
Agent = renamed_class(Trainer)
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from ray.rllib.agents.ars.ars import (ARSAgent, DEFAULT_CONFIG)
|
||||
from ray.rllib.agents.ars.ars import (ARSTrainer, DEFAULT_CONFIG)
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
__all__ = ["ARSAgent", "DEFAULT_CONFIG"]
|
||||
ARSAgent = renamed_class(ARSTrainer)
|
||||
|
||||
__all__ = ["ARSAgent", "ARSTrainer", "DEFAULT_CONFIG"]
|
||||
|
|
|
@ -12,7 +12,7 @@ import numpy as np
|
|||
import time
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents import Agent, with_common_config
|
||||
from ray.rllib.agents import Trainer, with_common_config
|
||||
|
||||
from ray.rllib.agents.ars import optimizers
|
||||
from ray.rllib.agents.ars import policies
|
||||
|
@ -157,13 +157,13 @@ class Worker(object):
|
|||
eval_lengths=eval_lengths)
|
||||
|
||||
|
||||
class ARSAgent(Agent):
|
||||
class ARSTrainer(Trainer):
|
||||
"""Large-scale implementation of Augmented Random Search in Ray."""
|
||||
|
||||
_agent_name = "ARS"
|
||||
_name = "ARS"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
env = env_creator(config["env_config"])
|
||||
from ray.rllib import models
|
||||
|
@ -195,7 +195,7 @@ class ARSAgent(Agent):
|
|||
self.reward_list = []
|
||||
self.tstart = time.time()
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
config = self.config
|
||||
|
||||
|
@ -291,13 +291,13 @@ class ARSAgent(Agent):
|
|||
|
||||
return result
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for w in self.workers:
|
||||
w.__ray_terminate__.remote()
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def compute_action(self, observation):
|
||||
return self.policy.compute(observation, update=True)[0]
|
||||
|
||||
|
|
|
@ -2,7 +2,14 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.ddpg.apex import ApexDDPGAgent
|
||||
from ray.rllib.agents.ddpg.ddpg import DDPGAgent, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.ddpg.apex import ApexDDPGTrainer
|
||||
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
__all__ = ["DDPGAgent", "ApexDDPGAgent", "DEFAULT_CONFIG"]
|
||||
ApexDDPGAgent = renamed_class(ApexDDPGTrainer)
|
||||
DDPGAgent = renamed_class(DDPGTrainer)
|
||||
|
||||
__all__ = [
|
||||
"DDPGAgent", "ApexDDPGAgent", "DDPGTrainer", "ApexDDPGTrainer",
|
||||
"DEFAULT_CONFIG"
|
||||
]
|
||||
|
|
|
@ -2,7 +2,8 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.ddpg.ddpg import DDPGAgent, DEFAULT_CONFIG as DDPG_CONFIG
|
||||
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, \
|
||||
DEFAULT_CONFIG as DDPG_CONFIG
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
||||
|
@ -32,17 +33,17 @@ APEX_DDPG_DEFAULT_CONFIG = merge_dicts(
|
|||
)
|
||||
|
||||
|
||||
class ApexDDPGAgent(DDPGAgent):
|
||||
class ApexDDPGTrainer(DDPGTrainer):
|
||||
"""DDPG variant that uses the Ape-X distributed policy optimizer.
|
||||
|
||||
By default, this is configured for a large single node (32 cores). For
|
||||
running in a large cluster, increase the `num_workers` config var.
|
||||
"""
|
||||
|
||||
_agent_name = "APEX_DDPG"
|
||||
_name = "APEX_DDPG"
|
||||
_default_config = APEX_DDPG_DEFAULT_CONFIG
|
||||
|
||||
@override(DDPGAgent)
|
||||
@override(DDPGTrainer)
|
||||
def update_target_if_needed(self):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
|
|
|
@ -2,8 +2,8 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.agent import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import DQNAgent
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer
|
||||
from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
|
||||
|
@ -132,13 +132,13 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
class DDPGAgent(DQNAgent):
|
||||
class DDPGTrainer(DQNTrainer):
|
||||
"""DDPG implementation in TensorFlow."""
|
||||
_agent_name = "DDPG"
|
||||
_name = "DDPG"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = DDPGPolicyGraph
|
||||
|
||||
@override(DQNAgent)
|
||||
@override(DQNTrainer)
|
||||
def _make_exploration_schedule(self, worker_index):
|
||||
# Override DQN's schedule to take into account `noise_scale`
|
||||
if self.config["per_worker_exploration"]:
|
||||
|
|
|
@ -2,7 +2,13 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.dqn.apex import ApexAgent
|
||||
from ray.rllib.agents.dqn.dqn import DQNAgent, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.dqn.apex import ApexTrainer
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
__all__ = ["ApexAgent", "DQNAgent", "DEFAULT_CONFIG"]
|
||||
DQNAgent = renamed_class(DQNTrainer)
|
||||
ApexAgent = renamed_class(ApexTrainer)
|
||||
|
||||
__all__ = [
|
||||
"DQNAgent", "ApexAgent", "ApexTrainer", "DQNTrainer", "DEFAULT_CONFIG"
|
||||
]
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.dqn.dqn import DQNAgent, DEFAULT_CONFIG as DQN_CONFIG
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_CONFIG
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
@ -36,17 +36,17 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
|||
# yapf: enable
|
||||
|
||||
|
||||
class ApexAgent(DQNAgent):
|
||||
class ApexTrainer(DQNTrainer):
|
||||
"""DQN variant that uses the Ape-X distributed policy optimizer.
|
||||
|
||||
By default, this is configured for a large single node (32 cores). For
|
||||
running in a large cluster, increase the `num_workers` config var.
|
||||
"""
|
||||
|
||||
_agent_name = "APEX"
|
||||
_name = "APEX"
|
||||
_default_config = APEX_DEFAULT_CONFIG
|
||||
|
||||
@override(DQNAgent)
|
||||
@override(DQNTrainer)
|
||||
def update_target_if_needed(self):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
|
|
|
@ -7,7 +7,7 @@ import time
|
|||
|
||||
from ray import tune
|
||||
from ray.rllib import optimizers
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
|
@ -137,15 +137,15 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
class DQNAgent(Agent):
|
||||
class DQNTrainer(Trainer):
|
||||
"""DQN implementation in TensorFlow."""
|
||||
|
||||
_agent_name = "DQN"
|
||||
_name = "DQN"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = DQNPolicyGraph
|
||||
_optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
self._validate_config()
|
||||
|
||||
|
@ -161,7 +161,7 @@ class DQNAgent(Agent):
|
|||
]
|
||||
|
||||
for k in self._optimizer_shared_configs:
|
||||
if self._agent_name != "DQN" and k in [
|
||||
if self._name != "DQN" and k in [
|
||||
"schedule_max_timesteps", "beta_annealing_fraction",
|
||||
"final_prioritized_replay_beta"
|
||||
]:
|
||||
|
@ -238,7 +238,7 @@ class DQNAgent(Agent):
|
|||
self.last_target_update_ts = 0
|
||||
self.num_target_updates = 0
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
start_timestep = self.global_timestep
|
||||
|
||||
|
@ -326,7 +326,7 @@ class DQNAgent(Agent):
|
|||
final_p=self.config["exploration_final_eps"])
|
||||
|
||||
def __getstate__(self):
|
||||
state = Agent.__getstate__(self)
|
||||
state = Trainer.__getstate__(self)
|
||||
state.update({
|
||||
"num_target_updates": self.num_target_updates,
|
||||
"last_target_update_ts": self.last_target_update_ts,
|
||||
|
@ -334,7 +334,7 @@ class DQNAgent(Agent):
|
|||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
Agent.__setstate__(self, state)
|
||||
Trainer.__setstate__(self, state)
|
||||
self.num_target_updates = state["num_target_updates"]
|
||||
self.last_target_update_ts = state["last_target_update_ts"]
|
||||
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from ray.rllib.agents.es.es import (ESAgent, DEFAULT_CONFIG)
|
||||
from ray.rllib.agents.es.es import (ESTrainer, DEFAULT_CONFIG)
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
__all__ = ["ESAgent", "DEFAULT_CONFIG"]
|
||||
ESAgent = renamed_class(ESTrainer)
|
||||
|
||||
__all__ = ["ESAgent", "ESTrainer", "DEFAULT_CONFIG"]
|
||||
|
|
|
@ -11,7 +11,7 @@ import numpy as np
|
|||
import time
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents import Agent, with_common_config
|
||||
from ray.rllib.agents import Trainer, with_common_config
|
||||
|
||||
from ray.rllib.agents.es import optimizers
|
||||
from ray.rllib.agents.es import policies
|
||||
|
@ -163,13 +163,13 @@ class Worker(object):
|
|||
eval_lengths=eval_lengths)
|
||||
|
||||
|
||||
class ESAgent(Agent):
|
||||
class ESTrainer(Trainer):
|
||||
"""Large-scale implementation of Evolution Strategies in Ray."""
|
||||
|
||||
_agent_name = "ES"
|
||||
_name = "ES"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
policy_params = {"action_noise_std": 0.01}
|
||||
|
||||
|
@ -200,7 +200,7 @@ class ESAgent(Agent):
|
|||
self.reward_list = []
|
||||
self.tstart = time.time()
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
config = self.config
|
||||
|
||||
|
@ -288,11 +288,11 @@ class ESAgent(Agent):
|
|||
|
||||
return result
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def compute_action(self, observation):
|
||||
return self.policy.compute(observation, update=False)[0]
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for w in self.workers:
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from ray.rllib.agents.impala.impala import ImpalaAgent, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.impala.impala import ImpalaTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
__all__ = ["ImpalaAgent", "DEFAULT_CONFIG"]
|
||||
ImpalaAgent = renamed_class(ImpalaTrainer)
|
||||
|
||||
__all__ = ["ImpalaAgent", "ImpalaTrainer", "DEFAULT_CONFIG"]
|
||||
|
|
|
@ -6,7 +6,7 @@ import time
|
|||
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
|
||||
from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.optimizers import AsyncSamplesOptimizer
|
||||
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -100,14 +100,14 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
class ImpalaAgent(Agent):
|
||||
class ImpalaTrainer(Trainer):
|
||||
"""IMPALA implementation using DeepMind's V-trace."""
|
||||
|
||||
_agent_name = "IMPALA"
|
||||
_name = "IMPALA"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = VTracePolicyGraph
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
for k in OPTIMIZER_SHARED_CONFIGS:
|
||||
if k not in config["optimizer"]:
|
||||
|
@ -136,7 +136,7 @@ class ImpalaAgent(Agent):
|
|||
@override(Trainable)
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
Agent._validate_config(cf)
|
||||
Trainer._validate_config(cf)
|
||||
return Resources(
|
||||
cpu=cf["num_cpus_for_driver"],
|
||||
gpu=cf["num_gpus"],
|
||||
|
@ -144,7 +144,7 @@ class ImpalaAgent(Agent):
|
|||
cf["num_aggregation_workers"],
|
||||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
start = time.time()
|
||||
|
|
|
@ -2,6 +2,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.marwil.marwil import MARWILAgent, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG
|
||||
|
||||
__all__ = ["MARWILAgent", "DEFAULT_CONFIG"]
|
||||
__all__ = ["MARWILTrainer", "DEFAULT_CONFIG"]
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.marwil.marwil_policy_graph import MARWILPolicyGraph
|
||||
from ray.rllib.optimizers import SyncBatchReplayOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -39,14 +39,14 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
class MARWILAgent(Agent):
|
||||
class MARWILTrainer(Trainer):
|
||||
"""MARWIL implementation in TensorFlow."""
|
||||
|
||||
_agent_name = "MARWIL"
|
||||
_name = "MARWIL"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = MARWILPolicyGraph
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
env_creator, self._policy_graph)
|
||||
|
@ -59,7 +59,7 @@ class MARWILAgent(Agent):
|
|||
"train_batch_size": config["train_batch_size"],
|
||||
})
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
fetches = self.optimizer.step()
|
||||
|
|
|
@ -6,13 +6,13 @@ import os
|
|||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
|
||||
|
||||
class _MockAgent(Agent):
|
||||
"""Mock agent for use in tests"""
|
||||
class _MockTrainer(Trainer):
|
||||
"""Mock trainer for use in tests"""
|
||||
|
||||
_agent_name = "MockAgent"
|
||||
_name = "MockTrainer"
|
||||
_default_config = with_common_config({
|
||||
"mock_error": False,
|
||||
"persistent_error": False,
|
||||
|
@ -57,12 +57,12 @@ class _MockAgent(Agent):
|
|||
return self.info
|
||||
|
||||
|
||||
class _SigmoidFakeData(_MockAgent):
|
||||
"""Agent that returns sigmoid learning curves.
|
||||
class _SigmoidFakeData(_MockTrainer):
|
||||
"""Trainer that returns sigmoid learning curves.
|
||||
|
||||
This can be helpful for evaluating early stopping algorithms."""
|
||||
|
||||
_agent_name = "SigmoidFakeData"
|
||||
_name = "SigmoidFakeData"
|
||||
_default_config = with_common_config({
|
||||
"width": 100,
|
||||
"height": 100,
|
||||
|
@ -84,9 +84,9 @@ class _SigmoidFakeData(_MockAgent):
|
|||
info={})
|
||||
|
||||
|
||||
class _ParameterTuningAgent(_MockAgent):
|
||||
class _ParameterTuningTrainer(_MockTrainer):
|
||||
|
||||
_agent_name = "ParameterTuningAgent"
|
||||
_name = "ParameterTuningTrainer"
|
||||
_default_config = with_common_config({
|
||||
"reward_amt": 10,
|
||||
"dummy_param": 10,
|
||||
|
@ -108,8 +108,8 @@ class _ParameterTuningAgent(_MockAgent):
|
|||
def _agent_import_failed(trace):
|
||||
"""Returns dummy agent class for if PyTorch etc. is not installed."""
|
||||
|
||||
class _AgentImportFailed(Agent):
|
||||
_agent_name = "AgentImportFailed"
|
||||
class _AgentImportFailed(Trainer):
|
||||
_name = "AgentImportFailed"
|
||||
_default_config = with_common_config({})
|
||||
|
||||
def _setup(self, config):
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from ray.rllib.agents.pg.pg import PGAgent, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
__all__ = ["PGAgent", "DEFAULT_CONFIG"]
|
||||
PGAgent = renamed_class(PGTrainer)
|
||||
|
||||
__all__ = ["PGAgent", "PGTrainer", "DEFAULT_CONFIG"]
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
|
@ -22,18 +22,18 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
class PGAgent(Agent):
|
||||
class PGTrainer(Trainer):
|
||||
"""Simple policy gradient agent.
|
||||
|
||||
This is an example agent to show how to implement algorithms in RLlib.
|
||||
In most cases, you will probably want to use the PPO agent instead.
|
||||
"""
|
||||
|
||||
_agent_name = "PG"
|
||||
_name = "PG"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = PGPolicyGraph
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
if config["use_pytorch"]:
|
||||
from ray.rllib.agents.pg.torch_pg_policy_graph import \
|
||||
|
@ -51,7 +51,7 @@ class PGAgent(Agent):
|
|||
self.optimizer = SyncSamplesOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators, optimizer_config)
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
self.optimizer.step()
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
from ray.rllib.agents.ppo.ppo import (PPOAgent, DEFAULT_CONFIG)
|
||||
from ray.rllib.agents.ppo.appo import APPOAgent
|
||||
from ray.rllib.agents.ppo.ppo import PPOTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.ppo.appo import APPOTrainer
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
__all__ = ["APPOAgent", "PPOAgent", "DEFAULT_CONFIG"]
|
||||
PPOAgent = renamed_class(PPOTrainer)
|
||||
|
||||
__all__ = ["PPOAgent", "APPOTrainer", "PPOTrainer", "DEFAULT_CONFIG"]
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOPolicyGraph
|
||||
from ray.rllib.agents.agent import with_base_config
|
||||
from ray.rllib.agents.trainer import with_base_config
|
||||
from ray.rllib.agents import impala
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
@ -52,13 +52,13 @@ DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
|
|||
# yapf: enable
|
||||
|
||||
|
||||
class APPOAgent(impala.ImpalaAgent):
|
||||
class APPOTrainer(impala.ImpalaTrainer):
|
||||
"""PPO surrogate loss with IMPALA-architecture."""
|
||||
|
||||
_agent_name = "APPO"
|
||||
_name = "APPO"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = AsyncPPOPolicyGraph
|
||||
|
||||
@override(impala.ImpalaAgent)
|
||||
@override(impala.ImpalaTrainer)
|
||||
def _get_policy_graph(self):
|
||||
return AsyncPPOPolicyGraph
|
||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import print_function
|
|||
|
||||
import logging
|
||||
|
||||
from ray.rllib.agents import Agent, with_common_config
|
||||
from ray.rllib.agents import Trainer, with_common_config
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -63,14 +63,14 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
class PPOAgent(Agent):
|
||||
class PPOTrainer(Trainer):
|
||||
"""Multi-GPU optimized implementation of PPO in TensorFlow."""
|
||||
|
||||
_agent_name = "PPO"
|
||||
_name = "PPO"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = PPOPolicyGraph
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
self._validate_config()
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
|
@ -96,7 +96,7 @@ class PPOAgent(Agent):
|
|||
"straggler_mitigation": config["straggler_mitigation"],
|
||||
})
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
if "observation_filter" not in self.raw_user_config:
|
||||
# TODO(ekl) remove this message after a few releases
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.qmix.apex import ApexQMixAgent
|
||||
from ray.rllib.agents.qmix.qmix import QMixTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.qmix.apex import ApexQMixTrainer
|
||||
|
||||
__all__ = ["QMixAgent", "ApexQMixAgent", "DEFAULT_CONFIG"]
|
||||
__all__ = ["QMixTrainer", "ApexQMixTrainer", "DEFAULT_CONFIG"]
|
||||
|
|
|
@ -4,7 +4,8 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG as QMIX_CONFIG
|
||||
from ray.rllib.agents.qmix.qmix import QMixTrainer, \
|
||||
DEFAULT_CONFIG as QMIX_CONFIG
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
||||
|
@ -34,17 +35,17 @@ APEX_QMIX_DEFAULT_CONFIG = merge_dicts(
|
|||
)
|
||||
|
||||
|
||||
class ApexQMixAgent(QMixAgent):
|
||||
class ApexQMixTrainer(QMixTrainer):
|
||||
"""QMIX variant that uses the Ape-X distributed policy optimizer.
|
||||
|
||||
By default, this is configured for a large single node (32 cores). For
|
||||
running in a large cluster, increase the `num_workers` config var.
|
||||
"""
|
||||
|
||||
_agent_name = "APEX_QMIX"
|
||||
_name = "APEX_QMIX"
|
||||
_default_config = APEX_QMIX_DEFAULT_CONFIG
|
||||
|
||||
@override(QMixAgent)
|
||||
@override(QMixTrainer)
|
||||
def update_target_if_needed(self):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
|
|
|
@ -2,8 +2,8 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.agent import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import DQNAgent
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer
|
||||
from ray.rllib.agents.qmix.qmix_policy_graph import QMixPolicyGraph
|
||||
|
||||
# yapf: disable
|
||||
|
@ -90,10 +90,10 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# yapf: enable
|
||||
|
||||
|
||||
class QMixAgent(DQNAgent):
|
||||
class QMixTrainer(DQNTrainer):
|
||||
"""QMix implementation in PyTorch."""
|
||||
|
||||
_agent_name = "QMIX"
|
||||
_name = "QMIX"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = QMixPolicyGraph
|
||||
_optimizer_shared_configs = [
|
||||
|
|
|
@ -11,77 +11,77 @@ from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
|
|||
|
||||
def _import_appo():
|
||||
from ray.rllib.agents import ppo
|
||||
return ppo.APPOAgent
|
||||
return ppo.APPOTrainer
|
||||
|
||||
|
||||
def _import_qmix():
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.QMixAgent
|
||||
return qmix.QMixTrainer
|
||||
|
||||
|
||||
def _import_apex_qmix():
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.ApexQMixAgent
|
||||
return qmix.ApexQMixTrainer
|
||||
|
||||
|
||||
def _import_ddpg():
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.DDPGAgent
|
||||
return ddpg.DDPGTrainer
|
||||
|
||||
|
||||
def _import_apex_ddpg():
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.ApexDDPGAgent
|
||||
return ddpg.ApexDDPGTrainer
|
||||
|
||||
|
||||
def _import_ppo():
|
||||
from ray.rllib.agents import ppo
|
||||
return ppo.PPOAgent
|
||||
return ppo.PPOTrainer
|
||||
|
||||
|
||||
def _import_es():
|
||||
from ray.rllib.agents import es
|
||||
return es.ESAgent
|
||||
return es.ESTrainer
|
||||
|
||||
|
||||
def _import_ars():
|
||||
from ray.rllib.agents import ars
|
||||
return ars.ARSAgent
|
||||
return ars.ARSTrainer
|
||||
|
||||
|
||||
def _import_dqn():
|
||||
from ray.rllib.agents import dqn
|
||||
return dqn.DQNAgent
|
||||
return dqn.DQNTrainer
|
||||
|
||||
|
||||
def _import_apex():
|
||||
from ray.rllib.agents import dqn
|
||||
return dqn.ApexAgent
|
||||
return dqn.ApexTrainer
|
||||
|
||||
|
||||
def _import_a3c():
|
||||
from ray.rllib.agents import a3c
|
||||
return a3c.A3CAgent
|
||||
return a3c.A3CTrainer
|
||||
|
||||
|
||||
def _import_a2c():
|
||||
from ray.rllib.agents import a3c
|
||||
return a3c.A2CAgent
|
||||
return a3c.A2CTrainer
|
||||
|
||||
|
||||
def _import_pg():
|
||||
from ray.rllib.agents import pg
|
||||
return pg.PGAgent
|
||||
return pg.PGTrainer
|
||||
|
||||
|
||||
def _import_impala():
|
||||
from ray.rllib.agents import impala
|
||||
return impala.ImpalaAgent
|
||||
return impala.ImpalaTrainer
|
||||
|
||||
|
||||
def _import_marwil():
|
||||
from ray.rllib.agents import marwil
|
||||
return marwil.MARWILAgent
|
||||
return marwil.MARWILTrainer
|
||||
|
||||
|
||||
ALGORITHMS = {
|
||||
|
@ -122,13 +122,13 @@ def _get_agent_class(alg):
|
|||
from ray.tune import script_runner
|
||||
return script_runner.ScriptRunner
|
||||
elif alg == "__fake":
|
||||
from ray.rllib.agents.mock import _MockAgent
|
||||
return _MockAgent
|
||||
from ray.rllib.agents.mock import _MockTrainer
|
||||
return _MockTrainer
|
||||
elif alg == "__sigmoid_fake_data":
|
||||
from ray.rllib.agents.mock import _SigmoidFakeData
|
||||
return _SigmoidFakeData
|
||||
elif alg == "__parameter_tuning":
|
||||
from ray.rllib.agents.mock import _ParameterTuningAgent
|
||||
return _ParameterTuningAgent
|
||||
from ray.rllib.agents.mock import _ParameterTuningTrainer
|
||||
return _ParameterTuningTrainer
|
||||
else:
|
||||
raise Exception(("Unknown algorithm {}.").format(alg))
|
||||
|
|
817
python/ray/rllib/agents/trainer.py
Normal file
817
python/ray/rllib/agents/trainer.py
Normal file
|
@ -0,0 +1,817 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from datetime import datetime
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import six
|
||||
import time
|
||||
import tempfile
|
||||
import tensorflow as tf
|
||||
from types import FunctionType
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayError
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
||||
ShuffledInput
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \
|
||||
_validate_multiagent_config
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import Resources, ExportFormat
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Max number of times to retry a worker failure. We shouldn't try too many
|
||||
# times in a row since that would indicate a persistent cluster issue.
|
||||
MAX_WORKER_FAILURE_RETRIES = 3
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
COMMON_CONFIG = {
|
||||
# === Debugging ===
|
||||
# Whether to write episode stats and videos to the agent log dir
|
||||
"monitor": False,
|
||||
# Set the ray.rllib.* log level for the agent process and its evaluators.
|
||||
# Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also
|
||||
# periodically print out summaries of relevant internal dataflow (this is
|
||||
# also printed out once at startup at the INFO level).
|
||||
"log_level": "INFO",
|
||||
# Callbacks that will be run during various phases of training. These all
|
||||
# take a single "info" dict as an argument. For episode callbacks, custom
|
||||
# metrics can be attached to the episode by updating the episode object's
|
||||
# custom metrics dict (see examples/custom_metrics_and_callbacks.py).
|
||||
"callbacks": {
|
||||
"on_episode_start": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_step": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_end": None, # arg: {"env": .., "episode": ...}
|
||||
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
|
||||
"on_train_result": None, # arg: {"trainer": ..., "result": ...}
|
||||
"on_postprocess_traj": None, # arg: {"batch": ..., "episode": ...}
|
||||
},
|
||||
# Whether to attempt to continue training if a worker crashes.
|
||||
"ignore_worker_failures": False,
|
||||
|
||||
# === Policy ===
|
||||
# Arguments to pass to model. See models/catalog.py for a full list of the
|
||||
# available model options.
|
||||
"model": MODEL_DEFAULTS,
|
||||
# Arguments to pass to the policy optimizer. These vary by optimizer.
|
||||
"optimizer": {},
|
||||
|
||||
# === Environment ===
|
||||
# Discount factor of the MDP
|
||||
"gamma": 0.99,
|
||||
# Number of steps after which the episode is forced to terminate. Defaults
|
||||
# to `env.spec.max_episode_steps` (if present) for Gym envs.
|
||||
"horizon": None,
|
||||
# Calculate rewards but don't reset the environment when the horizon is
|
||||
# hit. This allows value estimation and RNN state to span across logical
|
||||
# episodes denoted by horizon. This only has an effect if horizon != inf.
|
||||
"soft_horizon": False,
|
||||
# Arguments to pass to the env creator
|
||||
"env_config": {},
|
||||
# Environment name can also be passed via config
|
||||
"env": None,
|
||||
# Whether to clip rewards prior to experience postprocessing. Setting to
|
||||
# None means clip for Atari only.
|
||||
"clip_rewards": None,
|
||||
# Whether to np.clip() actions to the action space low/high range spec.
|
||||
"clip_actions": True,
|
||||
# Whether to use rllib or deepmind preprocessors by default
|
||||
"preprocessor_pref": "deepmind",
|
||||
|
||||
# === Resources ===
|
||||
# Number of actors used for parallelism
|
||||
"num_workers": 2,
|
||||
# Number of GPUs to allocate to the driver. Note that not all algorithms
|
||||
# can take advantage of driver GPUs. This can be fraction (e.g., 0.3 GPUs).
|
||||
"num_gpus": 0,
|
||||
# Number of CPUs to allocate per worker.
|
||||
"num_cpus_per_worker": 1,
|
||||
# Number of GPUs to allocate per worker. This can be fractional.
|
||||
"num_gpus_per_worker": 0,
|
||||
# Any custom resources to allocate per worker.
|
||||
"custom_resources_per_worker": {},
|
||||
# Number of CPUs to allocate for the driver. Note: this only takes effect
|
||||
# when running in Tune.
|
||||
"num_cpus_for_driver": 1,
|
||||
|
||||
# === Execution ===
|
||||
# Number of environments to evaluate vectorwise per worker.
|
||||
"num_envs_per_worker": 1,
|
||||
# Default sample batch size (unroll length). Batches of this size are
|
||||
# collected from workers until train_batch_size is met. When using
|
||||
# multiple envs per worker, this is multiplied by num_envs_per_worker.
|
||||
"sample_batch_size": 200,
|
||||
# Training batch size, if applicable. Should be >= sample_batch_size.
|
||||
# Samples batches will be concatenated together to this size for training.
|
||||
"train_batch_size": 200,
|
||||
# Whether to rollout "complete_episodes" or "truncate_episodes"
|
||||
"batch_mode": "truncate_episodes",
|
||||
# (Deprecated) Use a background thread for sampling (slightly off-policy)
|
||||
"sample_async": False,
|
||||
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter"
|
||||
"observation_filter": "NoFilter",
|
||||
# Whether to synchronize the statistics of remote filters.
|
||||
"synchronize_filters": True,
|
||||
# Configure TF for single-process operation by default
|
||||
"tf_session_args": {
|
||||
# note: overriden by `local_evaluator_tf_session_args`
|
||||
"intra_op_parallelism_threads": 2,
|
||||
"inter_op_parallelism_threads": 2,
|
||||
"gpu_options": {
|
||||
"allow_growth": True,
|
||||
},
|
||||
"log_device_placement": False,
|
||||
"device_count": {
|
||||
"CPU": 1
|
||||
},
|
||||
"allow_soft_placement": True, # required by PPO multi-gpu
|
||||
},
|
||||
# Override the following tf session args on the local evaluator
|
||||
"local_evaluator_tf_session_args": {
|
||||
# Allow a higher level of parallelism by default, but not unlimited
|
||||
# since that can cause crashes with many concurrent drivers.
|
||||
"intra_op_parallelism_threads": 8,
|
||||
"inter_op_parallelism_threads": 8,
|
||||
},
|
||||
# Whether to LZ4 compress individual observations
|
||||
"compress_observations": False,
|
||||
# Drop metric batches from unresponsive workers after this many seconds
|
||||
"collect_metrics_timeout": 180,
|
||||
# Smooth metrics over this many episodes.
|
||||
"metrics_smoothing_episodes": 100,
|
||||
# If using num_envs_per_worker > 1, whether to create those new envs in
|
||||
# remote processes instead of in the same worker. This adds overheads, but
|
||||
# can make sense if your envs can take much time to step / reset
|
||||
# (e.g., for StarCraft). Use this cautiously; overheads are significant.
|
||||
"remote_worker_envs": False,
|
||||
# Timeout that remote workers are waiting when polling environments.
|
||||
# 0 (continue when at least one env is ready) is a reasonable default,
|
||||
# but optimal value could be obtained by measuring your environment
|
||||
# step / reset and model inference perf.
|
||||
"remote_env_batch_wait_ms": 0,
|
||||
|
||||
# === Offline Datasets ===
|
||||
# Specify how to generate experiences:
|
||||
# - "sampler": generate experiences via online simulation (default)
|
||||
# - a local directory or file glob expression (e.g., "/tmp/*.json")
|
||||
# - a list of individual file paths/URIs (e.g., ["/tmp/1.json",
|
||||
# "s3://bucket/2.json"])
|
||||
# - a dict with string keys and sampling probabilities as values (e.g.,
|
||||
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
|
||||
# - a function that returns a rllib.offline.InputReader
|
||||
"input": "sampler",
|
||||
# Specify how to evaluate the current policy. This only has an effect when
|
||||
# reading offline experiences. Available options:
|
||||
# - "wis": the weighted step-wise importance sampling estimator.
|
||||
# - "is": the step-wise importance sampling estimator.
|
||||
# - "simulation": run the environment in the background, but use
|
||||
# this data for evaluation only and not for learning.
|
||||
"input_evaluation": ["is", "wis"],
|
||||
# Whether to run postprocess_trajectory() on the trajectory fragments from
|
||||
# offline inputs. Note that postprocessing will be done using the *current*
|
||||
# policy, not the *behaviour* policy, which is typically undesirable for
|
||||
# on-policy algorithms.
|
||||
"postprocess_inputs": False,
|
||||
# If positive, input batches will be shuffled via a sliding window buffer
|
||||
# of this number of batches. Use this if the input data is not in random
|
||||
# enough order. Input is delayed until the shuffle buffer is filled.
|
||||
"shuffle_buffer_size": 0,
|
||||
# Specify where experiences should be saved:
|
||||
# - None: don't save any experiences
|
||||
# - "logdir" to save to the agent log dir
|
||||
# - a path/URI to save to a custom output directory (e.g., "s3://bucket/")
|
||||
# - a function that returns a rllib.offline.OutputWriter
|
||||
"output": None,
|
||||
# What sample batch columns to LZ4 compress in the output data.
|
||||
"output_compress_columns": ["obs", "new_obs"],
|
||||
# Max output file size before rolling over to a new file.
|
||||
"output_max_file_size": 64 * 1024 * 1024,
|
||||
|
||||
# === Multiagent ===
|
||||
"multiagent": {
|
||||
# Map from policy ids to tuples of (policy_graph_cls, obs_space,
|
||||
# act_space, config). See policy_evaluator.py for more info.
|
||||
"policy_graphs": {},
|
||||
# Function mapping agent ids to policy ids.
|
||||
"policy_mapping_fn": None,
|
||||
# Optional whitelist of policies to train, or None for all policies.
|
||||
"policies_to_train": None,
|
||||
},
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def with_common_config(extra_config):
|
||||
"""Returns the given config dict merged with common agent confs."""
|
||||
|
||||
return with_base_config(COMMON_CONFIG, extra_config)
|
||||
|
||||
|
||||
def with_base_config(base_config, extra_config):
|
||||
"""Returns the given config dict merged with a base agent conf."""
|
||||
|
||||
config = copy.deepcopy(base_config)
|
||||
config.update(extra_config)
|
||||
return config
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class Trainer(Trainable):
|
||||
"""A trainer coordinates the optimization of one or more RL policies.
|
||||
|
||||
All RLlib trainers extend this base class, e.g., the A3CTrainer implements
|
||||
the A3C algorithm for single and multi-agent training.
|
||||
|
||||
Trainer objects retain internal model state between calls to train(), so
|
||||
you should create a new trainer instance for each training session.
|
||||
|
||||
Attributes:
|
||||
env_creator (func): Function that creates a new training env.
|
||||
config (obj): Algorithm-specific configuration data.
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
"""
|
||||
|
||||
_allow_unknown_configs = False
|
||||
_allow_unknown_subkeys = [
|
||||
"tf_session_args", "env_config", "model", "optimizer", "multiagent",
|
||||
"custom_resources_per_worker"
|
||||
]
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, config=None, env=None, logger_creator=None):
|
||||
"""Initialize an RLLib trainer.
|
||||
|
||||
Args:
|
||||
config (dict): Algorithm-specific configuration data.
|
||||
env (str): Name of the environment to use. Note that this can also
|
||||
be specified as the `env` key in config.
|
||||
logger_creator (func): Function that creates a ray.tune.Logger
|
||||
object. If unspecified, a default logger is created.
|
||||
"""
|
||||
|
||||
config = config or {}
|
||||
|
||||
# Vars to synchronize to evaluators on each train call
|
||||
self.global_vars = {"timestep": 0}
|
||||
|
||||
# Trainers allow env ids to be passed directly to the constructor.
|
||||
self._env_id = self._register_if_needed(env or config.get("env"))
|
||||
|
||||
# Create a default logger creator if no logger_creator is specified
|
||||
if logger_creator is None:
|
||||
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
logdir_prefix = "{}_{}_{}".format(self._name, self._env_id,
|
||||
timestr)
|
||||
|
||||
def default_logger_creator(config):
|
||||
"""Creates a Unified logger with a default logdir prefix
|
||||
containing the agent name and the env id
|
||||
"""
|
||||
if not os.path.exists(DEFAULT_RESULTS_DIR):
|
||||
os.makedirs(DEFAULT_RESULTS_DIR)
|
||||
logdir = tempfile.mkdtemp(
|
||||
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
|
||||
return UnifiedLogger(config, logdir, None)
|
||||
|
||||
logger_creator = default_logger_creator
|
||||
|
||||
Trainable.__init__(self, config, logger_creator)
|
||||
|
||||
@classmethod
|
||||
@override(Trainable)
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
Trainer._validate_config(cf)
|
||||
# TODO(ekl): add custom resources here once tune supports them
|
||||
return Resources(
|
||||
cpu=cf["num_cpus_for_driver"],
|
||||
gpu=cf["num_gpus"],
|
||||
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
|
||||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
|
||||
|
||||
@override(Trainable)
|
||||
@PublicAPI
|
||||
def train(self):
|
||||
"""Overrides super.train to synchronize global vars."""
|
||||
|
||||
if self._has_policy_optimizer():
|
||||
self.global_vars["timestep"] = self.optimizer.num_steps_sampled
|
||||
self.optimizer.local_evaluator.set_global_vars(self.global_vars)
|
||||
for ev in self.optimizer.remote_evaluators:
|
||||
ev.set_global_vars.remote(self.global_vars)
|
||||
logger.debug("updated global vars: {}".format(self.global_vars))
|
||||
|
||||
result = None
|
||||
for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
|
||||
try:
|
||||
result = Trainable.train(self)
|
||||
except RayError as e:
|
||||
if self.config["ignore_worker_failures"]:
|
||||
logger.exception(
|
||||
"Error in train call, attempting to recover")
|
||||
self._try_recover()
|
||||
else:
|
||||
logger.info(
|
||||
"Worker crashed during call to train(). To attempt to "
|
||||
"continue training without the failed worker, set "
|
||||
"`'ignore_worker_failures': True`.")
|
||||
raise e
|
||||
except Exception as e:
|
||||
time.sleep(0.5) # allow logs messages to propagate
|
||||
raise e
|
||||
else:
|
||||
break
|
||||
if result is None:
|
||||
raise RuntimeError("Failed to recover from worker crash")
|
||||
|
||||
if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
|
||||
and hasattr(self, "local_evaluator")):
|
||||
FilterManager.synchronize(
|
||||
self.local_evaluator.filters,
|
||||
self.remote_evaluators,
|
||||
update_remote=self.config["synchronize_filters"])
|
||||
logger.debug("synchronized filters: {}".format(
|
||||
self.local_evaluator.filters))
|
||||
|
||||
if self._has_policy_optimizer():
|
||||
result["num_healthy_workers"] = len(
|
||||
self.optimizer.remote_evaluators)
|
||||
return result
|
||||
|
||||
@override(Trainable)
|
||||
def _log_result(self, result):
|
||||
if self.config["callbacks"].get("on_train_result"):
|
||||
self.config["callbacks"]["on_train_result"]({
|
||||
"trainer": self,
|
||||
"result": result,
|
||||
})
|
||||
# log after the callback is invoked, so that the user has a chance
|
||||
# to mutate the result
|
||||
Trainable._log_result(self, result)
|
||||
|
||||
@override(Trainable)
|
||||
def _setup(self, config):
|
||||
env = self._env_id
|
||||
if env:
|
||||
config["env"] = env
|
||||
if _global_registry.contains(ENV_CREATOR, env):
|
||||
self.env_creator = _global_registry.get(ENV_CREATOR, env)
|
||||
else:
|
||||
import gym # soft dependency
|
||||
self.env_creator = lambda env_config: gym.make(env)
|
||||
else:
|
||||
self.env_creator = lambda env_config: None
|
||||
|
||||
# Merge the supplied config with the class default
|
||||
merged_config = copy.deepcopy(self._default_config)
|
||||
merged_config = deep_update(merged_config, config,
|
||||
self._allow_unknown_configs,
|
||||
self._allow_unknown_subkeys)
|
||||
self.raw_user_config = config
|
||||
self.config = merged_config
|
||||
Trainer._validate_config(self.config)
|
||||
if self.config.get("log_level"):
|
||||
logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
|
||||
|
||||
# TODO(ekl) setting the graph is unnecessary for PyTorch agents
|
||||
with tf.Graph().as_default():
|
||||
self._init(self.config, self.env_creator)
|
||||
|
||||
@override(Trainable)
|
||||
def _stop(self):
|
||||
# Call stop on all evaluators to release resources
|
||||
if hasattr(self, "local_evaluator"):
|
||||
self.local_evaluator.stop()
|
||||
if hasattr(self, "remote_evaluators"):
|
||||
for ev in self.remote_evaluators:
|
||||
ev.stop.remote()
|
||||
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
if hasattr(self, "remote_evaluators"):
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote()
|
||||
|
||||
if hasattr(self, "optimizer"):
|
||||
self.optimizer.stop()
|
||||
|
||||
@override(Trainable)
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(self.iteration))
|
||||
pickle.dump(self.__getstate__(), open(checkpoint_path, "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
@override(Trainable)
|
||||
def _restore(self, checkpoint_path):
|
||||
extra_data = pickle.load(open(checkpoint_path, "rb"))
|
||||
self.__setstate__(extra_data)
|
||||
|
||||
@DeveloperAPI
|
||||
def _init(self, config, env_creator):
|
||||
"""Subclasses should override this for custom initialization."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def compute_action(self,
|
||||
observation,
|
||||
state=None,
|
||||
prev_action=None,
|
||||
prev_reward=None,
|
||||
info=None,
|
||||
policy_id=DEFAULT_POLICY_ID,
|
||||
full_fetch=False):
|
||||
"""Computes an action for the specified policy.
|
||||
|
||||
Note that you can also access the policy object through
|
||||
self.get_policy(policy_id) and call compute_actions() on it directly.
|
||||
|
||||
Arguments:
|
||||
observation (obj): observation from the environment.
|
||||
state (list): RNN hidden state, if any. If state is not None,
|
||||
then all of compute_single_action(...) is returned
|
||||
(computed action, rnn state, logits dictionary).
|
||||
Otherwise compute_single_action(...)[0] is
|
||||
returned (computed action).
|
||||
prev_action (obj): previous action value, if any
|
||||
prev_reward (int): previous reward, if any
|
||||
info (dict): info object, if any
|
||||
policy_id (str): policy to query (only applies to multi-agent).
|
||||
full_fetch (bool): whether to return extra action fetch results.
|
||||
This is always set to true if RNN state is specified.
|
||||
|
||||
Returns:
|
||||
Just the computed action if full_fetch=False, or the full output
|
||||
of policy.compute_actions() otherwise.
|
||||
"""
|
||||
|
||||
if state is None:
|
||||
state = []
|
||||
preprocessed = self.local_evaluator.preprocessors[policy_id].transform(
|
||||
observation)
|
||||
filtered_obs = self.local_evaluator.filters[policy_id](
|
||||
preprocessed, update=False)
|
||||
if state:
|
||||
return self.get_policy(policy_id).compute_single_action(
|
||||
filtered_obs,
|
||||
state,
|
||||
prev_action,
|
||||
prev_reward,
|
||||
info,
|
||||
clip_actions=self.config["clip_actions"])
|
||||
res = self.get_policy(policy_id).compute_single_action(
|
||||
filtered_obs,
|
||||
state,
|
||||
prev_action,
|
||||
prev_reward,
|
||||
info,
|
||||
clip_actions=self.config["clip_actions"])
|
||||
if full_fetch:
|
||||
return res
|
||||
else:
|
||||
return res[0] # backwards compatibility
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
"""Current training iter, auto-incremented with each train() call."""
|
||||
|
||||
return self._iteration
|
||||
|
||||
@property
|
||||
def _name(self):
|
||||
"""Subclasses should override this to declare their name."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _default_config(self):
|
||||
"""Subclasses should override this to declare their default config."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Return policy graph for the specified id, or None.
|
||||
|
||||
Arguments:
|
||||
policy_id (str): id of policy graph to return.
|
||||
"""
|
||||
|
||||
return self.local_evaluator.get_policy(policy_id)
|
||||
|
||||
@PublicAPI
|
||||
def get_weights(self, policies=None):
|
||||
"""Return a dictionary of policy ids to weights.
|
||||
|
||||
Arguments:
|
||||
policies (list): Optional list of policies to return weights for,
|
||||
or None for all policies.
|
||||
"""
|
||||
return self.local_evaluator.get_weights(policies)
|
||||
|
||||
@PublicAPI
|
||||
def set_weights(self, weights):
|
||||
"""Set policy weights by policy id.
|
||||
|
||||
Arguments:
|
||||
weights (dict): Map of policy ids to weights to set.
|
||||
"""
|
||||
self.local_evaluator.set_weights(weights)
|
||||
|
||||
@DeveloperAPI
|
||||
def make_local_evaluator(self,
|
||||
env_creator,
|
||||
policy_graph,
|
||||
extra_config=None):
|
||||
"""Convenience method to return configured local evaluator."""
|
||||
|
||||
return self._make_evaluator(
|
||||
PolicyEvaluator,
|
||||
env_creator,
|
||||
policy_graph,
|
||||
0,
|
||||
merge_dicts(
|
||||
# important: allow local tf to use more CPUs for optimization
|
||||
merge_dicts(
|
||||
self.config, {
|
||||
"tf_session_args": self.
|
||||
config["local_evaluator_tf_session_args"]
|
||||
}),
|
||||
extra_config or {}))
|
||||
|
||||
@DeveloperAPI
|
||||
def make_remote_evaluators(self, env_creator, policy_graph, count):
|
||||
"""Convenience method to return a number of remote evaluators."""
|
||||
|
||||
remote_args = {
|
||||
"num_cpus": self.config["num_cpus_per_worker"],
|
||||
"num_gpus": self.config["num_gpus_per_worker"],
|
||||
"resources": self.config["custom_resources_per_worker"],
|
||||
}
|
||||
|
||||
cls = PolicyEvaluator.as_remote(**remote_args).remote
|
||||
|
||||
return [
|
||||
self._make_evaluator(cls, env_creator, policy_graph, i + 1,
|
||||
self.config) for i in range(count)
|
||||
]
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Export policy model with given policy_id to local directory.
|
||||
|
||||
Arguments:
|
||||
export_dir (string): Writable local directory.
|
||||
policy_id (string): Optional policy id to export.
|
||||
|
||||
Example:
|
||||
>>> trainer = MyTrainer()
|
||||
>>> for _ in range(10):
|
||||
>>> trainer.train()
|
||||
>>> trainer.export_policy_model("/tmp/export_dir")
|
||||
"""
|
||||
self.local_evaluator.export_policy_model(export_dir, policy_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_checkpoint(self,
|
||||
export_dir,
|
||||
filename_prefix="model",
|
||||
policy_id=DEFAULT_POLICY_ID):
|
||||
"""Export tensorflow policy model checkpoint to local directory.
|
||||
|
||||
Arguments:
|
||||
export_dir (string): Writable local directory.
|
||||
filename_prefix (string): file name prefix of checkpoint files.
|
||||
policy_id (string): Optional policy id to export.
|
||||
|
||||
Example:
|
||||
>>> trainer = MyTrainer()
|
||||
>>> for _ in range(10):
|
||||
>>> trainer.train()
|
||||
>>> trainer.export_policy_checkpoint("/tmp/export_dir")
|
||||
"""
|
||||
self.local_evaluator.export_policy_checkpoint(
|
||||
export_dir, filename_prefix, policy_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def collect_metrics(self, selected_evaluators=None):
|
||||
"""Collects metrics from the remote evaluators of this agent.
|
||||
|
||||
This is the same data as returned by a call to train().
|
||||
"""
|
||||
return self.optimizer.collect_metrics(
|
||||
self.config["collect_metrics_timeout"],
|
||||
min_history=self.config["metrics_smoothing_episodes"],
|
||||
selected_evaluators=selected_evaluators)
|
||||
|
||||
@classmethod
|
||||
def resource_help(cls, config):
|
||||
return ("\n\nYou can adjust the resource requests of RLlib agents by "
|
||||
"setting `num_workers`, `num_gpus`, and other configs. See "
|
||||
"the DEFAULT_CONFIG defined by each agent for more info.\n\n"
|
||||
"The config of this agent is: {}".format(config))
|
||||
|
||||
@staticmethod
|
||||
def _validate_config(config):
|
||||
if "gpu" in config:
|
||||
raise ValueError(
|
||||
"The `gpu` config is deprecated, please use `num_gpus=0|1` "
|
||||
"instead.")
|
||||
if "gpu_fraction" in config:
|
||||
raise ValueError(
|
||||
"The `gpu_fraction` config is deprecated, please use "
|
||||
"`num_gpus=<fraction>` instead.")
|
||||
if "use_gpu_for_workers" in config:
|
||||
raise ValueError(
|
||||
"The `use_gpu_for_workers` config is deprecated, please use "
|
||||
"`num_gpus_per_worker=1` instead.")
|
||||
if type(config["input_evaluation"]) != list:
|
||||
raise ValueError(
|
||||
"`input_evaluation` must be a list of strings, got {}".format(
|
||||
config["input_evaluation"]))
|
||||
|
||||
def _try_recover(self):
|
||||
"""Try to identify and blacklist any unhealthy workers.
|
||||
|
||||
This method is called after an unexpected remote error is encountered
|
||||
from a worker. It issues check requests to all current workers and
|
||||
blacklists any that respond with error. If no healthy workers remain,
|
||||
an error is raised.
|
||||
"""
|
||||
|
||||
if not self._has_policy_optimizer():
|
||||
raise NotImplementedError(
|
||||
"Recovery is not supported for this algorithm")
|
||||
|
||||
logger.info("Health checking all workers...")
|
||||
checks = []
|
||||
for ev in self.optimizer.remote_evaluators:
|
||||
_, obj_id = ev.sample_with_count.remote()
|
||||
checks.append(obj_id)
|
||||
|
||||
healthy_evaluators = []
|
||||
for i, obj_id in enumerate(checks):
|
||||
ev = self.optimizer.remote_evaluators[i]
|
||||
try:
|
||||
ray.get(obj_id)
|
||||
healthy_evaluators.append(ev)
|
||||
logger.info("Worker {} looks healthy".format(i + 1))
|
||||
except RayError:
|
||||
logger.exception("Blacklisting worker {}".format(i + 1))
|
||||
try:
|
||||
ev.__ray_terminate__.remote()
|
||||
except Exception:
|
||||
logger.exception("Error terminating unhealthy worker")
|
||||
|
||||
if len(healthy_evaluators) < 1:
|
||||
raise RuntimeError(
|
||||
"Not enough healthy workers remain to continue.")
|
||||
|
||||
self.optimizer.reset(healthy_evaluators)
|
||||
|
||||
def _has_policy_optimizer(self):
|
||||
return hasattr(self, "optimizer") and isinstance(
|
||||
self.optimizer, PolicyOptimizer)
|
||||
|
||||
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
|
||||
config):
|
||||
def session_creator():
|
||||
logger.debug("Creating TF session {}".format(
|
||||
config["tf_session_args"]))
|
||||
return tf.Session(
|
||||
config=tf.ConfigProto(**config["tf_session_args"]))
|
||||
|
||||
if isinstance(config["input"], FunctionType):
|
||||
input_creator = config["input"]
|
||||
elif config["input"] == "sampler":
|
||||
input_creator = (lambda ioctx: ioctx.default_sampler_input())
|
||||
elif isinstance(config["input"], dict):
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
MixedInput(config["input"], ioctx), config[
|
||||
"shuffle_buffer_size"]))
|
||||
else:
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
JsonReader(config["input"], ioctx), config[
|
||||
"shuffle_buffer_size"]))
|
||||
|
||||
if isinstance(config["output"], FunctionType):
|
||||
output_creator = config["output"]
|
||||
elif config["output"] is None:
|
||||
output_creator = (lambda ioctx: NoopOutput())
|
||||
elif config["output"] == "logdir":
|
||||
output_creator = (lambda ioctx: JsonWriter(
|
||||
ioctx.log_dir,
|
||||
ioctx,
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
else:
|
||||
output_creator = (lambda ioctx: JsonWriter(
|
||||
config["output"],
|
||||
ioctx,
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
|
||||
if config["input"] == "sampler":
|
||||
input_evaluation = []
|
||||
else:
|
||||
input_evaluation = config["input_evaluation"]
|
||||
|
||||
# Fill in the default policy graph if 'None' is specified in multiagent
|
||||
if self.config["multiagent"]["policy_graphs"]:
|
||||
tmp = self.config["multiagent"]["policy_graphs"]
|
||||
_validate_multiagent_config(tmp, allow_none_graph=True)
|
||||
for k, v in tmp.items():
|
||||
if v[0] is None:
|
||||
tmp[k] = (policy_graph, v[1], v[2], v[3])
|
||||
policy_graph = tmp
|
||||
|
||||
return cls(
|
||||
env_creator,
|
||||
policy_graph,
|
||||
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
|
||||
policies_to_train=self.config["multiagent"]["policies_to_train"],
|
||||
tf_session_creator=(session_creator
|
||||
if config["tf_session_args"] else None),
|
||||
batch_steps=config["sample_batch_size"],
|
||||
batch_mode=config["batch_mode"],
|
||||
episode_horizon=config["horizon"],
|
||||
preprocessor_pref=config["preprocessor_pref"],
|
||||
sample_async=config["sample_async"],
|
||||
compress_observations=config["compress_observations"],
|
||||
num_envs=config["num_envs_per_worker"],
|
||||
observation_filter=config["observation_filter"],
|
||||
clip_rewards=config["clip_rewards"],
|
||||
clip_actions=config["clip_actions"],
|
||||
env_config=config["env_config"],
|
||||
model_config=config["model"],
|
||||
policy_config=config,
|
||||
worker_index=worker_index,
|
||||
monitor_path=self.logdir if config["monitor"] else None,
|
||||
log_dir=self.logdir,
|
||||
log_level=config["log_level"],
|
||||
callbacks=config["callbacks"],
|
||||
input_creator=input_creator,
|
||||
input_evaluation=input_evaluation,
|
||||
output_creator=output_creator,
|
||||
remote_worker_envs=config["remote_worker_envs"],
|
||||
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
|
||||
soft_horizon=config["soft_horizon"],
|
||||
_fake_sampler=config.get("_fake_sampler", False))
|
||||
|
||||
@override(Trainable)
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
ExportFormat.validate(export_formats)
|
||||
exported = {}
|
||||
if ExportFormat.CHECKPOINT in export_formats:
|
||||
path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
|
||||
self.export_policy_checkpoint(path)
|
||||
exported[ExportFormat.CHECKPOINT] = path
|
||||
if ExportFormat.MODEL in export_formats:
|
||||
path = os.path.join(export_dir, ExportFormat.MODEL)
|
||||
self.export_policy_model(path)
|
||||
exported[ExportFormat.MODEL] = path
|
||||
return exported
|
||||
|
||||
def __getstate__(self):
|
||||
state = {}
|
||||
if hasattr(self, "local_evaluator"):
|
||||
state["evaluator"] = self.local_evaluator.save()
|
||||
if hasattr(self, "optimizer") and hasattr(self.optimizer, "save"):
|
||||
state["optimizer"] = self.optimizer.save()
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
if "evaluator" in state:
|
||||
self.local_evaluator.restore(state["evaluator"])
|
||||
remote_state = ray.put(state["evaluator"])
|
||||
for r in self.remote_evaluators:
|
||||
r.restore.remote(remote_state)
|
||||
if "optimizer" in state:
|
||||
self.optimizer.restore(state["optimizer"])
|
||||
|
||||
def _register_if_needed(self, env_object):
|
||||
if isinstance(env_object, six.string_types):
|
||||
return env_object
|
||||
elif isinstance(env_object, type):
|
||||
name = env_object.__name__
|
||||
register_env(name, lambda config: env_object(config))
|
||||
return name
|
||||
raise ValueError(
|
||||
"{} is an invalid env specification. ".format(env_object) +
|
||||
"You can specify a custom env as either a class "
|
||||
"(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").")
|
|
@ -4,25 +4,25 @@ from __future__ import print_function
|
|||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
class RandomAgent(Agent):
|
||||
"""Agent that takes random actions and never learns."""
|
||||
class RandomAgent(Trainer):
|
||||
"""Policy that takes random actions and never learns."""
|
||||
|
||||
_agent_name = "RandomAgent"
|
||||
_name = "RandomAgent"
|
||||
_default_config = with_common_config({
|
||||
"rollouts_per_iteration": 10,
|
||||
})
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
self.env = env_creator(config["env_config"])
|
||||
|
||||
@override(Agent)
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
rewards = []
|
||||
steps = 0
|
||||
|
@ -45,8 +45,8 @@ class RandomAgent(Agent):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
agent = RandomAgent(
|
||||
trainer = RandomAgent(
|
||||
env="CartPole-v0", config={"rollouts_per_iteration": 10})
|
||||
result = agent.train()
|
||||
result = trainer.train()
|
||||
assert result["episode_reward_mean"] > 10, result
|
||||
print("Test: OK")
|
||||
|
|
4
python/ray/rllib/env/external_env.py
vendored
4
python/ray/rllib/env/external_env.py
vendored
|
@ -34,9 +34,9 @@ class ExternalEnv(threading.Thread):
|
|||
|
||||
Examples:
|
||||
>>> register_env("my_env", lambda config: YourExternalEnv(config))
|
||||
>>> agent = DQNAgent(env="my_env")
|
||||
>>> trainer = DQNTrainer(env="my_env")
|
||||
>>> while True:
|
||||
print(agent.train())
|
||||
print(trainer.train())
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
|
|
|
@ -35,6 +35,19 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Handle to the current evaluator, which will be set to the most recently
|
||||
# created PolicyEvaluator in this process. This can be helpful to access in
|
||||
# custom env or policy classes for debugging or advanced use cases.
|
||||
_global_evaluator = None
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def get_global_evaluator():
|
||||
"""Returns a handle to the active policy evaluator in this process."""
|
||||
|
||||
global _global_evaluator
|
||||
return _global_evaluator
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class PolicyEvaluator(EvaluatorInterface):
|
||||
|
@ -215,6 +228,9 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
_fake_sampler (bool): Use a fake (inf speed) sampler for testing.
|
||||
"""
|
||||
|
||||
global _global_evaluator
|
||||
_global_evaluator = self
|
||||
|
||||
if log_level:
|
||||
logging.getLogger("ray.rllib").setLevel(log_level)
|
||||
|
||||
|
|
|
@ -71,12 +71,13 @@ class MultiAgentSampleBatchBuilder(object):
|
|||
corresponding policy batch for the agent's policy.
|
||||
"""
|
||||
|
||||
def __init__(self, policy_map, clip_rewards):
|
||||
def __init__(self, policy_map, clip_rewards, postp_callback):
|
||||
"""Initialize a MultiAgentSampleBatchBuilder.
|
||||
|
||||
Arguments:
|
||||
policy_map (dict): Maps policy ids to policy graph instances.
|
||||
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
||||
postp_callback: function to call on each postprocessed batch.
|
||||
"""
|
||||
|
||||
self.policy_map = policy_map
|
||||
|
@ -87,6 +88,7 @@ class MultiAgentSampleBatchBuilder(object):
|
|||
}
|
||||
self.agent_builders = {}
|
||||
self.agent_to_policy = {}
|
||||
self.postp_callback = postp_callback
|
||||
self.count = 0 # increment this manually
|
||||
|
||||
def total(self):
|
||||
|
@ -158,6 +160,8 @@ class MultiAgentSampleBatchBuilder(object):
|
|||
for agent_id, post_batch in sorted(post_batches.items()):
|
||||
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
|
||||
post_batch)
|
||||
if self.postp_callback:
|
||||
self.postp_callback({"episode": episode, "batch": post_batch})
|
||||
|
||||
self.agent_builders.clear()
|
||||
self.agent_to_policy.clear()
|
||||
|
|
|
@ -279,7 +279,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
|||
if batch_builder_pool:
|
||||
return batch_builder_pool.pop()
|
||||
else:
|
||||
return MultiAgentSampleBatchBuilder(policies, clip_rewards)
|
||||
return MultiAgentSampleBatchBuilder(
|
||||
policies, clip_rewards, callbacks.get("on_postprocess_traj"))
|
||||
|
||||
def new_episode():
|
||||
episode = MultiAgentEpisode(policies, policy_mapping_fn,
|
||||
|
|
|
@ -38,12 +38,21 @@ def on_sample_end(info):
|
|||
|
||||
|
||||
def on_train_result(info):
|
||||
print("agent.train() result: {} -> {} episodes".format(
|
||||
info["agent"], info["result"]["episodes_this_iter"]))
|
||||
print("trainer.train() result: {} -> {} episodes".format(
|
||||
info["trainer"], info["result"]["episodes_this_iter"]))
|
||||
# you can mutate the result dict to add new fields to return
|
||||
info["result"]["callback_ok"] = True
|
||||
|
||||
|
||||
def on_postprocess_traj(info):
|
||||
episode = info["episode"]
|
||||
batch = info["batch"]
|
||||
print("postprocessed {} steps".format(batch.count))
|
||||
if "num_batches" not in episode.custom_metrics:
|
||||
episode.custom_metrics["num_batches"] = 0
|
||||
episode.custom_metrics["num_batches"] += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-iters", type=int, default=2000)
|
||||
|
@ -63,6 +72,7 @@ if __name__ == "__main__":
|
|||
"on_episode_end": tune.function(on_episode_end),
|
||||
"on_sample_end": tune.function(on_sample_end),
|
||||
"on_train_result": tune.function(on_train_result),
|
||||
"on_postprocess_traj": tune.function(on_postprocess_traj),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
@ -73,5 +83,6 @@ if __name__ == "__main__":
|
|||
assert "pole_angle_mean" in custom_metrics
|
||||
assert "pole_angle_min" in custom_metrics
|
||||
assert "pole_angle_max" in custom_metrics
|
||||
assert "num_batches_mean" in custom_metrics
|
||||
assert type(custom_metrics["pole_angle_mean"]) is float
|
||||
assert "callback_ok" in trials[0].last_result
|
||||
|
|
|
@ -12,12 +12,12 @@ from __future__ import print_function
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.ppo import PPOAgent
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
|
||||
|
||||
def my_train_fn(config, reporter):
|
||||
# Train for 100 iterations with high LR
|
||||
agent1 = PPOAgent(env="CartPole-v0", config=config)
|
||||
agent1 = PPOTrainer(env="CartPole-v0", config=config)
|
||||
for _ in range(10):
|
||||
result = agent1.train()
|
||||
result["phase"] = 1
|
||||
|
@ -28,7 +28,7 @@ def my_train_fn(config, reporter):
|
|||
|
||||
# Train for 100 iterations with low LR
|
||||
config["lr"] = 0.0001
|
||||
agent2 = PPOAgent(env="CartPole-v0", config=config)
|
||||
agent2 = PPOTrainer(env="CartPole-v0", config=config)
|
||||
agent2.restore(state)
|
||||
for _ in range(10):
|
||||
result = agent2.train()
|
||||
|
|
|
@ -15,9 +15,9 @@ import argparse
|
|||
import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dqn.dqn import DQNAgent
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
|
||||
from ray.rllib.agents.ppo.ppo import PPOAgent
|
||||
from ray.rllib.agents.ppo.ppo import PPOTrainer
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
|
||||
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
|
||||
from ray.tune.logger import pretty_print
|
||||
|
@ -49,7 +49,7 @@ if __name__ == "__main__":
|
|||
else:
|
||||
return "dqn_policy"
|
||||
|
||||
ppo_trainer = PPOAgent(
|
||||
ppo_trainer = PPOTrainer(
|
||||
env="multi_cartpole",
|
||||
config={
|
||||
"multiagent": {
|
||||
|
@ -62,7 +62,7 @@ if __name__ == "__main__":
|
|||
"observation_filter": "NoFilter",
|
||||
})
|
||||
|
||||
dqn_trainer = DQNAgent(
|
||||
dqn_trainer = DQNTrainer(
|
||||
env="multi_cartpole",
|
||||
config={
|
||||
"multiagent": {
|
||||
|
|
|
@ -13,7 +13,7 @@ from gym import spaces
|
|||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dqn import DQNAgent
|
||||
from ray.rllib.agents.dqn import DQNTrainer
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.utils.policy_server import PolicyServer
|
||||
from ray.tune.logger import pretty_print
|
||||
|
@ -43,7 +43,7 @@ if __name__ == "__main__":
|
|||
|
||||
# We use DQN since it supports off-policy actions, but you can choose and
|
||||
# configure any agent.
|
||||
dqn = DQNAgent(
|
||||
dqn = DQNTrainer(
|
||||
env="srv",
|
||||
config={
|
||||
# Use a single process to avoid needing to set up a load balancer
|
||||
|
|
|
@ -35,7 +35,7 @@ class MinibatchBuffer(object):
|
|||
released: True if the item is now removed from the ring buffer.
|
||||
"""
|
||||
if self.ttl[self.idx] <= 0:
|
||||
self.buffers[self.idx] = self.inqueue.get(timeout=60.0)
|
||||
self.buffers[self.idx] = self.inqueue.get(timeout=300.0)
|
||||
self.ttl[self.idx] = self.cur_max_ttl
|
||||
if self.cur_max_ttl < self.max_ttl:
|
||||
self.cur_max_ttl += 1
|
||||
|
|
|
@ -7,7 +7,7 @@ from gym.spaces import Tuple, Discrete, Dict, Box
|
|||
import ray
|
||||
from ray.tune import register_env
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.agents.qmix import QMixAgent
|
||||
from ray.rllib.agents.qmix import QMixTrainer
|
||||
|
||||
|
||||
class AvailActionsTestEnv(MultiAgentEnv):
|
||||
|
@ -55,7 +55,7 @@ if __name__ == "__main__":
|
|||
grouping, obs_space=obs_space, act_space=act_space))
|
||||
|
||||
ray.init()
|
||||
agent = QMixAgent(
|
||||
agent = QMixTrainer(
|
||||
env="action_mask_test",
|
||||
config={
|
||||
"num_envs_per_worker": 5, # test with vectorization on
|
||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import print_function
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dqn import DQNAgent
|
||||
from ray.rllib.agents.dqn import DQNTrainer
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import _adjust_nstep
|
||||
|
||||
|
||||
|
@ -26,7 +26,8 @@ class DQNTest(unittest.TestCase):
|
|||
|
||||
def testEvaluationOption(self):
|
||||
ray.init()
|
||||
agent = DQNAgent(env="CartPole-v0", config={"evaluation_interval": 2})
|
||||
agent = DQNTrainer(
|
||||
env="CartPole-v0", config={"evaluation_interval": 2})
|
||||
r0 = agent.train()
|
||||
r1 = agent.train()
|
||||
r2 = agent.train()
|
||||
|
|
|
@ -9,8 +9,8 @@ import unittest
|
|||
import uuid
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dqn import DQNAgent
|
||||
from ray.rllib.agents.pg import PGAgent
|
||||
from ray.rllib.agents.dqn import DQNTrainer
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.tests.test_policy_evaluator import (BadPolicyGraph,
|
||||
|
@ -163,7 +163,7 @@ class TestExternalEnv(unittest.TestCase):
|
|||
register_env(
|
||||
"test3", lambda _: PartOffPolicyServing(
|
||||
gym.make("CartPole-v0"), off_pol_frac=0.2))
|
||||
dqn = DQNAgent(env="test3", config={"exploration_fraction": 0.001})
|
||||
dqn = DQNTrainer(env="test3", config={"exploration_fraction": 0.001})
|
||||
for i in range(100):
|
||||
result = dqn.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
|
@ -174,7 +174,7 @@ class TestExternalEnv(unittest.TestCase):
|
|||
|
||||
def testTrainCartpole(self):
|
||||
register_env("test", lambda _: SimpleServing(gym.make("CartPole-v0")))
|
||||
pg = PGAgent(env="test", config={"num_workers": 0})
|
||||
pg = PGTrainer(env="test", config={"num_workers": 0})
|
||||
for i in range(100):
|
||||
result = pg.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
|
@ -186,7 +186,7 @@ class TestExternalEnv(unittest.TestCase):
|
|||
def testTrainCartpoleMulti(self):
|
||||
register_env("test2",
|
||||
lambda _: MultiServing(lambda: gym.make("CartPole-v0")))
|
||||
pg = PGAgent(env="test2", config={"num_workers": 0})
|
||||
pg = PGTrainer(env="test2", config={"num_workers": 0})
|
||||
for i in range(100):
|
||||
result = pg.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
|
|
|
@ -14,7 +14,7 @@ import time
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg import PGAgent
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.evaluation import SampleBatch
|
||||
from ray.rllib.offline import IOContext, JsonWriter, JsonReader
|
||||
|
@ -44,7 +44,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def writeOutputs(self, output):
|
||||
agent = PGAgent(
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"output": output,
|
||||
|
@ -65,7 +65,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
|
||||
def testAgentInputDir(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
agent = PGAgent(
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
|
@ -97,7 +97,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
for data in out:
|
||||
f.write(json.dumps(data))
|
||||
|
||||
agent = PGAgent(
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
|
@ -111,7 +111,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
|
||||
def testAgentInputEvalSim(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
agent = PGAgent(
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
|
@ -126,7 +126,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
|
||||
def testAgentInputList(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
agent = PGAgent(
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": glob.glob(self.test_dir + "/*.json"),
|
||||
|
@ -139,7 +139,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
|
||||
def testAgentInputDict(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
agent = PGAgent(
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": {
|
||||
|
@ -161,7 +161,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
act_space = single_env.action_space
|
||||
return (PGPolicyGraph, obs_space, act_space, {})
|
||||
|
||||
pg = PGAgent(
|
||||
pg = PGTrainer(
|
||||
env="multi_cartpole",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
|
@ -180,7 +180,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
self.assertEqual(len(os.listdir(self.test_dir)), 1)
|
||||
|
||||
pg.stop()
|
||||
pg = PGAgent(
|
||||
pg = PGTrainer(
|
||||
env="multi_cartpole",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import print_function
|
|||
|
||||
import unittest
|
||||
|
||||
from ray.rllib.agents.ppo import PPOAgent, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG
|
||||
import ray
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ class LocalModeTest(unittest.TestCase):
|
|||
def testLocal(self):
|
||||
ray.init(local_mode=True)
|
||||
cf = DEFAULT_CONFIG.copy()
|
||||
agent = PPOAgent(cf, "CartPole-v0")
|
||||
agent = PPOTrainer(cf, "CartPole-v0")
|
||||
print(agent.train())
|
||||
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import tensorflow as tf
|
|||
import tensorflow.contrib.rnn as rnn
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo import PPOAgent
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.lstm import add_time_dimension, chop_into_sequences
|
||||
from ray.rllib.models.misc import linear, normc_initializer
|
||||
|
@ -149,7 +149,7 @@ class RNNSequencing(unittest.TestCase):
|
|||
def testSimpleOptimizerSequencing(self):
|
||||
ModelCatalog.register_custom_model("rnn", RNNSpyModel)
|
||||
register_env("counter", lambda _: DebugCounterEnv())
|
||||
ppo = PPOAgent(
|
||||
ppo = PPOTrainer(
|
||||
env="counter",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
|
@ -205,7 +205,7 @@ class RNNSequencing(unittest.TestCase):
|
|||
def testMinibatchSequencing(self):
|
||||
ModelCatalog.register_custom_model("rnn", RNNSpyModel)
|
||||
register_env("counter", lambda _: DebugCounterEnv())
|
||||
ppo = PPOAgent(
|
||||
ppo = PPOTrainer(
|
||||
env="counter",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
|
|
|
@ -7,7 +7,7 @@ import random
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg import PGAgent
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
|
||||
from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer,
|
||||
|
@ -519,7 +519,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
def testTrainMultiCartpoleSinglePolicy(self):
|
||||
n = 10
|
||||
register_env("multi_cartpole", lambda _: MultiCartpole(n))
|
||||
pg = PGAgent(env="multi_cartpole", config={"num_workers": 0})
|
||||
pg = PGTrainer(env="multi_cartpole", config={"num_workers": 0})
|
||||
for i in range(100):
|
||||
result = pg.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
|
@ -542,7 +542,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
act_space = single_env.action_space
|
||||
return (None, obs_space, act_space, config)
|
||||
|
||||
pg = PGAgent(
|
||||
pg = PGTrainer(
|
||||
env="multi_cartpole",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
|
|
|
@ -12,8 +12,8 @@ import tensorflow as tf
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c import A2CAgent
|
||||
from ray.rllib.agents.pg import PGAgent
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.env import MultiAgentEnv
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
|
@ -215,7 +215,7 @@ class TupleSpyModel(Model):
|
|||
class NestedSpacesTest(unittest.TestCase):
|
||||
def testInvalidModel(self):
|
||||
ModelCatalog.register_custom_model("invalid", InvalidModel)
|
||||
self.assertRaises(ValueError, lambda: PGAgent(
|
||||
self.assertRaises(ValueError, lambda: PGTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"model": {
|
||||
"custom_model": "invalid",
|
||||
|
@ -226,7 +226,7 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
ModelCatalog.register_custom_model("invalid2", InvalidModel2)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Expected output.*",
|
||||
lambda: PGAgent(
|
||||
lambda: PGTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"model": {
|
||||
"custom_model": "invalid2",
|
||||
|
@ -236,7 +236,7 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
def doTestNestedDict(self, make_env, test_lstm=False):
|
||||
ModelCatalog.register_custom_model("composite", DictSpyModel)
|
||||
register_env("nested", make_env)
|
||||
pg = PGAgent(
|
||||
pg = PGTrainer(
|
||||
env="nested",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
|
@ -265,7 +265,7 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
def doTestNestedTuple(self, make_env):
|
||||
ModelCatalog.register_custom_model("composite2", TupleSpyModel)
|
||||
register_env("nested2", make_env)
|
||||
pg = PGAgent(
|
||||
pg = PGTrainer(
|
||||
env="nested2",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
|
@ -323,7 +323,7 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
ModelCatalog.register_custom_model("tuple_spy", TupleSpyModel)
|
||||
register_env("nested_ma", lambda _: NestedMultiAgentEnv())
|
||||
act_space = spaces.Discrete(2)
|
||||
pg = PGAgent(
|
||||
pg = PGTrainer(
|
||||
env="nested_ma",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
|
@ -370,13 +370,13 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
|
||||
def testRolloutDictSpace(self):
|
||||
register_env("nested", lambda _: NestedDictEnv())
|
||||
agent = PGAgent(env="nested")
|
||||
agent = PGTrainer(env="nested")
|
||||
agent.train()
|
||||
path = agent.save()
|
||||
agent.stop()
|
||||
|
||||
# Test train works on restore
|
||||
agent2 = PGAgent(env="nested")
|
||||
agent2 = PGTrainer(env="nested")
|
||||
agent2.restore(path)
|
||||
agent2.train()
|
||||
|
||||
|
@ -386,7 +386,7 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
def testPyTorchModel(self):
|
||||
ModelCatalog.register_custom_model("composite", TorchSpyModel)
|
||||
register_env("nested", lambda _: NestedDictEnv())
|
||||
a2c = A2CAgent(
|
||||
a2c = A2CTrainer(
|
||||
env="nested",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
|
|
|
@ -9,7 +9,7 @@ import time
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo import PPOAgent
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
|
||||
from ray.rllib.evaluation import SampleBatch
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
|
@ -41,7 +41,7 @@ class PPOCollectTest(unittest.TestCase):
|
|||
ray.init(num_cpus=4)
|
||||
|
||||
# Check we at least collect the initial wave of samples
|
||||
ppo = PPOAgent(
|
||||
ppo = PPOTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"sample_batch_size": 200,
|
||||
|
@ -53,7 +53,7 @@ class PPOCollectTest(unittest.TestCase):
|
|||
ppo.stop()
|
||||
|
||||
# Check we collect at least the specified amount of samples
|
||||
ppo = PPOAgent(
|
||||
ppo = PPOTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"sample_batch_size": 200,
|
||||
|
@ -65,7 +65,7 @@ class PPOCollectTest(unittest.TestCase):
|
|||
ppo.stop()
|
||||
|
||||
# Check in vectorized mode
|
||||
ppo = PPOAgent(
|
||||
ppo = PPOTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"sample_batch_size": 200,
|
||||
|
@ -78,7 +78,7 @@ class PPOCollectTest(unittest.TestCase):
|
|||
ppo.stop()
|
||||
|
||||
# Check legacy mode
|
||||
ppo = PPOAgent(
|
||||
ppo = PPOTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"sample_batch_size": 200,
|
||||
|
|
|
@ -10,8 +10,8 @@ import unittest
|
|||
from collections import Counter
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg import PGAgent
|
||||
from ray.rllib.agents.a3c import A2CAgent
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
|
@ -170,7 +170,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
|||
print()
|
||||
|
||||
def testGlobalVarsUpdate(self):
|
||||
agent = A2CAgent(
|
||||
agent = A2CTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"lr_schedule": [[0, 0.1], [400, 0.000001]],
|
||||
|
@ -182,12 +182,12 @@ class TestPolicyEvaluator(unittest.TestCase):
|
|||
|
||||
def testNoStepOnInit(self):
|
||||
register_env("fail", lambda _: FailOnStepEnv())
|
||||
pg = PGAgent(env="fail", config={"num_workers": 1})
|
||||
pg = PGTrainer(env="fail", config={"num_workers": 1})
|
||||
self.assertRaises(Exception, lambda: pg.train())
|
||||
|
||||
def testCallbacks(self):
|
||||
counts = Counter()
|
||||
pg = PGAgent(
|
||||
pg = PGTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 50,
|
||||
|
@ -211,7 +211,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
|||
|
||||
def testQueryEvaluators(self):
|
||||
register_env("test", lambda _: gym.make("CartPole-v0"))
|
||||
pg = PGAgent(
|
||||
pg = PGTrainer(
|
||||
env="test",
|
||||
config={
|
||||
"num_workers": 2,
|
||||
|
|
|
@ -1,10 +1,33 @@
|
|||
import logging
|
||||
|
||||
from ray.rllib.utils.filter_manager import FilterManager
|
||||
from ray.rllib.utils.filter import Filter
|
||||
from ray.rllib.utils.policy_client import PolicyClient
|
||||
from ray.rllib.utils.policy_server import PolicyServer
|
||||
from ray.tune.util import merge_dicts, deep_update
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def renamed_class(cls):
|
||||
class DeprecationWrapper(cls):
|
||||
def __init__(self, *args, **kwargs):
|
||||
old_name = cls.__name__.replace("Trainer", "Agent")
|
||||
new_name = cls.__name__
|
||||
logger.warn("DeprecationWarning: {} has been renamed to {}. ".
|
||||
format(old_name, new_name) +
|
||||
"This will raise an error in the future.")
|
||||
cls.__init__(self, *args, **kwargs)
|
||||
|
||||
return DeprecationWrapper
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Filter", "FilterManager", "PolicyClient", "PolicyServer", "merge_dicts",
|
||||
"deep_update"
|
||||
"Filter",
|
||||
"FilterManager",
|
||||
"PolicyClient",
|
||||
"PolicyServer",
|
||||
"merge_dicts",
|
||||
"deep_update",
|
||||
"renamed_class",
|
||||
]
|
||||
|
|
|
@ -27,10 +27,10 @@ def PublicAPI(obj):
|
|||
can expect these APIs to remain stable across RLlib releases.
|
||||
|
||||
Subclasses that inherit from a ``@PublicAPI`` base class can be
|
||||
assumed part of the RLlib public API as well (e.g., all agent classes
|
||||
are in public API because Agent is ``@PublicAPI``).
|
||||
assumed part of the RLlib public API as well (e.g., all trainer classes
|
||||
are in public API because Trainer is ``@PublicAPI``).
|
||||
|
||||
In addition, you can assume all agent configurations are part of their
|
||||
In addition, you can assume all trainer configurations are part of their
|
||||
public API as well.
|
||||
"""
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ class PolicyServer(ThreadingMixIn, HTTPServer):
|
|||
server = PolicyServer(self, "localhost", 8900)
|
||||
server.serve_forever()
|
||||
>>> register_env("srv", lambda _: CartpoleServing())
|
||||
>>> pg = PGAgent(env="srv", config={"num_workers": 0})
|
||||
>>> pg = PGTrainer(env="srv", config={"num_workers": 0})
|
||||
>>> while True:
|
||||
pg.train()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue