From 37208216ae28687ff1c738a518642944592e5472 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 7 Apr 2019 00:36:18 -0700 Subject: [PATCH] [rllib] Rename Agent to Trainer (#4556) --- doc/source/index.rst | 2 +- doc/source/rllib-api.svg | 5 +- doc/source/rllib-config.svg | 2 +- doc/source/rllib-dev.rst | 14 +- doc/source/rllib-env.rst | 10 +- doc/source/rllib-models.rst | 16 +- doc/source/rllib-offline.rst | 10 +- doc/source/rllib-package-ref.rst | 23 +- doc/source/rllib-training.rst | 72 +- doc/source/rllib.rst | 20 +- python/ray/rllib/agents/__init__.py | 5 +- python/ray/rllib/agents/a3c/__init__.py | 12 +- python/ray/rllib/agents/a3c/a2c.py | 10 +- python/ray/rllib/agents/a3c/a3c.py | 10 +- python/ray/rllib/agents/agent.py | 796 +---------------- python/ray/rllib/agents/ars/__init__.py | 7 +- python/ray/rllib/agents/ars/ars.py | 14 +- python/ray/rllib/agents/ddpg/__init__.py | 13 +- python/ray/rllib/agents/ddpg/apex.py | 9 +- python/ray/rllib/agents/ddpg/ddpg.py | 10 +- python/ray/rllib/agents/dqn/__init__.py | 12 +- python/ray/rllib/agents/dqn/apex.py | 8 +- python/ray/rllib/agents/dqn/dqn.py | 16 +- python/ray/rllib/agents/es/__init__.py | 7 +- python/ray/rllib/agents/es/es.py | 14 +- python/ray/rllib/agents/impala/__init__.py | 7 +- python/ray/rllib/agents/impala/impala.py | 12 +- python/ray/rllib/agents/marwil/__init__.py | 4 +- python/ray/rllib/agents/marwil/marwil.py | 10 +- python/ray/rllib/agents/mock.py | 22 +- python/ray/rllib/agents/pg/__init__.py | 7 +- python/ray/rllib/agents/pg/pg.py | 10 +- python/ray/rllib/agents/ppo/__init__.py | 9 +- python/ray/rllib/agents/ppo/appo.py | 8 +- python/ray/rllib/agents/ppo/ppo.py | 10 +- python/ray/rllib/agents/qmix/__init__.py | 6 +- python/ray/rllib/agents/qmix/apex.py | 9 +- python/ray/rllib/agents/qmix/qmix.py | 8 +- python/ray/rllib/agents/registry.py | 38 +- python/ray/rllib/agents/trainer.py | 817 ++++++++++++++++++ .../contrib/random_agent/random_agent.py | 16 +- python/ray/rllib/env/external_env.py | 4 +- .../ray/rllib/evaluation/policy_evaluator.py | 16 + .../rllib/evaluation/sample_batch_builder.py | 6 +- python/ray/rllib/evaluation/sampler.py | 3 +- .../examples/custom_metrics_and_callbacks.py | 15 +- python/ray/rllib/examples/custom_train_fn.py | 6 +- .../rllib/examples/multiagent_two_trainers.py | 8 +- .../rllib/examples/serving/cartpole_server.py | 4 +- .../rllib/optimizers/aso_minibatch_buffer.py | 2 +- .../rllib/tests/test_avail_actions_qmix.py | 4 +- python/ray/rllib/tests/test_evaluators.py | 5 +- python/ray/rllib/tests/test_external_env.py | 10 +- python/ray/rllib/tests/test_io.py | 18 +- python/ray/rllib/tests/test_local.py | 4 +- python/ray/rllib/tests/test_lstm.py | 6 +- .../ray/rllib/tests/test_multi_agent_env.py | 6 +- python/ray/rllib/tests/test_nested_spaces.py | 20 +- python/ray/rllib/tests/test_optimizers.py | 10 +- .../ray/rllib/tests/test_policy_evaluator.py | 12 +- python/ray/rllib/utils/__init__.py | 27 +- python/ray/rllib/utils/annotations.py | 6 +- python/ray/rllib/utils/policy_server.py | 2 +- 63 files changed, 1212 insertions(+), 1092 deletions(-) create mode 100644 python/ray/rllib/agents/trainer.py diff --git a/doc/source/index.rst b/doc/source/index.rst index b04009cf8..d47dcbacc 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -93,8 +93,8 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin rllib.rst rllib-training.rst rllib-env.rst - rllib-algorithms.rst rllib-models.rst + rllib-algorithms.rst rllib-offline.rst rllib-dev.rst rllib-concepts.rst diff --git a/doc/source/rllib-api.svg b/doc/source/rllib-api.svg index e357b0f3e..e157e106f 100644 --- a/doc/source/rllib-api.svg +++ b/doc/source/rllib-api.svg @@ -1,4 +1 @@ - - - - + \ No newline at end of file diff --git a/doc/source/rllib-config.svg b/doc/source/rllib-config.svg index 6bc412a60..04331f5f3 100644 --- a/doc/source/rllib-config.svg +++ b/doc/source/rllib-config.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/doc/source/rllib-dev.rst b/doc/source/rllib-dev.rst index feb30ef58..1445e043f 100644 --- a/doc/source/rllib-dev.rst +++ b/doc/source/rllib-dev.rst @@ -31,13 +31,13 @@ Contributing Algorithms These are the guidelines for merging new algorithms into RLlib: * Contributed algorithms (`rllib/contrib `__): - - must subclass Agent and implement the ``_train()`` method + - must subclass Trainer and implement the ``_train()`` method - must include a lightweight test (`example `__) to ensure the algorithm runs - should include tuned hyperparameter examples and documentation - should offer functionality not present in existing algorithms * Fully integrated algorithms (`rllib/agents `__) have the following additional requirements: - - must fully implement the Agent API + - must fully implement the Trainer API - must offer substantial new functionality not possible to add to other algorithms - should support custom models and preprocessors - should use RLlib abstractions and support distributed execution @@ -46,14 +46,14 @@ Both integrated and contributed algorithms ship with the ``ray`` PyPI package, a How to add an algorithm to ``contrib`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -It takes just two changes to add an algorithm to `contrib `__. A minimal example can be found `here `__. First, subclass `Agent `__ and implement the ``_init`` and ``_train`` methods: +It takes just two changes to add an algorithm to `contrib `__. A minimal example can be found `here `__. First, subclass `Trainer `__ and implement the ``_init`` and ``_train`` methods: .. literalinclude:: ../../python/ray/rllib/contrib/random_agent/random_agent.py :language: python :start-after: __sphinx_doc_begin__ :end-before: __sphinx_doc_end__ -Second, register the agent with a name in `contrib/registry.py `__. +Second, register the trainer with a name in `contrib/registry.py `__. .. code-block:: python @@ -66,12 +66,12 @@ Second, register the agent with a name in `contrib/registry.py `__. Custom env classes passed directly to the agent must take a single ``env_config`` parameter in their constructor: +You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name `__. Custom env classes passed directly to the trainer must take a single ``env_config`` parameter in their constructor: .. code-block:: python @@ -43,7 +43,7 @@ You can pass either a string name or a Python class to specify an environment. B return , , , ray.init() - trainer = ppo.PPOAgent(env=MyEnv, config={ + trainer = ppo.PPOTrainer(env=MyEnv, config={ "env_config": {}, # config to pass to env class }) @@ -60,7 +60,7 @@ You can also register a custom env creator function with a string name. This fun return MyEnv(...) # return an env instance register_env("my_env", env_creator) - trainer = ppo.PPOAgent(env="my_env") + trainer = ppo.PPOTrainer(env="my_env") For a full runnable code example using the custom environment API, see `custom_env.py `__. @@ -71,7 +71,7 @@ For a full runnable code example using the custom environment API, see `custom_e Configuring Environments ------------------------ -In the above example, note that the ``env_creator`` function takes in an ``env_config`` object. This is a dict containing options passed in through your agent. You can also access ``env_config.worker_index`` and ``env_config.vector_index`` to get the worker id and env id within the worker (if ``num_envs_per_worker > 0``). This can be useful if you want to train over an ensemble of different environments, for example: +In the above example, note that the ``env_creator`` function takes in an ``env_config`` object. This is a dict containing options passed in through your trainer. You can also access ``env_config.worker_index`` and ``env_config.vector_index`` to get the worker id and env id within the worker (if ``num_envs_per_worker > 0``). This can be useful if you want to train over an ensemble of different environments, for example: .. code-block:: python @@ -128,7 +128,7 @@ Multi-Agent and Hierarchical .. note:: - Learn more about multi-agent reinforcement learning in RLlib by reading the `blog post `__. + Learn more about multi-agent reinforcement learning in RLlib by checking out some of the `code examples `__ or reading the `blog post `__. A multi-agent environment is one which has multiple acting entities per step, e.g., in a traffic simulation, there may be multiple "car" and "traffic light" agents in the environment. The model for multi-agent in RLlib as follows: (1) as a user you define the number of policies available up front, and (2) a function that maps agent ids to policy ids. This is summarized by the below figure: diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index d0c1a1082..7fd860a65 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -127,7 +127,7 @@ Custom TF models should subclass the common RLlib `model class `__. See the +Each algorithm has specific hyperparameters that can be set with ``--config``, in addition to a number of `common hyperparameters `__. See the `algorithms documentation `__ for more information. In an example below, we train A2C by specifying 8 workers through the config flag. @@ -77,16 +77,16 @@ In an example below, we train A2C by specifying 8 workers through the config fla Specifying Resources ~~~~~~~~~~~~~~~~~~~~ -You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most agents. The number of GPUs the driver should use can be set via the ``num_gpus`` option. Similarly, the resource allocation to workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five agents onto one GPU by setting ``num_gpus: 0.2``. +You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most algorithms. The number of GPUs the driver should use can be set via the ``num_gpus`` option. Similarly, the resource allocation to workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five trainers onto one GPU by setting ``num_gpus: 0.2``. .. image:: rllib-config.svg Common Parameters ~~~~~~~~~~~~~~~~~ -The following is a list of the common agent hyperparameters: +The following is a list of the common algorithm hyperparameters: -.. literalinclude:: ../../python/ray/rllib/agents/agent.py +.. literalinclude:: ../../python/ray/rllib/agents/trainer.py :language: python :start-after: __sphinx_doc_begin__ :end-before: __sphinx_doc_end__ @@ -122,25 +122,25 @@ Here is an example of the basic usage (for a more complete example, see `custom_ config = ppo.DEFAULT_CONFIG.copy() config["num_gpus"] = 0 config["num_workers"] = 1 - agent = ppo.PPOAgent(config=config, env="CartPole-v0") + trainer = ppo.PPOTrainer(config=config, env="CartPole-v0") - # Can optionally call agent.restore(path) to load a checkpoint. + # Can optionally call trainer.restore(path) to load a checkpoint. for i in range(1000): # Perform one iteration of training the policy with PPO - result = agent.train() + result = trainer.train() print(pretty_print(result)) if i % 100 == 0: - checkpoint = agent.save() + checkpoint = trainer.save() print("checkpoint saved at", checkpoint) .. note:: - It's recommended that you run RLlib agents with `Tune `__, for easy experiment management and visualization of results. Just set ``"run": AGENT_NAME, "env": ENV_NAME`` in the experiment config. + It's recommended that you run RLlib trainers with `Tune `__, for easy experiment management and visualization of results. Just set ``"run": ALG_NAME, "env": ENV_NAME`` in the experiment config. -All RLlib agents are compatible with the `Tune API `__. This enables them to be easily used in experiments with `Tune `__. For example, the following code performs a simple hyperparam sweep of PPO: +All RLlib trainers are compatible with the `Tune API `__. This enables them to be easily used in experiments with `Tune `__. For example, the following code performs a simple hyperparam sweep of PPO: .. code-block:: python @@ -176,27 +176,27 @@ Tune will schedule the trials to run in parallel on your Ray cluster: Custom Training Workflows ~~~~~~~~~~~~~~~~~~~~~~~~~ -In the `basic training example `__, Tune will call ``train()`` on your agent once per iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports `custom trainable functions `__ that can be used to implement `custom training workflows (example) `__. +In the `basic training example `__, Tune will call ``train()`` on your trainer once per iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports `custom trainable functions `__ that can be used to implement `custom training workflows (example) `__. Accessing Policy State ~~~~~~~~~~~~~~~~~~~~~~ -It is common to need to access an agent's internal state, e.g., to set or get internal weights. In RLlib an agent's state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``agent.optimizer.foreach_evaluator()`` or ``agent.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list. +It is common to need to access a trainer's internal state, e.g., to set or get internal weights. In RLlib trainer state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``trainer.optimizer.foreach_evaluator()`` or ``trainer.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list. -You can also access just the "master" copy of the agent state through ``agent.get_policy()`` or ``agent.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``agent.get_policy().get_weights()``. This is also equivalent to ``agent.local_evaluator.policy_map["default_policy"].get_weights()``: +You can also access just the "master" copy of the trainer state through ``trainer.get_policy()`` or ``trainer.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``trainer.get_policy().get_weights()``. This is also equivalent to ``trainer.local_evaluator.policy_map["default_policy"].get_weights()``: .. code-block:: python # Get weights of the default local policy - agent.get_policy().get_weights() + trainer.get_policy().get_weights() # Same as above - agent.local_evaluator.policy_map["default_policy"].get_weights() + trainer.local_evaluator.policy_map["default_policy"].get_weights() # Get list of weights of each evaluator, including remote replicas - agent.optimizer.foreach_evaluator(lambda ev: ev.get_policy().get_weights()) + trainer.optimizer.foreach_evaluator(lambda ev: ev.get_policy().get_weights()) # Same as above - agent.optimizer.foreach_evaluator_with_index(lambda ev, i: ev.get_policy().get_weights()) + trainer.optimizer.foreach_evaluator_with_index(lambda ev, i: ev.get_policy().get_weights()) Global Coordination ~~~~~~~~~~~~~~~~~~~ @@ -252,8 +252,8 @@ You can provide callback functions to be called at points during policy evaluati episode.custom_metrics["pole_angle"] = pole_angle def on_train_result(info): - print("agent.train() result: {} -> {} episodes".format( - info["agent"].__name__, info["result"]["episodes_this_iter"])) + print("trainer.train() result: {} -> {} episodes".format( + info["trainer"].__name__, info["result"]["episodes_this_iter"])) ray.init() trials = tune.run( @@ -278,18 +278,18 @@ Example: Curriculum Learning Let's look at two ways to use the above APIs to implement `curriculum learning `__. In curriculum learning, the agent task is adjusted over time to improve the learning process. Suppose that we have an environment class with a ``set_phase()`` method that we can call to adjust the task difficulty over time: -Approach 1: Use the Agent API and update the environment between calls to ``train()``. This example shows the agent being run inside a Tune function: +Approach 1: Use the Trainer API and update the environment between calls to ``train()``. This example shows the trainer being run inside a Tune function: .. code-block:: python import ray from ray import tune - from ray.rllib.agents.ppo import PPOAgent + from ray.rllib.agents.ppo import PPOTrainer def train(config, reporter): - agent = PPOAgent(config=config, env=YourEnv) + trainer = PPOTrainer(config=config, env=YourEnv) while True: - result = agent.train() + result = trainer.train() reporter(**result) if result["episode_reward_mean"] > 200: phase = 2 @@ -297,7 +297,7 @@ Approach 1: Use the Agent API and update the environment between calls to ``trai phase = 1 else: phase = 0 - agent.optimizer.foreach_evaluator( + trainer.optimizer.foreach_evaluator( lambda ev: ev.foreach_env( lambda env: env.set_phase(phase))) @@ -330,8 +330,8 @@ Approach 2: Use the callbacks API to update the environment on new training resu phase = 1 else: phase = 0 - agent = info["agent"] - agent.optimizer.foreach_evaluator( + trainer = info["trainer"] + trainer.optimizer.foreach_evaluator( lambda ev: ev.foreach_env( lambda env: env.set_phase(phase))) @@ -382,7 +382,7 @@ You can use the `data output API `__ to save episode traces Log Verbosity ~~~~~~~~~~~~~ -You can control the agent log level via the ``"log_level"`` flag. Valid values are "INFO" (default), "DEBUG", "WARN", and "ERROR". This can be used to increase or decrease the verbosity of internal logging. For example: +You can control the trainer log level via the ``"log_level"`` flag. Valid values are "INFO" (default), "DEBUG", "WARN", and "ERROR". This can be used to increase or decrease the verbosity of internal logging. For example: .. code-block:: bash diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index 7c019df9c..60251d9ed 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -42,6 +42,16 @@ Environments * `Interfacing with External Agents `__ * `Advanced Integrations `__ +Models and Preprocessors +------------------------ +* `RLlib Models and Preprocessors Overview `__ +* `Custom Models (TensorFlow) `__ +* `Custom Models (PyTorch) `__ +* `Custom Preprocessors `__ +* `Supervised Model Losses `__ +* `Variable-length / Parametric Action Spaces `__ +* `Customizing Policy Graphs `__ + Algorithms ---------- @@ -79,16 +89,6 @@ Algorithms - `Advantage Re-Weighted Imitation Learning (MARWIL) `__ -Models and Preprocessors ------------------------- -* `RLlib Models and Preprocessors Overview `__ -* `Custom Models (TensorFlow) `__ -* `Custom Models (PyTorch) `__ -* `Custom Preprocessors `__ -* `Supervised Model Losses `__ -* `Variable-length / Parametric Action Spaces `__ -* `Customizing Policy Graphs `__ - Offline Datasets ---------------- * `Working with Offline Datasets `__ diff --git a/python/ray/rllib/agents/__init__.py b/python/ray/rllib/agents/__init__.py index da8494a2b..4eea0e678 100644 --- a/python/ray/rllib/agents/__init__.py +++ b/python/ray/rllib/agents/__init__.py @@ -1,3 +1,4 @@ -from ray.rllib.agents.agent import Agent, with_common_config +from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.agents.agent import Agent -__all__ = ["Agent", "with_common_config"] +__all__ = ["Agent", "Trainer", "with_common_config"] diff --git a/python/ray/rllib/agents/a3c/__init__.py b/python/ray/rllib/agents/a3c/__init__.py index 7d4303263..9c8205389 100644 --- a/python/ray/rllib/agents/a3c/__init__.py +++ b/python/ray/rllib/agents/a3c/__init__.py @@ -1,4 +1,10 @@ -from ray.rllib.agents.a3c.a3c import A3CAgent, DEFAULT_CONFIG -from ray.rllib.agents.a3c.a2c import A2CAgent +from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG +from ray.rllib.agents.a3c.a2c import A2CTrainer +from ray.rllib.utils import renamed_class -__all__ = ["A2CAgent", "A3CAgent", "DEFAULT_CONFIG"] +A2CAgent = renamed_class(A2CTrainer) +A3CAgent = renamed_class(A3CTrainer) + +__all__ = [ + "A2CAgent", "A3CAgent", "A2CTrainer", "A3CTrainer", "DEFAULT_CONFIG" +] diff --git a/python/ray/rllib/agents/a3c/a2c.py b/python/ray/rllib/agents/a3c/a2c.py index 17e84ca5b..08de9e2b6 100644 --- a/python/ray/rllib/agents/a3c/a2c.py +++ b/python/ray/rllib/agents/a3c/a2c.py @@ -2,7 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.rllib.agents.a3c.a3c import A3CAgent, DEFAULT_CONFIG as A3C_CONFIG +from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG as A3C_CONFIG from ray.rllib.optimizers import SyncSamplesOptimizer from ray.rllib.utils.annotations import override from ray.rllib.utils import merge_dicts @@ -17,13 +17,13 @@ A2C_DEFAULT_CONFIG = merge_dicts( ) -class A2CAgent(A3CAgent): - """Synchronous variant of the A3CAgent.""" +class A2CTrainer(A3CTrainer): + """Synchronous variant of the A3CTrainer.""" - _agent_name = "A2C" + _name = "A2C" _default_config = A2C_DEFAULT_CONFIG - @override(A3CAgent) + @override(A3CTrainer) def _make_optimizer(self): return SyncSamplesOptimizer( self.local_evaluator, self.remote_evaluators, diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index 586a51123..199155a23 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -5,7 +5,7 @@ from __future__ import print_function import time from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph -from ray.rllib.agents.agent import Agent, with_common_config +from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.optimizers import AsyncGradientsOptimizer from ray.rllib.utils.annotations import override @@ -38,14 +38,14 @@ DEFAULT_CONFIG = with_common_config({ # yapf: enable -class A3CAgent(Agent): +class A3CTrainer(Trainer): """A3C implementations in TensorFlow and PyTorch.""" - _agent_name = "A3C" + _name = "A3C" _default_config = DEFAULT_CONFIG _policy_graph = A3CPolicyGraph - @override(Agent) + @override(Trainer) def _init(self, config, env_creator): if config["use_pytorch"]: from ray.rllib.agents.a3c.a3c_torch_policy_graph import \ @@ -63,7 +63,7 @@ class A3CAgent(Agent): env_creator, policy_cls, config["num_workers"]) self.optimizer = self._make_optimizer() - @override(Agent) + @override(Trainer) def _train(self): prev_steps = self.optimizer.num_steps_sampled start = time.time() diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 256168e54..5b0ecf268 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -2,797 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from datetime import datetime -import copy -import logging -import os -import pickle -import six -import tempfile -import tensorflow as tf -from types import FunctionType +from ray.rllib.agents.trainer import Trainer +from ray.rllib.utils import renamed_class -import ray -from ray.exceptions import RayError -from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \ - ShuffledInput -from ray.rllib.models import MODEL_DEFAULTS -from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \ - _validate_multiagent_config -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer -from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI -from ray.rllib.utils import FilterManager, deep_update, merge_dicts -from ray.tune.registry import ENV_CREATOR, register_env, _global_registry -from ray.tune.trainable import Trainable -from ray.tune.trial import Resources, ExportFormat -from ray.tune.logger import UnifiedLogger -from ray.tune.result import DEFAULT_RESULTS_DIR - -logger = logging.getLogger(__name__) - -# Max number of times to retry a worker failure. We shouldn't try too many -# times in a row since that would indicate a persistent cluster issue. -MAX_WORKER_FAILURE_RETRIES = 3 - -# yapf: disable -# __sphinx_doc_begin__ -COMMON_CONFIG = { - # === Debugging === - # Whether to write episode stats and videos to the agent log dir - "monitor": False, - # Set the ray.rllib.* log level for the agent process and its evaluators. - # Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also - # periodically print out summaries of relevant internal dataflow (this is - # also printed out once at startup at the INFO level). - "log_level": "INFO", - # Callbacks that will be run during various phases of training. These all - # take a single "info" dict as an argument. For episode callbacks, custom - # metrics can be attached to the episode by updating the episode object's - # custom metrics dict (see examples/custom_metrics_and_callbacks.py). - "callbacks": { - "on_episode_start": None, # arg: {"env": .., "episode": ...} - "on_episode_step": None, # arg: {"env": .., "episode": ...} - "on_episode_end": None, # arg: {"env": .., "episode": ...} - "on_sample_end": None, # arg: {"samples": .., "evaluator": ...} - "on_train_result": None, # arg: {"agent": ..., "result": ...} - }, - # Whether to attempt to continue training if a worker crashes. - "ignore_worker_failures": False, - - # === Policy === - # Arguments to pass to model. See models/catalog.py for a full list of the - # available model options. - "model": MODEL_DEFAULTS, - # Arguments to pass to the policy optimizer. These vary by optimizer. - "optimizer": {}, - - # === Environment === - # Discount factor of the MDP - "gamma": 0.99, - # Number of steps after which the episode is forced to terminate. Defaults - # to `env.spec.max_episode_steps` (if present) for Gym envs. - "horizon": None, - # Calculate rewards but don't reset the environment when the horizon is - # hit. This allows value estimation and RNN state to span across logical - # episodes denoted by horizon. This only has an effect if horizon != inf. - "soft_horizon": False, - # Arguments to pass to the env creator - "env_config": {}, - # Environment name can also be passed via config - "env": None, - # Whether to clip rewards prior to experience postprocessing. Setting to - # None means clip for Atari only. - "clip_rewards": None, - # Whether to np.clip() actions to the action space low/high range spec. - "clip_actions": True, - # Whether to use rllib or deepmind preprocessors by default - "preprocessor_pref": "deepmind", - - # === Resources === - # Number of actors used for parallelism - "num_workers": 2, - # Number of GPUs to allocate to the driver. Note that not all algorithms - # can take advantage of driver GPUs. This can be fraction (e.g., 0.3 GPUs). - "num_gpus": 0, - # Number of CPUs to allocate per worker. - "num_cpus_per_worker": 1, - # Number of GPUs to allocate per worker. This can be fractional. - "num_gpus_per_worker": 0, - # Any custom resources to allocate per worker. - "custom_resources_per_worker": {}, - # Number of CPUs to allocate for the driver. Note: this only takes effect - # when running in Tune. - "num_cpus_for_driver": 1, - - # === Execution === - # Number of environments to evaluate vectorwise per worker. - "num_envs_per_worker": 1, - # Default sample batch size (unroll length). Batches of this size are - # collected from workers until train_batch_size is met. When using - # multiple envs per worker, this is multiplied by num_envs_per_worker. - "sample_batch_size": 200, - # Training batch size, if applicable. Should be >= sample_batch_size. - # Samples batches will be concatenated together to this size for training. - "train_batch_size": 200, - # Whether to rollout "complete_episodes" or "truncate_episodes" - "batch_mode": "truncate_episodes", - # (Deprecated) Use a background thread for sampling (slightly off-policy) - "sample_async": False, - # Element-wise observation filter, either "NoFilter" or "MeanStdFilter" - "observation_filter": "NoFilter", - # Whether to synchronize the statistics of remote filters. - "synchronize_filters": True, - # Configure TF for single-process operation by default - "tf_session_args": { - # note: overriden by `local_evaluator_tf_session_args` - "intra_op_parallelism_threads": 2, - "inter_op_parallelism_threads": 2, - "gpu_options": { - "allow_growth": True, - }, - "log_device_placement": False, - "device_count": { - "CPU": 1 - }, - "allow_soft_placement": True, # required by PPO multi-gpu - }, - # Override the following tf session args on the local evaluator - "local_evaluator_tf_session_args": { - # Allow a higher level of parallelism by default, but not unlimited - # since that can cause crashes with many concurrent drivers. - "intra_op_parallelism_threads": 8, - "inter_op_parallelism_threads": 8, - }, - # Whether to LZ4 compress individual observations - "compress_observations": False, - # Drop metric batches from unresponsive workers after this many seconds - "collect_metrics_timeout": 180, - # Smooth metrics over this many episodes. - "metrics_smoothing_episodes": 100, - # If using num_envs_per_worker > 1, whether to create those new envs in - # remote processes instead of in the same worker. This adds overheads, but - # can make sense if your envs can take much time to step / reset - # (e.g., for StarCraft). Use this cautiously; overheads are significant. - "remote_worker_envs": False, - # Timeout that remote workers are waiting when polling environments. - # 0 (continue when at least one env is ready) is a reasonable default, - # but optimal value could be obtained by measuring your environment - # step / reset and model inference perf. - "remote_env_batch_wait_ms": 0, - - # === Offline Datasets === - # Specify how to generate experiences: - # - "sampler": generate experiences via online simulation (default) - # - a local directory or file glob expression (e.g., "/tmp/*.json") - # - a list of individual file paths/URIs (e.g., ["/tmp/1.json", - # "s3://bucket/2.json"]) - # - a dict with string keys and sampling probabilities as values (e.g., - # {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}). - # - a function that returns a rllib.offline.InputReader - "input": "sampler", - # Specify how to evaluate the current policy. This only has an effect when - # reading offline experiences. Available options: - # - "wis": the weighted step-wise importance sampling estimator. - # - "is": the step-wise importance sampling estimator. - # - "simulation": run the environment in the background, but use - # this data for evaluation only and not for learning. - "input_evaluation": ["is", "wis"], - # Whether to run postprocess_trajectory() on the trajectory fragments from - # offline inputs. Note that postprocessing will be done using the *current* - # policy, not the *behaviour* policy, which is typically undesirable for - # on-policy algorithms. - "postprocess_inputs": False, - # If positive, input batches will be shuffled via a sliding window buffer - # of this number of batches. Use this if the input data is not in random - # enough order. Input is delayed until the shuffle buffer is filled. - "shuffle_buffer_size": 0, - # Specify where experiences should be saved: - # - None: don't save any experiences - # - "logdir" to save to the agent log dir - # - a path/URI to save to a custom output directory (e.g., "s3://bucket/") - # - a function that returns a rllib.offline.OutputWriter - "output": None, - # What sample batch columns to LZ4 compress in the output data. - "output_compress_columns": ["obs", "new_obs"], - # Max output file size before rolling over to a new file. - "output_max_file_size": 64 * 1024 * 1024, - - # === Multiagent === - "multiagent": { - # Map from policy ids to tuples of (policy_graph_cls, obs_space, - # act_space, config). See policy_evaluator.py for more info. - "policy_graphs": {}, - # Function mapping agent ids to policy ids. - "policy_mapping_fn": None, - # Optional whitelist of policies to train, or None for all policies. - "policies_to_train": None, - }, -} -# __sphinx_doc_end__ -# yapf: enable - - -@DeveloperAPI -def with_common_config(extra_config): - """Returns the given config dict merged with common agent confs.""" - - return with_base_config(COMMON_CONFIG, extra_config) - - -def with_base_config(base_config, extra_config): - """Returns the given config dict merged with a base agent conf.""" - - config = copy.deepcopy(base_config) - config.update(extra_config) - return config - - -@PublicAPI -class Agent(Trainable): - """All RLlib agents extend this base class. - - Agent objects retain internal model state between calls to train(), so - you should create a new agent instance for each training session. - - Attributes: - env_creator (func): Function that creates a new training env. - config (obj): Algorithm-specific configuration data. - logdir (str): Directory in which training outputs should be placed. - """ - - _allow_unknown_configs = False - _allow_unknown_subkeys = [ - "tf_session_args", "env_config", "model", "optimizer", "multiagent", - "custom_resources_per_worker" - ] - - @PublicAPI - def __init__(self, config=None, env=None, logger_creator=None): - """Initialize an RLLib agent. - - Args: - config (dict): Algorithm-specific configuration data. - env (str): Name of the environment to use. Note that this can also - be specified as the `env` key in config. - logger_creator (func): Function that creates a ray.tune.Logger - object. If unspecified, a default logger is created. - """ - - config = config or {} - - # Vars to synchronize to evaluators on each train call - self.global_vars = {"timestep": 0} - - # Agents allow env ids to be passed directly to the constructor. - self._env_id = self._register_if_needed(env or config.get("env")) - - # Create a default logger creator if no logger_creator is specified - if logger_creator is None: - timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") - logdir_prefix = "{}_{}_{}".format(self._agent_name, self._env_id, - timestr) - - def default_logger_creator(config): - """Creates a Unified logger with a default logdir prefix - containing the agent name and the env id - """ - if not os.path.exists(DEFAULT_RESULTS_DIR): - os.makedirs(DEFAULT_RESULTS_DIR) - logdir = tempfile.mkdtemp( - prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR) - return UnifiedLogger(config, logdir, None) - - logger_creator = default_logger_creator - - Trainable.__init__(self, config, logger_creator) - - @classmethod - @override(Trainable) - def default_resource_request(cls, config): - cf = dict(cls._default_config, **config) - Agent._validate_config(cf) - # TODO(ekl): add custom resources here once tune supports them - return Resources( - cpu=cf["num_cpus_for_driver"], - gpu=cf["num_gpus"], - extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], - extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) - - @override(Trainable) - @PublicAPI - def train(self): - """Overrides super.train to synchronize global vars.""" - - if self._has_policy_optimizer(): - self.global_vars["timestep"] = self.optimizer.num_steps_sampled - self.optimizer.local_evaluator.set_global_vars(self.global_vars) - for ev in self.optimizer.remote_evaluators: - ev.set_global_vars.remote(self.global_vars) - logger.debug("updated global vars: {}".format(self.global_vars)) - - result = None - for _ in range(1 + MAX_WORKER_FAILURE_RETRIES): - try: - result = Trainable.train(self) - except RayError as e: - if self.config["ignore_worker_failures"]: - logger.exception( - "Error in train call, attempting to recover") - self._try_recover() - else: - logger.info( - "Worker crashed during call to train(). To attempt to " - "continue training without the failed worker, set " - "`'ignore_worker_failures': True`.") - raise e - else: - break - if result is None: - raise RuntimeError("Failed to recover from worker crash") - - if (self.config.get("observation_filter", "NoFilter") != "NoFilter" - and hasattr(self, "local_evaluator")): - FilterManager.synchronize( - self.local_evaluator.filters, - self.remote_evaluators, - update_remote=self.config["synchronize_filters"]) - logger.debug("synchronized filters: {}".format( - self.local_evaluator.filters)) - - if self._has_policy_optimizer(): - result["num_healthy_workers"] = len( - self.optimizer.remote_evaluators) - return result - - @override(Trainable) - def _log_result(self, result): - if self.config["callbacks"].get("on_train_result"): - self.config["callbacks"]["on_train_result"]({ - "agent": self, - "result": result, - }) - # log after the callback is invoked, so that the user has a chance - # to mutate the result - Trainable._log_result(self, result) - - @override(Trainable) - def _setup(self, config): - env = self._env_id - if env: - config["env"] = env - if _global_registry.contains(ENV_CREATOR, env): - self.env_creator = _global_registry.get(ENV_CREATOR, env) - else: - import gym # soft dependency - self.env_creator = lambda env_config: gym.make(env) - else: - self.env_creator = lambda env_config: None - - # Merge the supplied config with the class default - merged_config = copy.deepcopy(self._default_config) - merged_config = deep_update(merged_config, config, - self._allow_unknown_configs, - self._allow_unknown_subkeys) - self.raw_user_config = config - self.config = merged_config - Agent._validate_config(self.config) - if self.config.get("log_level"): - logging.getLogger("ray.rllib").setLevel(self.config["log_level"]) - - # TODO(ekl) setting the graph is unnecessary for PyTorch agents - with tf.Graph().as_default(): - self._init(self.config, self.env_creator) - - @override(Trainable) - def _stop(self): - # Call stop on all evaluators to release resources - if hasattr(self, "local_evaluator"): - self.local_evaluator.stop() - if hasattr(self, "remote_evaluators"): - for ev in self.remote_evaluators: - ev.stop.remote() - - # workaround for https://github.com/ray-project/ray/issues/1516 - if hasattr(self, "remote_evaluators"): - for ev in self.remote_evaluators: - ev.__ray_terminate__.remote() - - if hasattr(self, "optimizer"): - self.optimizer.stop() - - @override(Trainable) - def _save(self, checkpoint_dir): - checkpoint_path = os.path.join(checkpoint_dir, - "checkpoint-{}".format(self.iteration)) - pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) - return checkpoint_path - - @override(Trainable) - def _restore(self, checkpoint_path): - extra_data = pickle.load(open(checkpoint_path, "rb")) - self.__setstate__(extra_data) - - @DeveloperAPI - def _init(self, config, env_creator): - """Subclasses should override this for custom initialization.""" - - raise NotImplementedError - - @PublicAPI - def compute_action(self, - observation, - state=None, - prev_action=None, - prev_reward=None, - info=None, - policy_id=DEFAULT_POLICY_ID): - """Computes an action for the specified policy. - - Note that you can also access the policy object through - self.get_policy(policy_id) and call compute_actions() on it directly. - - Arguments: - observation (obj): observation from the environment. - state (list): RNN hidden state, if any. If state is not None, - then all of compute_single_action(...) is returned - (computed action, rnn state, logits dictionary). - Otherwise compute_single_action(...)[0] is - returned (computed action). - prev_action (obj): previous action value, if any - prev_reward (int): previous reward, if any - info (dict): info object, if any - policy_id (str): policy to query (only applies to multi-agent). - """ - - if state is None: - state = [] - preprocessed = self.local_evaluator.preprocessors[policy_id].transform( - observation) - filtered_obs = self.local_evaluator.filters[policy_id]( - preprocessed, update=False) - if state: - return self.get_policy(policy_id).compute_single_action( - filtered_obs, - state, - prev_action, - prev_reward, - info, - clip_actions=self.config["clip_actions"]) - return self.get_policy(policy_id).compute_single_action( - filtered_obs, - state, - prev_action, - prev_reward, - info, - clip_actions=self.config["clip_actions"])[0] - - @property - def iteration(self): - """Current training iter, auto-incremented with each train() call.""" - - return self._iteration - - @property - def _agent_name(self): - """Subclasses should override this to declare their name.""" - - raise NotImplementedError - - @property - def _default_config(self): - """Subclasses should override this to declare their default config.""" - - raise NotImplementedError - - @PublicAPI - def get_policy(self, policy_id=DEFAULT_POLICY_ID): - """Return policy graph for the specified id, or None. - - Arguments: - policy_id (str): id of policy graph to return. - """ - - return self.local_evaluator.get_policy(policy_id) - - @PublicAPI - def get_weights(self, policies=None): - """Return a dictionary of policy ids to weights. - - Arguments: - policies (list): Optional list of policies to return weights for, - or None for all policies. - """ - return self.local_evaluator.get_weights(policies) - - @PublicAPI - def set_weights(self, weights): - """Set policy weights by policy id. - - Arguments: - weights (dict): Map of policy ids to weights to set. - """ - self.local_evaluator.set_weights(weights) - - @DeveloperAPI - def make_local_evaluator(self, - env_creator, - policy_graph, - extra_config=None): - """Convenience method to return configured local evaluator.""" - - return self._make_evaluator( - PolicyEvaluator, - env_creator, - policy_graph, - 0, - merge_dicts( - # important: allow local tf to use more CPUs for optimization - merge_dicts( - self.config, { - "tf_session_args": self. - config["local_evaluator_tf_session_args"] - }), - extra_config or {})) - - @DeveloperAPI - def make_remote_evaluators(self, env_creator, policy_graph, count): - """Convenience method to return a number of remote evaluators.""" - - remote_args = { - "num_cpus": self.config["num_cpus_per_worker"], - "num_gpus": self.config["num_gpus_per_worker"], - "resources": self.config["custom_resources_per_worker"], - } - - cls = PolicyEvaluator.as_remote(**remote_args).remote - - return [ - self._make_evaluator(cls, env_creator, policy_graph, i + 1, - self.config) for i in range(count) - ] - - @DeveloperAPI - def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID): - """Export policy model with given policy_id to local directory. - - Arguments: - export_dir (string): Writable local directory. - policy_id (string): Optional policy id to export. - - Example: - >>> agent = MyAgent() - >>> for _ in range(10): - >>> agent.train() - >>> agent.export_policy_model("/tmp/export_dir") - """ - self.local_evaluator.export_policy_model(export_dir, policy_id) - - @DeveloperAPI - def export_policy_checkpoint(self, - export_dir, - filename_prefix="model", - policy_id=DEFAULT_POLICY_ID): - """Export tensorflow policy model checkpoint to local directory. - - Arguments: - export_dir (string): Writable local directory. - filename_prefix (string): file name prefix of checkpoint files. - policy_id (string): Optional policy id to export. - - Example: - >>> agent = MyAgent() - >>> for _ in range(10): - >>> agent.train() - >>> agent.export_policy_checkpoint("/tmp/export_dir") - """ - self.local_evaluator.export_policy_checkpoint( - export_dir, filename_prefix, policy_id) - - @DeveloperAPI - def collect_metrics(self, selected_evaluators=None): - """Collects metrics from the remote evaluators of this agent. - - This is the same data as returned by a call to train(). - """ - return self.optimizer.collect_metrics( - self.config["collect_metrics_timeout"], - min_history=self.config["metrics_smoothing_episodes"], - selected_evaluators=selected_evaluators) - - @classmethod - def resource_help(cls, config): - return ("\n\nYou can adjust the resource requests of RLlib agents by " - "setting `num_workers`, `num_gpus`, and other configs. See " - "the DEFAULT_CONFIG defined by each agent for more info.\n\n" - "The config of this agent is: {}".format(config)) - - @staticmethod - def _validate_config(config): - if "gpu" in config: - raise ValueError( - "The `gpu` config is deprecated, please use `num_gpus=0|1` " - "instead.") - if "gpu_fraction" in config: - raise ValueError( - "The `gpu_fraction` config is deprecated, please use " - "`num_gpus=` instead.") - if "use_gpu_for_workers" in config: - raise ValueError( - "The `use_gpu_for_workers` config is deprecated, please use " - "`num_gpus_per_worker=1` instead.") - if type(config["input_evaluation"]) != list: - raise ValueError( - "`input_evaluation` must be a list of strings, got {}".format( - config["input_evaluation"])) - - def _try_recover(self): - """Try to identify and blacklist any unhealthy workers. - - This method is called after an unexpected remote error is encountered - from a worker. It issues check requests to all current workers and - blacklists any that respond with error. If no healthy workers remain, - an error is raised. - """ - - if not self._has_policy_optimizer(): - raise NotImplementedError( - "Recovery is not supported for this algorithm") - - logger.info("Health checking all workers...") - checks = [] - for ev in self.optimizer.remote_evaluators: - _, obj_id = ev.sample_with_count.remote() - checks.append(obj_id) - - healthy_evaluators = [] - for i, obj_id in enumerate(checks): - ev = self.optimizer.remote_evaluators[i] - try: - ray.get(obj_id) - healthy_evaluators.append(ev) - logger.info("Worker {} looks healthy".format(i + 1)) - except RayError: - logger.exception("Blacklisting worker {}".format(i + 1)) - try: - ev.__ray_terminate__.remote() - except Exception: - logger.exception("Error terminating unhealthy worker") - - if len(healthy_evaluators) < 1: - raise RuntimeError( - "Not enough healthy workers remain to continue.") - - self.optimizer.reset(healthy_evaluators) - - def _has_policy_optimizer(self): - return hasattr(self, "optimizer") and isinstance( - self.optimizer, PolicyOptimizer) - - def _make_evaluator(self, cls, env_creator, policy_graph, worker_index, - config): - def session_creator(): - logger.debug("Creating TF session {}".format( - config["tf_session_args"])) - return tf.Session( - config=tf.ConfigProto(**config["tf_session_args"])) - - if isinstance(config["input"], FunctionType): - input_creator = config["input"] - elif config["input"] == "sampler": - input_creator = (lambda ioctx: ioctx.default_sampler_input()) - elif isinstance(config["input"], dict): - input_creator = (lambda ioctx: ShuffledInput( - MixedInput(config["input"], ioctx), config[ - "shuffle_buffer_size"])) - else: - input_creator = (lambda ioctx: ShuffledInput( - JsonReader(config["input"], ioctx), config[ - "shuffle_buffer_size"])) - - if isinstance(config["output"], FunctionType): - output_creator = config["output"] - elif config["output"] is None: - output_creator = (lambda ioctx: NoopOutput()) - elif config["output"] == "logdir": - output_creator = (lambda ioctx: JsonWriter( - ioctx.log_dir, - ioctx, - max_file_size=config["output_max_file_size"], - compress_columns=config["output_compress_columns"])) - else: - output_creator = (lambda ioctx: JsonWriter( - config["output"], - ioctx, - max_file_size=config["output_max_file_size"], - compress_columns=config["output_compress_columns"])) - - if config["input"] == "sampler": - input_evaluation = [] - else: - input_evaluation = config["input_evaluation"] - - # Fill in the default policy graph if 'None' is specified in multiagent - if self.config["multiagent"]["policy_graphs"]: - tmp = self.config["multiagent"]["policy_graphs"] - _validate_multiagent_config(tmp, allow_none_graph=True) - for k, v in tmp.items(): - if v[0] is None: - tmp[k] = (policy_graph, v[1], v[2], v[3]) - policy_graph = tmp - - return cls( - env_creator, - policy_graph, - policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], - policies_to_train=self.config["multiagent"]["policies_to_train"], - tf_session_creator=(session_creator - if config["tf_session_args"] else None), - batch_steps=config["sample_batch_size"], - batch_mode=config["batch_mode"], - episode_horizon=config["horizon"], - preprocessor_pref=config["preprocessor_pref"], - sample_async=config["sample_async"], - compress_observations=config["compress_observations"], - num_envs=config["num_envs_per_worker"], - observation_filter=config["observation_filter"], - clip_rewards=config["clip_rewards"], - clip_actions=config["clip_actions"], - env_config=config["env_config"], - model_config=config["model"], - policy_config=config, - worker_index=worker_index, - monitor_path=self.logdir if config["monitor"] else None, - log_dir=self.logdir, - log_level=config["log_level"], - callbacks=config["callbacks"], - input_creator=input_creator, - input_evaluation=input_evaluation, - output_creator=output_creator, - remote_worker_envs=config["remote_worker_envs"], - remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"], - soft_horizon=config["soft_horizon"], - _fake_sampler=config.get("_fake_sampler", False)) - - @override(Trainable) - def _export_model(self, export_formats, export_dir): - ExportFormat.validate(export_formats) - exported = {} - if ExportFormat.CHECKPOINT in export_formats: - path = os.path.join(export_dir, ExportFormat.CHECKPOINT) - self.export_policy_checkpoint(path) - exported[ExportFormat.CHECKPOINT] = path - if ExportFormat.MODEL in export_formats: - path = os.path.join(export_dir, ExportFormat.MODEL) - self.export_policy_model(path) - exported[ExportFormat.MODEL] = path - return exported - - def __getstate__(self): - state = {} - if hasattr(self, "local_evaluator"): - state["evaluator"] = self.local_evaluator.save() - if hasattr(self, "optimizer") and hasattr(self.optimizer, "save"): - state["optimizer"] = self.optimizer.save() - return state - - def __setstate__(self, state): - if "evaluator" in state: - self.local_evaluator.restore(state["evaluator"]) - remote_state = ray.put(state["evaluator"]) - for r in self.remote_evaluators: - r.restore.remote(remote_state) - if "optimizer" in state: - self.optimizer.restore(state["optimizer"]) - - def _register_if_needed(self, env_object): - if isinstance(env_object, six.string_types): - return env_object - elif isinstance(env_object, type): - name = env_object.__name__ - register_env(name, lambda config: env_object(config)) - return name - raise ValueError( - "{} is an invalid env specification. ".format(env_object) + - "You can specify a custom env as either a class " - "(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").") +Agent = renamed_class(Trainer) diff --git a/python/ray/rllib/agents/ars/__init__.py b/python/ray/rllib/agents/ars/__init__.py index 5e7809d38..a1120ff8c 100644 --- a/python/ray/rllib/agents/ars/__init__.py +++ b/python/ray/rllib/agents/ars/__init__.py @@ -1,3 +1,6 @@ -from ray.rllib.agents.ars.ars import (ARSAgent, DEFAULT_CONFIG) +from ray.rllib.agents.ars.ars import (ARSTrainer, DEFAULT_CONFIG) +from ray.rllib.utils import renamed_class -__all__ = ["ARSAgent", "DEFAULT_CONFIG"] +ARSAgent = renamed_class(ARSTrainer) + +__all__ = ["ARSAgent", "ARSTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/ars/ars.py b/python/ray/rllib/agents/ars/ars.py index 4074eff31..9136bf393 100644 --- a/python/ray/rllib/agents/ars/ars.py +++ b/python/ray/rllib/agents/ars/ars.py @@ -12,7 +12,7 @@ import numpy as np import time import ray -from ray.rllib.agents import Agent, with_common_config +from ray.rllib.agents import Trainer, with_common_config from ray.rllib.agents.ars import optimizers from ray.rllib.agents.ars import policies @@ -157,13 +157,13 @@ class Worker(object): eval_lengths=eval_lengths) -class ARSAgent(Agent): +class ARSTrainer(Trainer): """Large-scale implementation of Augmented Random Search in Ray.""" - _agent_name = "ARS" + _name = "ARS" _default_config = DEFAULT_CONFIG - @override(Agent) + @override(Trainer) def _init(self, config, env_creator): env = env_creator(config["env_config"]) from ray.rllib import models @@ -195,7 +195,7 @@ class ARSAgent(Agent): self.reward_list = [] self.tstart = time.time() - @override(Agent) + @override(Trainer) def _train(self): config = self.config @@ -291,13 +291,13 @@ class ARSAgent(Agent): return result - @override(Agent) + @override(Trainer) def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for w in self.workers: w.__ray_terminate__.remote() - @override(Agent) + @override(Trainer) def compute_action(self, observation): return self.policy.compute(observation, update=True)[0] diff --git a/python/ray/rllib/agents/ddpg/__init__.py b/python/ray/rllib/agents/ddpg/__init__.py index 7d3390b20..5d2099187 100644 --- a/python/ray/rllib/agents/ddpg/__init__.py +++ b/python/ray/rllib/agents/ddpg/__init__.py @@ -2,7 +2,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.rllib.agents.ddpg.apex import ApexDDPGAgent -from ray.rllib.agents.ddpg.ddpg import DDPGAgent, DEFAULT_CONFIG +from ray.rllib.agents.ddpg.apex import ApexDDPGTrainer +from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, DEFAULT_CONFIG +from ray.rllib.utils import renamed_class -__all__ = ["DDPGAgent", "ApexDDPGAgent", "DEFAULT_CONFIG"] +ApexDDPGAgent = renamed_class(ApexDDPGTrainer) +DDPGAgent = renamed_class(DDPGTrainer) + +__all__ = [ + "DDPGAgent", "ApexDDPGAgent", "DDPGTrainer", "ApexDDPGTrainer", + "DEFAULT_CONFIG" +] diff --git a/python/ray/rllib/agents/ddpg/apex.py b/python/ray/rllib/agents/ddpg/apex.py index 994ba7938..24edbb226 100644 --- a/python/ray/rllib/agents/ddpg/apex.py +++ b/python/ray/rllib/agents/ddpg/apex.py @@ -2,7 +2,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.rllib.agents.ddpg.ddpg import DDPGAgent, DEFAULT_CONFIG as DDPG_CONFIG +from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, \ + DEFAULT_CONFIG as DDPG_CONFIG from ray.rllib.utils.annotations import override from ray.rllib.utils import merge_dicts @@ -32,17 +33,17 @@ APEX_DDPG_DEFAULT_CONFIG = merge_dicts( ) -class ApexDDPGAgent(DDPGAgent): +class ApexDDPGTrainer(DDPGTrainer): """DDPG variant that uses the Ape-X distributed policy optimizer. By default, this is configured for a large single node (32 cores). For running in a large cluster, increase the `num_workers` config var. """ - _agent_name = "APEX_DDPG" + _name = "APEX_DDPG" _default_config = APEX_DDPG_DEFAULT_CONFIG - @override(DDPGAgent) + @override(DDPGTrainer) def update_target_if_needed(self): # Ape-X updates based on num steps trained, not sampled if self.optimizer.num_steps_trained - self.last_target_update_ts > \ diff --git a/python/ray/rllib/agents/ddpg/ddpg.py b/python/ray/rllib/agents/ddpg/ddpg.py index 15699a097..e5031ca11 100644 --- a/python/ray/rllib/agents/ddpg/ddpg.py +++ b/python/ray/rllib/agents/ddpg/ddpg.py @@ -2,8 +2,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.rllib.agents.agent import with_common_config -from ray.rllib.agents.dqn.dqn import DQNAgent +from ray.rllib.agents.trainer import with_common_config +from ray.rllib.agents.dqn.dqn import DQNTrainer from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule @@ -132,13 +132,13 @@ DEFAULT_CONFIG = with_common_config({ # yapf: enable -class DDPGAgent(DQNAgent): +class DDPGTrainer(DQNTrainer): """DDPG implementation in TensorFlow.""" - _agent_name = "DDPG" + _name = "DDPG" _default_config = DEFAULT_CONFIG _policy_graph = DDPGPolicyGraph - @override(DQNAgent) + @override(DQNTrainer) def _make_exploration_schedule(self, worker_index): # Override DQN's schedule to take into account `noise_scale` if self.config["per_worker_exploration"]: diff --git a/python/ray/rllib/agents/dqn/__init__.py b/python/ray/rllib/agents/dqn/__init__.py index b46c249e6..415ceae6c 100644 --- a/python/ray/rllib/agents/dqn/__init__.py +++ b/python/ray/rllib/agents/dqn/__init__.py @@ -2,7 +2,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.rllib.agents.dqn.apex import ApexAgent -from ray.rllib.agents.dqn.dqn import DQNAgent, DEFAULT_CONFIG +from ray.rllib.agents.dqn.apex import ApexTrainer +from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG +from ray.rllib.utils import renamed_class -__all__ = ["ApexAgent", "DQNAgent", "DEFAULT_CONFIG"] +DQNAgent = renamed_class(DQNTrainer) +ApexAgent = renamed_class(ApexTrainer) + +__all__ = [ + "DQNAgent", "ApexAgent", "ApexTrainer", "DQNTrainer", "DEFAULT_CONFIG" +] diff --git a/python/ray/rllib/agents/dqn/apex.py b/python/ray/rllib/agents/dqn/apex.py index fbe130dd3..27bde322a 100644 --- a/python/ray/rllib/agents/dqn/apex.py +++ b/python/ray/rllib/agents/dqn/apex.py @@ -2,7 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.rllib.agents.dqn.dqn import DQNAgent, DEFAULT_CONFIG as DQN_CONFIG +from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_CONFIG from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override @@ -36,17 +36,17 @@ APEX_DEFAULT_CONFIG = merge_dicts( # yapf: enable -class ApexAgent(DQNAgent): +class ApexTrainer(DQNTrainer): """DQN variant that uses the Ape-X distributed policy optimizer. By default, this is configured for a large single node (32 cores). For running in a large cluster, increase the `num_workers` config var. """ - _agent_name = "APEX" + _name = "APEX" _default_config = APEX_DEFAULT_CONFIG - @override(DQNAgent) + @override(DQNTrainer) def update_target_if_needed(self): # Ape-X updates based on num steps trained, not sampled if self.optimizer.num_steps_trained - self.last_target_update_ts > \ diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index 62c56cbc5..7567c6c5a 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -7,7 +7,7 @@ import time from ray import tune from ray.rllib import optimizers -from ray.rllib.agents.agent import Agent, with_common_config +from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID @@ -137,15 +137,15 @@ DEFAULT_CONFIG = with_common_config({ # yapf: enable -class DQNAgent(Agent): +class DQNTrainer(Trainer): """DQN implementation in TensorFlow.""" - _agent_name = "DQN" + _name = "DQN" _default_config = DEFAULT_CONFIG _policy_graph = DQNPolicyGraph _optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS - @override(Agent) + @override(Trainer) def _init(self, config, env_creator): self._validate_config() @@ -161,7 +161,7 @@ class DQNAgent(Agent): ] for k in self._optimizer_shared_configs: - if self._agent_name != "DQN" and k in [ + if self._name != "DQN" and k in [ "schedule_max_timesteps", "beta_annealing_fraction", "final_prioritized_replay_beta" ]: @@ -238,7 +238,7 @@ class DQNAgent(Agent): self.last_target_update_ts = 0 self.num_target_updates = 0 - @override(Agent) + @override(Trainer) def _train(self): start_timestep = self.global_timestep @@ -326,7 +326,7 @@ class DQNAgent(Agent): final_p=self.config["exploration_final_eps"]) def __getstate__(self): - state = Agent.__getstate__(self) + state = Trainer.__getstate__(self) state.update({ "num_target_updates": self.num_target_updates, "last_target_update_ts": self.last_target_update_ts, @@ -334,7 +334,7 @@ class DQNAgent(Agent): return state def __setstate__(self, state): - Agent.__setstate__(self, state) + Trainer.__setstate__(self, state) self.num_target_updates = state["num_target_updates"] self.last_target_update_ts = state["last_target_update_ts"] diff --git a/python/ray/rllib/agents/es/__init__.py b/python/ray/rllib/agents/es/__init__.py index 3ea5f2edc..d7bec2a9e 100644 --- a/python/ray/rllib/agents/es/__init__.py +++ b/python/ray/rllib/agents/es/__init__.py @@ -1,3 +1,6 @@ -from ray.rllib.agents.es.es import (ESAgent, DEFAULT_CONFIG) +from ray.rllib.agents.es.es import (ESTrainer, DEFAULT_CONFIG) +from ray.rllib.utils import renamed_class -__all__ = ["ESAgent", "DEFAULT_CONFIG"] +ESAgent = renamed_class(ESTrainer) + +__all__ = ["ESAgent", "ESTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/es/es.py b/python/ray/rllib/agents/es/es.py index fdaacc183..e6ee1be7c 100644 --- a/python/ray/rllib/agents/es/es.py +++ b/python/ray/rllib/agents/es/es.py @@ -11,7 +11,7 @@ import numpy as np import time import ray -from ray.rllib.agents import Agent, with_common_config +from ray.rllib.agents import Trainer, with_common_config from ray.rllib.agents.es import optimizers from ray.rllib.agents.es import policies @@ -163,13 +163,13 @@ class Worker(object): eval_lengths=eval_lengths) -class ESAgent(Agent): +class ESTrainer(Trainer): """Large-scale implementation of Evolution Strategies in Ray.""" - _agent_name = "ES" + _name = "ES" _default_config = DEFAULT_CONFIG - @override(Agent) + @override(Trainer) def _init(self, config, env_creator): policy_params = {"action_noise_std": 0.01} @@ -200,7 +200,7 @@ class ESAgent(Agent): self.reward_list = [] self.tstart = time.time() - @override(Agent) + @override(Trainer) def _train(self): config = self.config @@ -288,11 +288,11 @@ class ESAgent(Agent): return result - @override(Agent) + @override(Trainer) def compute_action(self, observation): return self.policy.compute(observation, update=False)[0] - @override(Agent) + @override(Trainer) def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for w in self.workers: diff --git a/python/ray/rllib/agents/impala/__init__.py b/python/ray/rllib/agents/impala/__init__.py index 087a568c3..81c64e889 100644 --- a/python/ray/rllib/agents/impala/__init__.py +++ b/python/ray/rllib/agents/impala/__init__.py @@ -1,3 +1,6 @@ -from ray.rllib.agents.impala.impala import ImpalaAgent, DEFAULT_CONFIG +from ray.rllib.agents.impala.impala import ImpalaTrainer, DEFAULT_CONFIG +from ray.rllib.utils import renamed_class -__all__ = ["ImpalaAgent", "DEFAULT_CONFIG"] +ImpalaAgent = renamed_class(ImpalaTrainer) + +__all__ = ["ImpalaAgent", "ImpalaTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index c312afd81..67914e2f8 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -6,7 +6,7 @@ import time from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph -from ray.rllib.agents.agent import Agent, with_common_config +from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.optimizers import AsyncSamplesOptimizer from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator from ray.rllib.utils.annotations import override @@ -100,14 +100,14 @@ DEFAULT_CONFIG = with_common_config({ # yapf: enable -class ImpalaAgent(Agent): +class ImpalaTrainer(Trainer): """IMPALA implementation using DeepMind's V-trace.""" - _agent_name = "IMPALA" + _name = "IMPALA" _default_config = DEFAULT_CONFIG _policy_graph = VTracePolicyGraph - @override(Agent) + @override(Trainer) def _init(self, config, env_creator): for k in OPTIMIZER_SHARED_CONFIGS: if k not in config["optimizer"]: @@ -136,7 +136,7 @@ class ImpalaAgent(Agent): @override(Trainable) def default_resource_request(cls, config): cf = dict(cls._default_config, **config) - Agent._validate_config(cf) + Trainer._validate_config(cf) return Resources( cpu=cf["num_cpus_for_driver"], gpu=cf["num_gpus"], @@ -144,7 +144,7 @@ class ImpalaAgent(Agent): cf["num_aggregation_workers"], extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) - @override(Agent) + @override(Trainer) def _train(self): prev_steps = self.optimizer.num_steps_sampled start = time.time() diff --git a/python/ray/rllib/agents/marwil/__init__.py b/python/ray/rllib/agents/marwil/__init__.py index eaa8bd87a..3d4a7cf2f 100644 --- a/python/ray/rllib/agents/marwil/__init__.py +++ b/python/ray/rllib/agents/marwil/__init__.py @@ -2,6 +2,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.rllib.agents.marwil.marwil import MARWILAgent, DEFAULT_CONFIG +from ray.rllib.agents.marwil.marwil import MARWILTrainer, DEFAULT_CONFIG -__all__ = ["MARWILAgent", "DEFAULT_CONFIG"] +__all__ = ["MARWILTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/marwil/marwil.py b/python/ray/rllib/agents/marwil/marwil.py index 782a427ad..d2b43478d 100644 --- a/python/ray/rllib/agents/marwil/marwil.py +++ b/python/ray/rllib/agents/marwil/marwil.py @@ -2,7 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.rllib.agents.agent import Agent, with_common_config +from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.agents.marwil.marwil_policy_graph import MARWILPolicyGraph from ray.rllib.optimizers import SyncBatchReplayOptimizer from ray.rllib.utils.annotations import override @@ -39,14 +39,14 @@ DEFAULT_CONFIG = with_common_config({ # yapf: enable -class MARWILAgent(Agent): +class MARWILTrainer(Trainer): """MARWIL implementation in TensorFlow.""" - _agent_name = "MARWIL" + _name = "MARWIL" _default_config = DEFAULT_CONFIG _policy_graph = MARWILPolicyGraph - @override(Agent) + @override(Trainer) def _init(self, config, env_creator): self.local_evaluator = self.make_local_evaluator( env_creator, self._policy_graph) @@ -59,7 +59,7 @@ class MARWILAgent(Agent): "train_batch_size": config["train_batch_size"], }) - @override(Agent) + @override(Trainer) def _train(self): prev_steps = self.optimizer.num_steps_sampled fetches = self.optimizer.step() diff --git a/python/ray/rllib/agents/mock.py b/python/ray/rllib/agents/mock.py index 5e421aae7..560573af7 100644 --- a/python/ray/rllib/agents/mock.py +++ b/python/ray/rllib/agents/mock.py @@ -6,13 +6,13 @@ import os import pickle import numpy as np -from ray.rllib.agents.agent import Agent, with_common_config +from ray.rllib.agents.trainer import Trainer, with_common_config -class _MockAgent(Agent): - """Mock agent for use in tests""" +class _MockTrainer(Trainer): + """Mock trainer for use in tests""" - _agent_name = "MockAgent" + _name = "MockTrainer" _default_config = with_common_config({ "mock_error": False, "persistent_error": False, @@ -57,12 +57,12 @@ class _MockAgent(Agent): return self.info -class _SigmoidFakeData(_MockAgent): - """Agent that returns sigmoid learning curves. +class _SigmoidFakeData(_MockTrainer): + """Trainer that returns sigmoid learning curves. This can be helpful for evaluating early stopping algorithms.""" - _agent_name = "SigmoidFakeData" + _name = "SigmoidFakeData" _default_config = with_common_config({ "width": 100, "height": 100, @@ -84,9 +84,9 @@ class _SigmoidFakeData(_MockAgent): info={}) -class _ParameterTuningAgent(_MockAgent): +class _ParameterTuningTrainer(_MockTrainer): - _agent_name = "ParameterTuningAgent" + _name = "ParameterTuningTrainer" _default_config = with_common_config({ "reward_amt": 10, "dummy_param": 10, @@ -108,8 +108,8 @@ class _ParameterTuningAgent(_MockAgent): def _agent_import_failed(trace): """Returns dummy agent class for if PyTorch etc. is not installed.""" - class _AgentImportFailed(Agent): - _agent_name = "AgentImportFailed" + class _AgentImportFailed(Trainer): + _name = "AgentImportFailed" _default_config = with_common_config({}) def _setup(self, config): diff --git a/python/ray/rllib/agents/pg/__init__.py b/python/ray/rllib/agents/pg/__init__.py index f0665f448..2203188a7 100644 --- a/python/ray/rllib/agents/pg/__init__.py +++ b/python/ray/rllib/agents/pg/__init__.py @@ -1,3 +1,6 @@ -from ray.rllib.agents.pg.pg import PGAgent, DEFAULT_CONFIG +from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG +from ray.rllib.utils import renamed_class -__all__ = ["PGAgent", "DEFAULT_CONFIG"] +PGAgent = renamed_class(PGTrainer) + +__all__ = ["PGAgent", "PGTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/pg/pg.py b/python/ray/rllib/agents/pg/pg.py index 356a5aae1..ba7884dd8 100644 --- a/python/ray/rllib/agents/pg/pg.py +++ b/python/ray/rllib/agents/pg/pg.py @@ -2,7 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.rllib.agents.agent import Agent, with_common_config +from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph from ray.rllib.optimizers import SyncSamplesOptimizer @@ -22,18 +22,18 @@ DEFAULT_CONFIG = with_common_config({ # yapf: enable -class PGAgent(Agent): +class PGTrainer(Trainer): """Simple policy gradient agent. This is an example agent to show how to implement algorithms in RLlib. In most cases, you will probably want to use the PPO agent instead. """ - _agent_name = "PG" + _name = "PG" _default_config = DEFAULT_CONFIG _policy_graph = PGPolicyGraph - @override(Agent) + @override(Trainer) def _init(self, config, env_creator): if config["use_pytorch"]: from ray.rllib.agents.pg.torch_pg_policy_graph import \ @@ -51,7 +51,7 @@ class PGAgent(Agent): self.optimizer = SyncSamplesOptimizer( self.local_evaluator, self.remote_evaluators, optimizer_config) - @override(Agent) + @override(Trainer) def _train(self): prev_steps = self.optimizer.num_steps_sampled self.optimizer.step() diff --git a/python/ray/rllib/agents/ppo/__init__.py b/python/ray/rllib/agents/ppo/__init__.py index 8e25f9b8f..a02cbc23c 100644 --- a/python/ray/rllib/agents/ppo/__init__.py +++ b/python/ray/rllib/agents/ppo/__init__.py @@ -1,4 +1,7 @@ -from ray.rllib.agents.ppo.ppo import (PPOAgent, DEFAULT_CONFIG) -from ray.rllib.agents.ppo.appo import APPOAgent +from ray.rllib.agents.ppo.ppo import PPOTrainer, DEFAULT_CONFIG +from ray.rllib.agents.ppo.appo import APPOTrainer +from ray.rllib.utils import renamed_class -__all__ = ["APPOAgent", "PPOAgent", "DEFAULT_CONFIG"] +PPOAgent = renamed_class(PPOTrainer) + +__all__ = ["PPOAgent", "APPOTrainer", "PPOTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/ppo/appo.py b/python/ray/rllib/agents/ppo/appo.py index c0ab74b44..ac3251775 100644 --- a/python/ray/rllib/agents/ppo/appo.py +++ b/python/ray/rllib/agents/ppo/appo.py @@ -3,7 +3,7 @@ from __future__ import division from __future__ import print_function from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOPolicyGraph -from ray.rllib.agents.agent import with_base_config +from ray.rllib.agents.trainer import with_base_config from ray.rllib.agents import impala from ray.rllib.utils.annotations import override @@ -52,13 +52,13 @@ DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, { # yapf: enable -class APPOAgent(impala.ImpalaAgent): +class APPOTrainer(impala.ImpalaTrainer): """PPO surrogate loss with IMPALA-architecture.""" - _agent_name = "APPO" + _name = "APPO" _default_config = DEFAULT_CONFIG _policy_graph = AsyncPPOPolicyGraph - @override(impala.ImpalaAgent) + @override(impala.ImpalaTrainer) def _get_policy_graph(self): return AsyncPPOPolicyGraph diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index 4c377af61..c409a4c45 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -4,7 +4,7 @@ from __future__ import print_function import logging -from ray.rllib.agents import Agent, with_common_config +from ray.rllib.agents import Trainer, with_common_config from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer from ray.rllib.utils.annotations import override @@ -63,14 +63,14 @@ DEFAULT_CONFIG = with_common_config({ # yapf: enable -class PPOAgent(Agent): +class PPOTrainer(Trainer): """Multi-GPU optimized implementation of PPO in TensorFlow.""" - _agent_name = "PPO" + _name = "PPO" _default_config = DEFAULT_CONFIG _policy_graph = PPOPolicyGraph - @override(Agent) + @override(Trainer) def _init(self, config, env_creator): self._validate_config() self.local_evaluator = self.make_local_evaluator( @@ -96,7 +96,7 @@ class PPOAgent(Agent): "straggler_mitigation": config["straggler_mitigation"], }) - @override(Agent) + @override(Trainer) def _train(self): if "observation_filter" not in self.raw_user_config: # TODO(ekl) remove this message after a few releases diff --git a/python/ray/rllib/agents/qmix/__init__.py b/python/ray/rllib/agents/qmix/__init__.py index a5a1e4993..0de9ff272 100644 --- a/python/ray/rllib/agents/qmix/__init__.py +++ b/python/ray/rllib/agents/qmix/__init__.py @@ -2,7 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG -from ray.rllib.agents.qmix.apex import ApexQMixAgent +from ray.rllib.agents.qmix.qmix import QMixTrainer, DEFAULT_CONFIG +from ray.rllib.agents.qmix.apex import ApexQMixTrainer -__all__ = ["QMixAgent", "ApexQMixAgent", "DEFAULT_CONFIG"] +__all__ = ["QMixTrainer", "ApexQMixTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/qmix/apex.py b/python/ray/rllib/agents/qmix/apex.py index dac95e0ff..f43a5ac12 100644 --- a/python/ray/rllib/agents/qmix/apex.py +++ b/python/ray/rllib/agents/qmix/apex.py @@ -4,7 +4,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG as QMIX_CONFIG +from ray.rllib.agents.qmix.qmix import QMixTrainer, \ + DEFAULT_CONFIG as QMIX_CONFIG from ray.rllib.utils.annotations import override from ray.rllib.utils import merge_dicts @@ -34,17 +35,17 @@ APEX_QMIX_DEFAULT_CONFIG = merge_dicts( ) -class ApexQMixAgent(QMixAgent): +class ApexQMixTrainer(QMixTrainer): """QMIX variant that uses the Ape-X distributed policy optimizer. By default, this is configured for a large single node (32 cores). For running in a large cluster, increase the `num_workers` config var. """ - _agent_name = "APEX_QMIX" + _name = "APEX_QMIX" _default_config = APEX_QMIX_DEFAULT_CONFIG - @override(QMixAgent) + @override(QMixTrainer) def update_target_if_needed(self): # Ape-X updates based on num steps trained, not sampled if self.optimizer.num_steps_trained - self.last_target_update_ts > \ diff --git a/python/ray/rllib/agents/qmix/qmix.py b/python/ray/rllib/agents/qmix/qmix.py index 037f7b80c..420a567d8 100644 --- a/python/ray/rllib/agents/qmix/qmix.py +++ b/python/ray/rllib/agents/qmix/qmix.py @@ -2,8 +2,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.rllib.agents.agent import with_common_config -from ray.rllib.agents.dqn.dqn import DQNAgent +from ray.rllib.agents.trainer import with_common_config +from ray.rllib.agents.dqn.dqn import DQNTrainer from ray.rllib.agents.qmix.qmix_policy_graph import QMixPolicyGraph # yapf: disable @@ -90,10 +90,10 @@ DEFAULT_CONFIG = with_common_config({ # yapf: enable -class QMixAgent(DQNAgent): +class QMixTrainer(DQNTrainer): """QMix implementation in PyTorch.""" - _agent_name = "QMIX" + _name = "QMIX" _default_config = DEFAULT_CONFIG _policy_graph = QMixPolicyGraph _optimizer_shared_configs = [ diff --git a/python/ray/rllib/agents/registry.py b/python/ray/rllib/agents/registry.py index 1ef7cb124..7b133fa81 100644 --- a/python/ray/rllib/agents/registry.py +++ b/python/ray/rllib/agents/registry.py @@ -11,77 +11,77 @@ from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS def _import_appo(): from ray.rllib.agents import ppo - return ppo.APPOAgent + return ppo.APPOTrainer def _import_qmix(): from ray.rllib.agents import qmix - return qmix.QMixAgent + return qmix.QMixTrainer def _import_apex_qmix(): from ray.rllib.agents import qmix - return qmix.ApexQMixAgent + return qmix.ApexQMixTrainer def _import_ddpg(): from ray.rllib.agents import ddpg - return ddpg.DDPGAgent + return ddpg.DDPGTrainer def _import_apex_ddpg(): from ray.rllib.agents import ddpg - return ddpg.ApexDDPGAgent + return ddpg.ApexDDPGTrainer def _import_ppo(): from ray.rllib.agents import ppo - return ppo.PPOAgent + return ppo.PPOTrainer def _import_es(): from ray.rllib.agents import es - return es.ESAgent + return es.ESTrainer def _import_ars(): from ray.rllib.agents import ars - return ars.ARSAgent + return ars.ARSTrainer def _import_dqn(): from ray.rllib.agents import dqn - return dqn.DQNAgent + return dqn.DQNTrainer def _import_apex(): from ray.rllib.agents import dqn - return dqn.ApexAgent + return dqn.ApexTrainer def _import_a3c(): from ray.rllib.agents import a3c - return a3c.A3CAgent + return a3c.A3CTrainer def _import_a2c(): from ray.rllib.agents import a3c - return a3c.A2CAgent + return a3c.A2CTrainer def _import_pg(): from ray.rllib.agents import pg - return pg.PGAgent + return pg.PGTrainer def _import_impala(): from ray.rllib.agents import impala - return impala.ImpalaAgent + return impala.ImpalaTrainer def _import_marwil(): from ray.rllib.agents import marwil - return marwil.MARWILAgent + return marwil.MARWILTrainer ALGORITHMS = { @@ -122,13 +122,13 @@ def _get_agent_class(alg): from ray.tune import script_runner return script_runner.ScriptRunner elif alg == "__fake": - from ray.rllib.agents.mock import _MockAgent - return _MockAgent + from ray.rllib.agents.mock import _MockTrainer + return _MockTrainer elif alg == "__sigmoid_fake_data": from ray.rllib.agents.mock import _SigmoidFakeData return _SigmoidFakeData elif alg == "__parameter_tuning": - from ray.rllib.agents.mock import _ParameterTuningAgent - return _ParameterTuningAgent + from ray.rllib.agents.mock import _ParameterTuningTrainer + return _ParameterTuningTrainer else: raise Exception(("Unknown algorithm {}.").format(alg)) diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py new file mode 100644 index 000000000..0eb8096d6 --- /dev/null +++ b/python/ray/rllib/agents/trainer.py @@ -0,0 +1,817 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from datetime import datetime +import copy +import logging +import os +import pickle +import six +import time +import tempfile +import tensorflow as tf +from types import FunctionType + +import ray +from ray.exceptions import RayError +from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \ + ShuffledInput +from ray.rllib.models import MODEL_DEFAULTS +from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \ + _validate_multiagent_config +from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer +from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI +from ray.rllib.utils import FilterManager, deep_update, merge_dicts +from ray.tune.registry import ENV_CREATOR, register_env, _global_registry +from ray.tune.trainable import Trainable +from ray.tune.trial import Resources, ExportFormat +from ray.tune.logger import UnifiedLogger +from ray.tune.result import DEFAULT_RESULTS_DIR + +logger = logging.getLogger(__name__) + +# Max number of times to retry a worker failure. We shouldn't try too many +# times in a row since that would indicate a persistent cluster issue. +MAX_WORKER_FAILURE_RETRIES = 3 + +# yapf: disable +# __sphinx_doc_begin__ +COMMON_CONFIG = { + # === Debugging === + # Whether to write episode stats and videos to the agent log dir + "monitor": False, + # Set the ray.rllib.* log level for the agent process and its evaluators. + # Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also + # periodically print out summaries of relevant internal dataflow (this is + # also printed out once at startup at the INFO level). + "log_level": "INFO", + # Callbacks that will be run during various phases of training. These all + # take a single "info" dict as an argument. For episode callbacks, custom + # metrics can be attached to the episode by updating the episode object's + # custom metrics dict (see examples/custom_metrics_and_callbacks.py). + "callbacks": { + "on_episode_start": None, # arg: {"env": .., "episode": ...} + "on_episode_step": None, # arg: {"env": .., "episode": ...} + "on_episode_end": None, # arg: {"env": .., "episode": ...} + "on_sample_end": None, # arg: {"samples": .., "evaluator": ...} + "on_train_result": None, # arg: {"trainer": ..., "result": ...} + "on_postprocess_traj": None, # arg: {"batch": ..., "episode": ...} + }, + # Whether to attempt to continue training if a worker crashes. + "ignore_worker_failures": False, + + # === Policy === + # Arguments to pass to model. See models/catalog.py for a full list of the + # available model options. + "model": MODEL_DEFAULTS, + # Arguments to pass to the policy optimizer. These vary by optimizer. + "optimizer": {}, + + # === Environment === + # Discount factor of the MDP + "gamma": 0.99, + # Number of steps after which the episode is forced to terminate. Defaults + # to `env.spec.max_episode_steps` (if present) for Gym envs. + "horizon": None, + # Calculate rewards but don't reset the environment when the horizon is + # hit. This allows value estimation and RNN state to span across logical + # episodes denoted by horizon. This only has an effect if horizon != inf. + "soft_horizon": False, + # Arguments to pass to the env creator + "env_config": {}, + # Environment name can also be passed via config + "env": None, + # Whether to clip rewards prior to experience postprocessing. Setting to + # None means clip for Atari only. + "clip_rewards": None, + # Whether to np.clip() actions to the action space low/high range spec. + "clip_actions": True, + # Whether to use rllib or deepmind preprocessors by default + "preprocessor_pref": "deepmind", + + # === Resources === + # Number of actors used for parallelism + "num_workers": 2, + # Number of GPUs to allocate to the driver. Note that not all algorithms + # can take advantage of driver GPUs. This can be fraction (e.g., 0.3 GPUs). + "num_gpus": 0, + # Number of CPUs to allocate per worker. + "num_cpus_per_worker": 1, + # Number of GPUs to allocate per worker. This can be fractional. + "num_gpus_per_worker": 0, + # Any custom resources to allocate per worker. + "custom_resources_per_worker": {}, + # Number of CPUs to allocate for the driver. Note: this only takes effect + # when running in Tune. + "num_cpus_for_driver": 1, + + # === Execution === + # Number of environments to evaluate vectorwise per worker. + "num_envs_per_worker": 1, + # Default sample batch size (unroll length). Batches of this size are + # collected from workers until train_batch_size is met. When using + # multiple envs per worker, this is multiplied by num_envs_per_worker. + "sample_batch_size": 200, + # Training batch size, if applicable. Should be >= sample_batch_size. + # Samples batches will be concatenated together to this size for training. + "train_batch_size": 200, + # Whether to rollout "complete_episodes" or "truncate_episodes" + "batch_mode": "truncate_episodes", + # (Deprecated) Use a background thread for sampling (slightly off-policy) + "sample_async": False, + # Element-wise observation filter, either "NoFilter" or "MeanStdFilter" + "observation_filter": "NoFilter", + # Whether to synchronize the statistics of remote filters. + "synchronize_filters": True, + # Configure TF for single-process operation by default + "tf_session_args": { + # note: overriden by `local_evaluator_tf_session_args` + "intra_op_parallelism_threads": 2, + "inter_op_parallelism_threads": 2, + "gpu_options": { + "allow_growth": True, + }, + "log_device_placement": False, + "device_count": { + "CPU": 1 + }, + "allow_soft_placement": True, # required by PPO multi-gpu + }, + # Override the following tf session args on the local evaluator + "local_evaluator_tf_session_args": { + # Allow a higher level of parallelism by default, but not unlimited + # since that can cause crashes with many concurrent drivers. + "intra_op_parallelism_threads": 8, + "inter_op_parallelism_threads": 8, + }, + # Whether to LZ4 compress individual observations + "compress_observations": False, + # Drop metric batches from unresponsive workers after this many seconds + "collect_metrics_timeout": 180, + # Smooth metrics over this many episodes. + "metrics_smoothing_episodes": 100, + # If using num_envs_per_worker > 1, whether to create those new envs in + # remote processes instead of in the same worker. This adds overheads, but + # can make sense if your envs can take much time to step / reset + # (e.g., for StarCraft). Use this cautiously; overheads are significant. + "remote_worker_envs": False, + # Timeout that remote workers are waiting when polling environments. + # 0 (continue when at least one env is ready) is a reasonable default, + # but optimal value could be obtained by measuring your environment + # step / reset and model inference perf. + "remote_env_batch_wait_ms": 0, + + # === Offline Datasets === + # Specify how to generate experiences: + # - "sampler": generate experiences via online simulation (default) + # - a local directory or file glob expression (e.g., "/tmp/*.json") + # - a list of individual file paths/URIs (e.g., ["/tmp/1.json", + # "s3://bucket/2.json"]) + # - a dict with string keys and sampling probabilities as values (e.g., + # {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}). + # - a function that returns a rllib.offline.InputReader + "input": "sampler", + # Specify how to evaluate the current policy. This only has an effect when + # reading offline experiences. Available options: + # - "wis": the weighted step-wise importance sampling estimator. + # - "is": the step-wise importance sampling estimator. + # - "simulation": run the environment in the background, but use + # this data for evaluation only and not for learning. + "input_evaluation": ["is", "wis"], + # Whether to run postprocess_trajectory() on the trajectory fragments from + # offline inputs. Note that postprocessing will be done using the *current* + # policy, not the *behaviour* policy, which is typically undesirable for + # on-policy algorithms. + "postprocess_inputs": False, + # If positive, input batches will be shuffled via a sliding window buffer + # of this number of batches. Use this if the input data is not in random + # enough order. Input is delayed until the shuffle buffer is filled. + "shuffle_buffer_size": 0, + # Specify where experiences should be saved: + # - None: don't save any experiences + # - "logdir" to save to the agent log dir + # - a path/URI to save to a custom output directory (e.g., "s3://bucket/") + # - a function that returns a rllib.offline.OutputWriter + "output": None, + # What sample batch columns to LZ4 compress in the output data. + "output_compress_columns": ["obs", "new_obs"], + # Max output file size before rolling over to a new file. + "output_max_file_size": 64 * 1024 * 1024, + + # === Multiagent === + "multiagent": { + # Map from policy ids to tuples of (policy_graph_cls, obs_space, + # act_space, config). See policy_evaluator.py for more info. + "policy_graphs": {}, + # Function mapping agent ids to policy ids. + "policy_mapping_fn": None, + # Optional whitelist of policies to train, or None for all policies. + "policies_to_train": None, + }, +} +# __sphinx_doc_end__ +# yapf: enable + + +@DeveloperAPI +def with_common_config(extra_config): + """Returns the given config dict merged with common agent confs.""" + + return with_base_config(COMMON_CONFIG, extra_config) + + +def with_base_config(base_config, extra_config): + """Returns the given config dict merged with a base agent conf.""" + + config = copy.deepcopy(base_config) + config.update(extra_config) + return config + + +@PublicAPI +class Trainer(Trainable): + """A trainer coordinates the optimization of one or more RL policies. + + All RLlib trainers extend this base class, e.g., the A3CTrainer implements + the A3C algorithm for single and multi-agent training. + + Trainer objects retain internal model state between calls to train(), so + you should create a new trainer instance for each training session. + + Attributes: + env_creator (func): Function that creates a new training env. + config (obj): Algorithm-specific configuration data. + logdir (str): Directory in which training outputs should be placed. + """ + + _allow_unknown_configs = False + _allow_unknown_subkeys = [ + "tf_session_args", "env_config", "model", "optimizer", "multiagent", + "custom_resources_per_worker" + ] + + @PublicAPI + def __init__(self, config=None, env=None, logger_creator=None): + """Initialize an RLLib trainer. + + Args: + config (dict): Algorithm-specific configuration data. + env (str): Name of the environment to use. Note that this can also + be specified as the `env` key in config. + logger_creator (func): Function that creates a ray.tune.Logger + object. If unspecified, a default logger is created. + """ + + config = config or {} + + # Vars to synchronize to evaluators on each train call + self.global_vars = {"timestep": 0} + + # Trainers allow env ids to be passed directly to the constructor. + self._env_id = self._register_if_needed(env or config.get("env")) + + # Create a default logger creator if no logger_creator is specified + if logger_creator is None: + timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") + logdir_prefix = "{}_{}_{}".format(self._name, self._env_id, + timestr) + + def default_logger_creator(config): + """Creates a Unified logger with a default logdir prefix + containing the agent name and the env id + """ + if not os.path.exists(DEFAULT_RESULTS_DIR): + os.makedirs(DEFAULT_RESULTS_DIR) + logdir = tempfile.mkdtemp( + prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR) + return UnifiedLogger(config, logdir, None) + + logger_creator = default_logger_creator + + Trainable.__init__(self, config, logger_creator) + + @classmethod + @override(Trainable) + def default_resource_request(cls, config): + cf = dict(cls._default_config, **config) + Trainer._validate_config(cf) + # TODO(ekl): add custom resources here once tune supports them + return Resources( + cpu=cf["num_cpus_for_driver"], + gpu=cf["num_gpus"], + extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], + extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) + + @override(Trainable) + @PublicAPI + def train(self): + """Overrides super.train to synchronize global vars.""" + + if self._has_policy_optimizer(): + self.global_vars["timestep"] = self.optimizer.num_steps_sampled + self.optimizer.local_evaluator.set_global_vars(self.global_vars) + for ev in self.optimizer.remote_evaluators: + ev.set_global_vars.remote(self.global_vars) + logger.debug("updated global vars: {}".format(self.global_vars)) + + result = None + for _ in range(1 + MAX_WORKER_FAILURE_RETRIES): + try: + result = Trainable.train(self) + except RayError as e: + if self.config["ignore_worker_failures"]: + logger.exception( + "Error in train call, attempting to recover") + self._try_recover() + else: + logger.info( + "Worker crashed during call to train(). To attempt to " + "continue training without the failed worker, set " + "`'ignore_worker_failures': True`.") + raise e + except Exception as e: + time.sleep(0.5) # allow logs messages to propagate + raise e + else: + break + if result is None: + raise RuntimeError("Failed to recover from worker crash") + + if (self.config.get("observation_filter", "NoFilter") != "NoFilter" + and hasattr(self, "local_evaluator")): + FilterManager.synchronize( + self.local_evaluator.filters, + self.remote_evaluators, + update_remote=self.config["synchronize_filters"]) + logger.debug("synchronized filters: {}".format( + self.local_evaluator.filters)) + + if self._has_policy_optimizer(): + result["num_healthy_workers"] = len( + self.optimizer.remote_evaluators) + return result + + @override(Trainable) + def _log_result(self, result): + if self.config["callbacks"].get("on_train_result"): + self.config["callbacks"]["on_train_result"]({ + "trainer": self, + "result": result, + }) + # log after the callback is invoked, so that the user has a chance + # to mutate the result + Trainable._log_result(self, result) + + @override(Trainable) + def _setup(self, config): + env = self._env_id + if env: + config["env"] = env + if _global_registry.contains(ENV_CREATOR, env): + self.env_creator = _global_registry.get(ENV_CREATOR, env) + else: + import gym # soft dependency + self.env_creator = lambda env_config: gym.make(env) + else: + self.env_creator = lambda env_config: None + + # Merge the supplied config with the class default + merged_config = copy.deepcopy(self._default_config) + merged_config = deep_update(merged_config, config, + self._allow_unknown_configs, + self._allow_unknown_subkeys) + self.raw_user_config = config + self.config = merged_config + Trainer._validate_config(self.config) + if self.config.get("log_level"): + logging.getLogger("ray.rllib").setLevel(self.config["log_level"]) + + # TODO(ekl) setting the graph is unnecessary for PyTorch agents + with tf.Graph().as_default(): + self._init(self.config, self.env_creator) + + @override(Trainable) + def _stop(self): + # Call stop on all evaluators to release resources + if hasattr(self, "local_evaluator"): + self.local_evaluator.stop() + if hasattr(self, "remote_evaluators"): + for ev in self.remote_evaluators: + ev.stop.remote() + + # workaround for https://github.com/ray-project/ray/issues/1516 + if hasattr(self, "remote_evaluators"): + for ev in self.remote_evaluators: + ev.__ray_terminate__.remote() + + if hasattr(self, "optimizer"): + self.optimizer.stop() + + @override(Trainable) + def _save(self, checkpoint_dir): + checkpoint_path = os.path.join(checkpoint_dir, + "checkpoint-{}".format(self.iteration)) + pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) + return checkpoint_path + + @override(Trainable) + def _restore(self, checkpoint_path): + extra_data = pickle.load(open(checkpoint_path, "rb")) + self.__setstate__(extra_data) + + @DeveloperAPI + def _init(self, config, env_creator): + """Subclasses should override this for custom initialization.""" + + raise NotImplementedError + + @PublicAPI + def compute_action(self, + observation, + state=None, + prev_action=None, + prev_reward=None, + info=None, + policy_id=DEFAULT_POLICY_ID, + full_fetch=False): + """Computes an action for the specified policy. + + Note that you can also access the policy object through + self.get_policy(policy_id) and call compute_actions() on it directly. + + Arguments: + observation (obj): observation from the environment. + state (list): RNN hidden state, if any. If state is not None, + then all of compute_single_action(...) is returned + (computed action, rnn state, logits dictionary). + Otherwise compute_single_action(...)[0] is + returned (computed action). + prev_action (obj): previous action value, if any + prev_reward (int): previous reward, if any + info (dict): info object, if any + policy_id (str): policy to query (only applies to multi-agent). + full_fetch (bool): whether to return extra action fetch results. + This is always set to true if RNN state is specified. + + Returns: + Just the computed action if full_fetch=False, or the full output + of policy.compute_actions() otherwise. + """ + + if state is None: + state = [] + preprocessed = self.local_evaluator.preprocessors[policy_id].transform( + observation) + filtered_obs = self.local_evaluator.filters[policy_id]( + preprocessed, update=False) + if state: + return self.get_policy(policy_id).compute_single_action( + filtered_obs, + state, + prev_action, + prev_reward, + info, + clip_actions=self.config["clip_actions"]) + res = self.get_policy(policy_id).compute_single_action( + filtered_obs, + state, + prev_action, + prev_reward, + info, + clip_actions=self.config["clip_actions"]) + if full_fetch: + return res + else: + return res[0] # backwards compatibility + + @property + def iteration(self): + """Current training iter, auto-incremented with each train() call.""" + + return self._iteration + + @property + def _name(self): + """Subclasses should override this to declare their name.""" + + raise NotImplementedError + + @property + def _default_config(self): + """Subclasses should override this to declare their default config.""" + + raise NotImplementedError + + @PublicAPI + def get_policy(self, policy_id=DEFAULT_POLICY_ID): + """Return policy graph for the specified id, or None. + + Arguments: + policy_id (str): id of policy graph to return. + """ + + return self.local_evaluator.get_policy(policy_id) + + @PublicAPI + def get_weights(self, policies=None): + """Return a dictionary of policy ids to weights. + + Arguments: + policies (list): Optional list of policies to return weights for, + or None for all policies. + """ + return self.local_evaluator.get_weights(policies) + + @PublicAPI + def set_weights(self, weights): + """Set policy weights by policy id. + + Arguments: + weights (dict): Map of policy ids to weights to set. + """ + self.local_evaluator.set_weights(weights) + + @DeveloperAPI + def make_local_evaluator(self, + env_creator, + policy_graph, + extra_config=None): + """Convenience method to return configured local evaluator.""" + + return self._make_evaluator( + PolicyEvaluator, + env_creator, + policy_graph, + 0, + merge_dicts( + # important: allow local tf to use more CPUs for optimization + merge_dicts( + self.config, { + "tf_session_args": self. + config["local_evaluator_tf_session_args"] + }), + extra_config or {})) + + @DeveloperAPI + def make_remote_evaluators(self, env_creator, policy_graph, count): + """Convenience method to return a number of remote evaluators.""" + + remote_args = { + "num_cpus": self.config["num_cpus_per_worker"], + "num_gpus": self.config["num_gpus_per_worker"], + "resources": self.config["custom_resources_per_worker"], + } + + cls = PolicyEvaluator.as_remote(**remote_args).remote + + return [ + self._make_evaluator(cls, env_creator, policy_graph, i + 1, + self.config) for i in range(count) + ] + + @DeveloperAPI + def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID): + """Export policy model with given policy_id to local directory. + + Arguments: + export_dir (string): Writable local directory. + policy_id (string): Optional policy id to export. + + Example: + >>> trainer = MyTrainer() + >>> for _ in range(10): + >>> trainer.train() + >>> trainer.export_policy_model("/tmp/export_dir") + """ + self.local_evaluator.export_policy_model(export_dir, policy_id) + + @DeveloperAPI + def export_policy_checkpoint(self, + export_dir, + filename_prefix="model", + policy_id=DEFAULT_POLICY_ID): + """Export tensorflow policy model checkpoint to local directory. + + Arguments: + export_dir (string): Writable local directory. + filename_prefix (string): file name prefix of checkpoint files. + policy_id (string): Optional policy id to export. + + Example: + >>> trainer = MyTrainer() + >>> for _ in range(10): + >>> trainer.train() + >>> trainer.export_policy_checkpoint("/tmp/export_dir") + """ + self.local_evaluator.export_policy_checkpoint( + export_dir, filename_prefix, policy_id) + + @DeveloperAPI + def collect_metrics(self, selected_evaluators=None): + """Collects metrics from the remote evaluators of this agent. + + This is the same data as returned by a call to train(). + """ + return self.optimizer.collect_metrics( + self.config["collect_metrics_timeout"], + min_history=self.config["metrics_smoothing_episodes"], + selected_evaluators=selected_evaluators) + + @classmethod + def resource_help(cls, config): + return ("\n\nYou can adjust the resource requests of RLlib agents by " + "setting `num_workers`, `num_gpus`, and other configs. See " + "the DEFAULT_CONFIG defined by each agent for more info.\n\n" + "The config of this agent is: {}".format(config)) + + @staticmethod + def _validate_config(config): + if "gpu" in config: + raise ValueError( + "The `gpu` config is deprecated, please use `num_gpus=0|1` " + "instead.") + if "gpu_fraction" in config: + raise ValueError( + "The `gpu_fraction` config is deprecated, please use " + "`num_gpus=` instead.") + if "use_gpu_for_workers" in config: + raise ValueError( + "The `use_gpu_for_workers` config is deprecated, please use " + "`num_gpus_per_worker=1` instead.") + if type(config["input_evaluation"]) != list: + raise ValueError( + "`input_evaluation` must be a list of strings, got {}".format( + config["input_evaluation"])) + + def _try_recover(self): + """Try to identify and blacklist any unhealthy workers. + + This method is called after an unexpected remote error is encountered + from a worker. It issues check requests to all current workers and + blacklists any that respond with error. If no healthy workers remain, + an error is raised. + """ + + if not self._has_policy_optimizer(): + raise NotImplementedError( + "Recovery is not supported for this algorithm") + + logger.info("Health checking all workers...") + checks = [] + for ev in self.optimizer.remote_evaluators: + _, obj_id = ev.sample_with_count.remote() + checks.append(obj_id) + + healthy_evaluators = [] + for i, obj_id in enumerate(checks): + ev = self.optimizer.remote_evaluators[i] + try: + ray.get(obj_id) + healthy_evaluators.append(ev) + logger.info("Worker {} looks healthy".format(i + 1)) + except RayError: + logger.exception("Blacklisting worker {}".format(i + 1)) + try: + ev.__ray_terminate__.remote() + except Exception: + logger.exception("Error terminating unhealthy worker") + + if len(healthy_evaluators) < 1: + raise RuntimeError( + "Not enough healthy workers remain to continue.") + + self.optimizer.reset(healthy_evaluators) + + def _has_policy_optimizer(self): + return hasattr(self, "optimizer") and isinstance( + self.optimizer, PolicyOptimizer) + + def _make_evaluator(self, cls, env_creator, policy_graph, worker_index, + config): + def session_creator(): + logger.debug("Creating TF session {}".format( + config["tf_session_args"])) + return tf.Session( + config=tf.ConfigProto(**config["tf_session_args"])) + + if isinstance(config["input"], FunctionType): + input_creator = config["input"] + elif config["input"] == "sampler": + input_creator = (lambda ioctx: ioctx.default_sampler_input()) + elif isinstance(config["input"], dict): + input_creator = (lambda ioctx: ShuffledInput( + MixedInput(config["input"], ioctx), config[ + "shuffle_buffer_size"])) + else: + input_creator = (lambda ioctx: ShuffledInput( + JsonReader(config["input"], ioctx), config[ + "shuffle_buffer_size"])) + + if isinstance(config["output"], FunctionType): + output_creator = config["output"] + elif config["output"] is None: + output_creator = (lambda ioctx: NoopOutput()) + elif config["output"] == "logdir": + output_creator = (lambda ioctx: JsonWriter( + ioctx.log_dir, + ioctx, + max_file_size=config["output_max_file_size"], + compress_columns=config["output_compress_columns"])) + else: + output_creator = (lambda ioctx: JsonWriter( + config["output"], + ioctx, + max_file_size=config["output_max_file_size"], + compress_columns=config["output_compress_columns"])) + + if config["input"] == "sampler": + input_evaluation = [] + else: + input_evaluation = config["input_evaluation"] + + # Fill in the default policy graph if 'None' is specified in multiagent + if self.config["multiagent"]["policy_graphs"]: + tmp = self.config["multiagent"]["policy_graphs"] + _validate_multiagent_config(tmp, allow_none_graph=True) + for k, v in tmp.items(): + if v[0] is None: + tmp[k] = (policy_graph, v[1], v[2], v[3]) + policy_graph = tmp + + return cls( + env_creator, + policy_graph, + policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], + policies_to_train=self.config["multiagent"]["policies_to_train"], + tf_session_creator=(session_creator + if config["tf_session_args"] else None), + batch_steps=config["sample_batch_size"], + batch_mode=config["batch_mode"], + episode_horizon=config["horizon"], + preprocessor_pref=config["preprocessor_pref"], + sample_async=config["sample_async"], + compress_observations=config["compress_observations"], + num_envs=config["num_envs_per_worker"], + observation_filter=config["observation_filter"], + clip_rewards=config["clip_rewards"], + clip_actions=config["clip_actions"], + env_config=config["env_config"], + model_config=config["model"], + policy_config=config, + worker_index=worker_index, + monitor_path=self.logdir if config["monitor"] else None, + log_dir=self.logdir, + log_level=config["log_level"], + callbacks=config["callbacks"], + input_creator=input_creator, + input_evaluation=input_evaluation, + output_creator=output_creator, + remote_worker_envs=config["remote_worker_envs"], + remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"], + soft_horizon=config["soft_horizon"], + _fake_sampler=config.get("_fake_sampler", False)) + + @override(Trainable) + def _export_model(self, export_formats, export_dir): + ExportFormat.validate(export_formats) + exported = {} + if ExportFormat.CHECKPOINT in export_formats: + path = os.path.join(export_dir, ExportFormat.CHECKPOINT) + self.export_policy_checkpoint(path) + exported[ExportFormat.CHECKPOINT] = path + if ExportFormat.MODEL in export_formats: + path = os.path.join(export_dir, ExportFormat.MODEL) + self.export_policy_model(path) + exported[ExportFormat.MODEL] = path + return exported + + def __getstate__(self): + state = {} + if hasattr(self, "local_evaluator"): + state["evaluator"] = self.local_evaluator.save() + if hasattr(self, "optimizer") and hasattr(self.optimizer, "save"): + state["optimizer"] = self.optimizer.save() + return state + + def __setstate__(self, state): + if "evaluator" in state: + self.local_evaluator.restore(state["evaluator"]) + remote_state = ray.put(state["evaluator"]) + for r in self.remote_evaluators: + r.restore.remote(remote_state) + if "optimizer" in state: + self.optimizer.restore(state["optimizer"]) + + def _register_if_needed(self, env_object): + if isinstance(env_object, six.string_types): + return env_object + elif isinstance(env_object, type): + name = env_object.__name__ + register_env(name, lambda config: env_object(config)) + return name + raise ValueError( + "{} is an invalid env specification. ".format(env_object) + + "You can specify a custom env as either a class " + "(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").") diff --git a/python/ray/rllib/contrib/random_agent/random_agent.py b/python/ray/rllib/contrib/random_agent/random_agent.py index dbf55e35b..17b4c8d09 100644 --- a/python/ray/rllib/contrib/random_agent/random_agent.py +++ b/python/ray/rllib/contrib/random_agent/random_agent.py @@ -4,25 +4,25 @@ from __future__ import print_function import numpy as np -from ray.rllib.agents.agent import Agent, with_common_config +from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ -class RandomAgent(Agent): - """Agent that takes random actions and never learns.""" +class RandomAgent(Trainer): + """Policy that takes random actions and never learns.""" - _agent_name = "RandomAgent" + _name = "RandomAgent" _default_config = with_common_config({ "rollouts_per_iteration": 10, }) - @override(Agent) + @override(Trainer) def _init(self, config, env_creator): self.env = env_creator(config["env_config"]) - @override(Agent) + @override(Trainer) def _train(self): rewards = [] steps = 0 @@ -45,8 +45,8 @@ class RandomAgent(Agent): if __name__ == "__main__": - agent = RandomAgent( + trainer = RandomAgent( env="CartPole-v0", config={"rollouts_per_iteration": 10}) - result = agent.train() + result = trainer.train() assert result["episode_reward_mean"] > 10, result print("Test: OK") diff --git a/python/ray/rllib/env/external_env.py b/python/ray/rllib/env/external_env.py index 4f70e2299..46d15a2de 100644 --- a/python/ray/rllib/env/external_env.py +++ b/python/ray/rllib/env/external_env.py @@ -34,9 +34,9 @@ class ExternalEnv(threading.Thread): Examples: >>> register_env("my_env", lambda config: YourExternalEnv(config)) - >>> agent = DQNAgent(env="my_env") + >>> trainer = DQNTrainer(env="my_env") >>> while True: - print(agent.train()) + print(trainer.train()) """ @PublicAPI diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index e4e834ba9..4c5a80c86 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -35,6 +35,19 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder logger = logging.getLogger(__name__) +# Handle to the current evaluator, which will be set to the most recently +# created PolicyEvaluator in this process. This can be helpful to access in +# custom env or policy classes for debugging or advanced use cases. +_global_evaluator = None + + +@DeveloperAPI +def get_global_evaluator(): + """Returns a handle to the active policy evaluator in this process.""" + + global _global_evaluator + return _global_evaluator + @DeveloperAPI class PolicyEvaluator(EvaluatorInterface): @@ -215,6 +228,9 @@ class PolicyEvaluator(EvaluatorInterface): _fake_sampler (bool): Use a fake (inf speed) sampler for testing. """ + global _global_evaluator + _global_evaluator = self + if log_level: logging.getLogger("ray.rllib").setLevel(log_level) diff --git a/python/ray/rllib/evaluation/sample_batch_builder.py b/python/ray/rllib/evaluation/sample_batch_builder.py index 2387f5cd0..973388075 100644 --- a/python/ray/rllib/evaluation/sample_batch_builder.py +++ b/python/ray/rllib/evaluation/sample_batch_builder.py @@ -71,12 +71,13 @@ class MultiAgentSampleBatchBuilder(object): corresponding policy batch for the agent's policy. """ - def __init__(self, policy_map, clip_rewards): + def __init__(self, policy_map, clip_rewards, postp_callback): """Initialize a MultiAgentSampleBatchBuilder. Arguments: policy_map (dict): Maps policy ids to policy graph instances. clip_rewards (bool): Whether to clip rewards before postprocessing. + postp_callback: function to call on each postprocessed batch. """ self.policy_map = policy_map @@ -87,6 +88,7 @@ class MultiAgentSampleBatchBuilder(object): } self.agent_builders = {} self.agent_to_policy = {} + self.postp_callback = postp_callback self.count = 0 # increment this manually def total(self): @@ -158,6 +160,8 @@ class MultiAgentSampleBatchBuilder(object): for agent_id, post_batch in sorted(post_batches.items()): self.policy_builders[self.agent_to_policy[agent_id]].add_batch( post_batch) + if self.postp_callback: + self.postp_callback({"episode": episode, "batch": post_batch}) self.agent_builders.clear() self.agent_to_policy.clear() diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 08f6f2003..25dd1ef5d 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -279,7 +279,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn, if batch_builder_pool: return batch_builder_pool.pop() else: - return MultiAgentSampleBatchBuilder(policies, clip_rewards) + return MultiAgentSampleBatchBuilder( + policies, clip_rewards, callbacks.get("on_postprocess_traj")) def new_episode(): episode = MultiAgentEpisode(policies, policy_mapping_fn, diff --git a/python/ray/rllib/examples/custom_metrics_and_callbacks.py b/python/ray/rllib/examples/custom_metrics_and_callbacks.py index a9c36ce35..27d91331f 100644 --- a/python/ray/rllib/examples/custom_metrics_and_callbacks.py +++ b/python/ray/rllib/examples/custom_metrics_and_callbacks.py @@ -38,12 +38,21 @@ def on_sample_end(info): def on_train_result(info): - print("agent.train() result: {} -> {} episodes".format( - info["agent"], info["result"]["episodes_this_iter"])) + print("trainer.train() result: {} -> {} episodes".format( + info["trainer"], info["result"]["episodes_this_iter"])) # you can mutate the result dict to add new fields to return info["result"]["callback_ok"] = True +def on_postprocess_traj(info): + episode = info["episode"] + batch = info["batch"] + print("postprocessed {} steps".format(batch.count)) + if "num_batches" not in episode.custom_metrics: + episode.custom_metrics["num_batches"] = 0 + episode.custom_metrics["num_batches"] += 1 + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--num-iters", type=int, default=2000) @@ -63,6 +72,7 @@ if __name__ == "__main__": "on_episode_end": tune.function(on_episode_end), "on_sample_end": tune.function(on_sample_end), "on_train_result": tune.function(on_train_result), + "on_postprocess_traj": tune.function(on_postprocess_traj), }, }, ) @@ -73,5 +83,6 @@ if __name__ == "__main__": assert "pole_angle_mean" in custom_metrics assert "pole_angle_min" in custom_metrics assert "pole_angle_max" in custom_metrics + assert "num_batches_mean" in custom_metrics assert type(custom_metrics["pole_angle_mean"]) is float assert "callback_ok" in trials[0].last_result diff --git a/python/ray/rllib/examples/custom_train_fn.py b/python/ray/rllib/examples/custom_train_fn.py index 224a54d51..3ae418ee7 100644 --- a/python/ray/rllib/examples/custom_train_fn.py +++ b/python/ray/rllib/examples/custom_train_fn.py @@ -12,12 +12,12 @@ from __future__ import print_function import ray from ray import tune -from ray.rllib.agents.ppo import PPOAgent +from ray.rllib.agents.ppo import PPOTrainer def my_train_fn(config, reporter): # Train for 100 iterations with high LR - agent1 = PPOAgent(env="CartPole-v0", config=config) + agent1 = PPOTrainer(env="CartPole-v0", config=config) for _ in range(10): result = agent1.train() result["phase"] = 1 @@ -28,7 +28,7 @@ def my_train_fn(config, reporter): # Train for 100 iterations with low LR config["lr"] = 0.0001 - agent2 = PPOAgent(env="CartPole-v0", config=config) + agent2 = PPOTrainer(env="CartPole-v0", config=config) agent2.restore(state) for _ in range(10): result = agent2.train() diff --git a/python/ray/rllib/examples/multiagent_two_trainers.py b/python/ray/rllib/examples/multiagent_two_trainers.py index 7bdfb0435..2c18f2bf4 100644 --- a/python/ray/rllib/examples/multiagent_two_trainers.py +++ b/python/ray/rllib/examples/multiagent_two_trainers.py @@ -15,9 +15,9 @@ import argparse import gym import ray -from ray.rllib.agents.dqn.dqn import DQNAgent +from ray.rllib.agents.dqn.dqn import DQNTrainer from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph -from ray.rllib.agents.ppo.ppo import PPOAgent +from ray.rllib.agents.ppo.ppo import PPOTrainer from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph from ray.rllib.tests.test_multi_agent_env import MultiCartpole from ray.tune.logger import pretty_print @@ -49,7 +49,7 @@ if __name__ == "__main__": else: return "dqn_policy" - ppo_trainer = PPOAgent( + ppo_trainer = PPOTrainer( env="multi_cartpole", config={ "multiagent": { @@ -62,7 +62,7 @@ if __name__ == "__main__": "observation_filter": "NoFilter", }) - dqn_trainer = DQNAgent( + dqn_trainer = DQNTrainer( env="multi_cartpole", config={ "multiagent": { diff --git a/python/ray/rllib/examples/serving/cartpole_server.py b/python/ray/rllib/examples/serving/cartpole_server.py index 40260350c..c9c794c3a 100755 --- a/python/ray/rllib/examples/serving/cartpole_server.py +++ b/python/ray/rllib/examples/serving/cartpole_server.py @@ -13,7 +13,7 @@ from gym import spaces import numpy as np import ray -from ray.rllib.agents.dqn import DQNAgent +from ray.rllib.agents.dqn import DQNTrainer from ray.rllib.env.external_env import ExternalEnv from ray.rllib.utils.policy_server import PolicyServer from ray.tune.logger import pretty_print @@ -43,7 +43,7 @@ if __name__ == "__main__": # We use DQN since it supports off-policy actions, but you can choose and # configure any agent. - dqn = DQNAgent( + dqn = DQNTrainer( env="srv", config={ # Use a single process to avoid needing to set up a load balancer diff --git a/python/ray/rllib/optimizers/aso_minibatch_buffer.py b/python/ray/rllib/optimizers/aso_minibatch_buffer.py index 25b685d4d..0288e7aff 100644 --- a/python/ray/rllib/optimizers/aso_minibatch_buffer.py +++ b/python/ray/rllib/optimizers/aso_minibatch_buffer.py @@ -35,7 +35,7 @@ class MinibatchBuffer(object): released: True if the item is now removed from the ring buffer. """ if self.ttl[self.idx] <= 0: - self.buffers[self.idx] = self.inqueue.get(timeout=60.0) + self.buffers[self.idx] = self.inqueue.get(timeout=300.0) self.ttl[self.idx] = self.cur_max_ttl if self.cur_max_ttl < self.max_ttl: self.cur_max_ttl += 1 diff --git a/python/ray/rllib/tests/test_avail_actions_qmix.py b/python/ray/rllib/tests/test_avail_actions_qmix.py index 606f35873..fc45e5d2a 100644 --- a/python/ray/rllib/tests/test_avail_actions_qmix.py +++ b/python/ray/rllib/tests/test_avail_actions_qmix.py @@ -7,7 +7,7 @@ from gym.spaces import Tuple, Discrete, Dict, Box import ray from ray.tune import register_env from ray.rllib.env.multi_agent_env import MultiAgentEnv -from ray.rllib.agents.qmix import QMixAgent +from ray.rllib.agents.qmix import QMixTrainer class AvailActionsTestEnv(MultiAgentEnv): @@ -55,7 +55,7 @@ if __name__ == "__main__": grouping, obs_space=obs_space, act_space=act_space)) ray.init() - agent = QMixAgent( + agent = QMixTrainer( env="action_mask_test", config={ "num_envs_per_worker": 5, # test with vectorization on diff --git a/python/ray/rllib/tests/test_evaluators.py b/python/ray/rllib/tests/test_evaluators.py index 6dc532a15..840fd88c2 100644 --- a/python/ray/rllib/tests/test_evaluators.py +++ b/python/ray/rllib/tests/test_evaluators.py @@ -5,7 +5,7 @@ from __future__ import print_function import unittest import ray -from ray.rllib.agents.dqn import DQNAgent +from ray.rllib.agents.dqn import DQNTrainer from ray.rllib.agents.dqn.dqn_policy_graph import _adjust_nstep @@ -26,7 +26,8 @@ class DQNTest(unittest.TestCase): def testEvaluationOption(self): ray.init() - agent = DQNAgent(env="CartPole-v0", config={"evaluation_interval": 2}) + agent = DQNTrainer( + env="CartPole-v0", config={"evaluation_interval": 2}) r0 = agent.train() r1 = agent.train() r2 = agent.train() diff --git a/python/ray/rllib/tests/test_external_env.py b/python/ray/rllib/tests/test_external_env.py index 5fbfcfb4c..337963961 100644 --- a/python/ray/rllib/tests/test_external_env.py +++ b/python/ray/rllib/tests/test_external_env.py @@ -9,8 +9,8 @@ import unittest import uuid import ray -from ray.rllib.agents.dqn import DQNAgent -from ray.rllib.agents.pg import PGAgent +from ray.rllib.agents.dqn import DQNTrainer +from ray.rllib.agents.pg import PGTrainer from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.env.external_env import ExternalEnv from ray.rllib.tests.test_policy_evaluator import (BadPolicyGraph, @@ -163,7 +163,7 @@ class TestExternalEnv(unittest.TestCase): register_env( "test3", lambda _: PartOffPolicyServing( gym.make("CartPole-v0"), off_pol_frac=0.2)) - dqn = DQNAgent(env="test3", config={"exploration_fraction": 0.001}) + dqn = DQNTrainer(env="test3", config={"exploration_fraction": 0.001}) for i in range(100): result = dqn.train() print("Iteration {}, reward {}, timesteps {}".format( @@ -174,7 +174,7 @@ class TestExternalEnv(unittest.TestCase): def testTrainCartpole(self): register_env("test", lambda _: SimpleServing(gym.make("CartPole-v0"))) - pg = PGAgent(env="test", config={"num_workers": 0}) + pg = PGTrainer(env="test", config={"num_workers": 0}) for i in range(100): result = pg.train() print("Iteration {}, reward {}, timesteps {}".format( @@ -186,7 +186,7 @@ class TestExternalEnv(unittest.TestCase): def testTrainCartpoleMulti(self): register_env("test2", lambda _: MultiServing(lambda: gym.make("CartPole-v0"))) - pg = PGAgent(env="test2", config={"num_workers": 0}) + pg = PGTrainer(env="test2", config={"num_workers": 0}) for i in range(100): result = pg.train() print("Iteration {}, reward {}, timesteps {}".format( diff --git a/python/ray/rllib/tests/test_io.py b/python/ray/rllib/tests/test_io.py index ef3969615..9f92c9107 100644 --- a/python/ray/rllib/tests/test_io.py +++ b/python/ray/rllib/tests/test_io.py @@ -14,7 +14,7 @@ import time import unittest import ray -from ray.rllib.agents.pg import PGAgent +from ray.rllib.agents.pg import PGTrainer from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph from ray.rllib.evaluation import SampleBatch from ray.rllib.offline import IOContext, JsonWriter, JsonReader @@ -44,7 +44,7 @@ class AgentIOTest(unittest.TestCase): shutil.rmtree(self.test_dir) def writeOutputs(self, output): - agent = PGAgent( + agent = PGTrainer( env="CartPole-v0", config={ "output": output, @@ -65,7 +65,7 @@ class AgentIOTest(unittest.TestCase): def testAgentInputDir(self): self.writeOutputs(self.test_dir) - agent = PGAgent( + agent = PGTrainer( env="CartPole-v0", config={ "input": self.test_dir, @@ -97,7 +97,7 @@ class AgentIOTest(unittest.TestCase): for data in out: f.write(json.dumps(data)) - agent = PGAgent( + agent = PGTrainer( env="CartPole-v0", config={ "input": self.test_dir, @@ -111,7 +111,7 @@ class AgentIOTest(unittest.TestCase): def testAgentInputEvalSim(self): self.writeOutputs(self.test_dir) - agent = PGAgent( + agent = PGTrainer( env="CartPole-v0", config={ "input": self.test_dir, @@ -126,7 +126,7 @@ class AgentIOTest(unittest.TestCase): def testAgentInputList(self): self.writeOutputs(self.test_dir) - agent = PGAgent( + agent = PGTrainer( env="CartPole-v0", config={ "input": glob.glob(self.test_dir + "/*.json"), @@ -139,7 +139,7 @@ class AgentIOTest(unittest.TestCase): def testAgentInputDict(self): self.writeOutputs(self.test_dir) - agent = PGAgent( + agent = PGTrainer( env="CartPole-v0", config={ "input": { @@ -161,7 +161,7 @@ class AgentIOTest(unittest.TestCase): act_space = single_env.action_space return (PGPolicyGraph, obs_space, act_space, {}) - pg = PGAgent( + pg = PGTrainer( env="multi_cartpole", config={ "num_workers": 0, @@ -180,7 +180,7 @@ class AgentIOTest(unittest.TestCase): self.assertEqual(len(os.listdir(self.test_dir)), 1) pg.stop() - pg = PGAgent( + pg = PGTrainer( env="multi_cartpole", config={ "num_workers": 0, diff --git a/python/ray/rllib/tests/test_local.py b/python/ray/rllib/tests/test_local.py index 2f76de374..67f7bc7d5 100644 --- a/python/ray/rllib/tests/test_local.py +++ b/python/ray/rllib/tests/test_local.py @@ -4,7 +4,7 @@ from __future__ import print_function import unittest -from ray.rllib.agents.ppo import PPOAgent, DEFAULT_CONFIG +from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG import ray @@ -12,7 +12,7 @@ class LocalModeTest(unittest.TestCase): def testLocal(self): ray.init(local_mode=True) cf = DEFAULT_CONFIG.copy() - agent = PPOAgent(cf, "CartPole-v0") + agent = PPOTrainer(cf, "CartPole-v0") print(agent.train()) diff --git a/python/ray/rllib/tests/test_lstm.py b/python/ray/rllib/tests/test_lstm.py index 304478eb8..56e2dfd5c 100644 --- a/python/ray/rllib/tests/test_lstm.py +++ b/python/ray/rllib/tests/test_lstm.py @@ -10,7 +10,7 @@ import tensorflow as tf import tensorflow.contrib.rnn as rnn import ray -from ray.rllib.agents.ppo import PPOAgent +from ray.rllib.agents.ppo import PPOTrainer from ray.rllib.models import ModelCatalog from ray.rllib.models.lstm import add_time_dimension, chop_into_sequences from ray.rllib.models.misc import linear, normc_initializer @@ -149,7 +149,7 @@ class RNNSequencing(unittest.TestCase): def testSimpleOptimizerSequencing(self): ModelCatalog.register_custom_model("rnn", RNNSpyModel) register_env("counter", lambda _: DebugCounterEnv()) - ppo = PPOAgent( + ppo = PPOTrainer( env="counter", config={ "num_workers": 0, @@ -205,7 +205,7 @@ class RNNSequencing(unittest.TestCase): def testMinibatchSequencing(self): ModelCatalog.register_custom_model("rnn", RNNSpyModel) register_env("counter", lambda _: DebugCounterEnv()) - ppo = PPOAgent( + ppo = PPOTrainer( env="counter", config={ "num_workers": 0, diff --git a/python/ray/rllib/tests/test_multi_agent_env.py b/python/ray/rllib/tests/test_multi_agent_env.py index 376e05dc0..955093b84 100644 --- a/python/ray/rllib/tests/test_multi_agent_env.py +++ b/python/ray/rllib/tests/test_multi_agent_env.py @@ -7,7 +7,7 @@ import random import unittest import ray -from ray.rllib.agents.pg import PGAgent +from ray.rllib.agents.pg import PGTrainer from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer, @@ -519,7 +519,7 @@ class TestMultiAgentEnv(unittest.TestCase): def testTrainMultiCartpoleSinglePolicy(self): n = 10 register_env("multi_cartpole", lambda _: MultiCartpole(n)) - pg = PGAgent(env="multi_cartpole", config={"num_workers": 0}) + pg = PGTrainer(env="multi_cartpole", config={"num_workers": 0}) for i in range(100): result = pg.train() print("Iteration {}, reward {}, timesteps {}".format( @@ -542,7 +542,7 @@ class TestMultiAgentEnv(unittest.TestCase): act_space = single_env.action_space return (None, obs_space, act_space, config) - pg = PGAgent( + pg = PGTrainer( env="multi_cartpole", config={ "num_workers": 0, diff --git a/python/ray/rllib/tests/test_nested_spaces.py b/python/ray/rllib/tests/test_nested_spaces.py index bc506152c..dc45ca3f6 100644 --- a/python/ray/rllib/tests/test_nested_spaces.py +++ b/python/ray/rllib/tests/test_nested_spaces.py @@ -12,8 +12,8 @@ import tensorflow as tf import unittest import ray -from ray.rllib.agents.a3c import A2CAgent -from ray.rllib.agents.pg import PGAgent +from ray.rllib.agents.a3c import A2CTrainer +from ray.rllib.agents.pg import PGTrainer from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph from ray.rllib.env import MultiAgentEnv from ray.rllib.env.base_env import BaseEnv @@ -215,7 +215,7 @@ class TupleSpyModel(Model): class NestedSpacesTest(unittest.TestCase): def testInvalidModel(self): ModelCatalog.register_custom_model("invalid", InvalidModel) - self.assertRaises(ValueError, lambda: PGAgent( + self.assertRaises(ValueError, lambda: PGTrainer( env="CartPole-v0", config={ "model": { "custom_model": "invalid", @@ -226,7 +226,7 @@ class NestedSpacesTest(unittest.TestCase): ModelCatalog.register_custom_model("invalid2", InvalidModel2) self.assertRaisesRegexp( ValueError, "Expected output.*", - lambda: PGAgent( + lambda: PGTrainer( env="CartPole-v0", config={ "model": { "custom_model": "invalid2", @@ -236,7 +236,7 @@ class NestedSpacesTest(unittest.TestCase): def doTestNestedDict(self, make_env, test_lstm=False): ModelCatalog.register_custom_model("composite", DictSpyModel) register_env("nested", make_env) - pg = PGAgent( + pg = PGTrainer( env="nested", config={ "num_workers": 0, @@ -265,7 +265,7 @@ class NestedSpacesTest(unittest.TestCase): def doTestNestedTuple(self, make_env): ModelCatalog.register_custom_model("composite2", TupleSpyModel) register_env("nested2", make_env) - pg = PGAgent( + pg = PGTrainer( env="nested2", config={ "num_workers": 0, @@ -323,7 +323,7 @@ class NestedSpacesTest(unittest.TestCase): ModelCatalog.register_custom_model("tuple_spy", TupleSpyModel) register_env("nested_ma", lambda _: NestedMultiAgentEnv()) act_space = spaces.Discrete(2) - pg = PGAgent( + pg = PGTrainer( env="nested_ma", config={ "num_workers": 0, @@ -370,13 +370,13 @@ class NestedSpacesTest(unittest.TestCase): def testRolloutDictSpace(self): register_env("nested", lambda _: NestedDictEnv()) - agent = PGAgent(env="nested") + agent = PGTrainer(env="nested") agent.train() path = agent.save() agent.stop() # Test train works on restore - agent2 = PGAgent(env="nested") + agent2 = PGTrainer(env="nested") agent2.restore(path) agent2.train() @@ -386,7 +386,7 @@ class NestedSpacesTest(unittest.TestCase): def testPyTorchModel(self): ModelCatalog.register_custom_model("composite", TorchSpyModel) register_env("nested", lambda _: NestedDictEnv()) - a2c = A2CAgent( + a2c = A2CTrainer( env="nested", config={ "num_workers": 0, diff --git a/python/ray/rllib/tests/test_optimizers.py b/python/ray/rllib/tests/test_optimizers.py index 90652ae9e..dcd1140a7 100644 --- a/python/ray/rllib/tests/test_optimizers.py +++ b/python/ray/rllib/tests/test_optimizers.py @@ -9,7 +9,7 @@ import time import unittest import ray -from ray.rllib.agents.ppo import PPOAgent +from ray.rllib.agents.ppo import PPOTrainer from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph from ray.rllib.evaluation import SampleBatch from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator @@ -41,7 +41,7 @@ class PPOCollectTest(unittest.TestCase): ray.init(num_cpus=4) # Check we at least collect the initial wave of samples - ppo = PPOAgent( + ppo = PPOTrainer( env="CartPole-v0", config={ "sample_batch_size": 200, @@ -53,7 +53,7 @@ class PPOCollectTest(unittest.TestCase): ppo.stop() # Check we collect at least the specified amount of samples - ppo = PPOAgent( + ppo = PPOTrainer( env="CartPole-v0", config={ "sample_batch_size": 200, @@ -65,7 +65,7 @@ class PPOCollectTest(unittest.TestCase): ppo.stop() # Check in vectorized mode - ppo = PPOAgent( + ppo = PPOTrainer( env="CartPole-v0", config={ "sample_batch_size": 200, @@ -78,7 +78,7 @@ class PPOCollectTest(unittest.TestCase): ppo.stop() # Check legacy mode - ppo = PPOAgent( + ppo = PPOTrainer( env="CartPole-v0", config={ "sample_batch_size": 200, diff --git a/python/ray/rllib/tests/test_policy_evaluator.py b/python/ray/rllib/tests/test_policy_evaluator.py index 093a2b00d..32160b0e5 100644 --- a/python/ray/rllib/tests/test_policy_evaluator.py +++ b/python/ray/rllib/tests/test_policy_evaluator.py @@ -10,8 +10,8 @@ import unittest from collections import Counter import ray -from ray.rllib.agents.pg import PGAgent -from ray.rllib.agents.a3c import A2CAgent +from ray.rllib.agents.pg import PGTrainer +from ray.rllib.agents.a3c import A2CTrainer from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.policy_graph import PolicyGraph @@ -170,7 +170,7 @@ class TestPolicyEvaluator(unittest.TestCase): print() def testGlobalVarsUpdate(self): - agent = A2CAgent( + agent = A2CTrainer( env="CartPole-v0", config={ "lr_schedule": [[0, 0.1], [400, 0.000001]], @@ -182,12 +182,12 @@ class TestPolicyEvaluator(unittest.TestCase): def testNoStepOnInit(self): register_env("fail", lambda _: FailOnStepEnv()) - pg = PGAgent(env="fail", config={"num_workers": 1}) + pg = PGTrainer(env="fail", config={"num_workers": 1}) self.assertRaises(Exception, lambda: pg.train()) def testCallbacks(self): counts = Counter() - pg = PGAgent( + pg = PGTrainer( env="CartPole-v0", config={ "num_workers": 0, "sample_batch_size": 50, @@ -211,7 +211,7 @@ class TestPolicyEvaluator(unittest.TestCase): def testQueryEvaluators(self): register_env("test", lambda _: gym.make("CartPole-v0")) - pg = PGAgent( + pg = PGTrainer( env="test", config={ "num_workers": 2, diff --git a/python/ray/rllib/utils/__init__.py b/python/ray/rllib/utils/__init__.py index c25cefeb4..5e7b2a141 100644 --- a/python/ray/rllib/utils/__init__.py +++ b/python/ray/rllib/utils/__init__.py @@ -1,10 +1,33 @@ +import logging + from ray.rllib.utils.filter_manager import FilterManager from ray.rllib.utils.filter import Filter from ray.rllib.utils.policy_client import PolicyClient from ray.rllib.utils.policy_server import PolicyServer from ray.tune.util import merge_dicts, deep_update +logger = logging.getLogger(__name__) + + +def renamed_class(cls): + class DeprecationWrapper(cls): + def __init__(self, *args, **kwargs): + old_name = cls.__name__.replace("Trainer", "Agent") + new_name = cls.__name__ + logger.warn("DeprecationWarning: {} has been renamed to {}. ". + format(old_name, new_name) + + "This will raise an error in the future.") + cls.__init__(self, *args, **kwargs) + + return DeprecationWrapper + + __all__ = [ - "Filter", "FilterManager", "PolicyClient", "PolicyServer", "merge_dicts", - "deep_update" + "Filter", + "FilterManager", + "PolicyClient", + "PolicyServer", + "merge_dicts", + "deep_update", + "renamed_class", ] diff --git a/python/ray/rllib/utils/annotations.py b/python/ray/rllib/utils/annotations.py index 2f22b63d9..0dcd960a4 100644 --- a/python/ray/rllib/utils/annotations.py +++ b/python/ray/rllib/utils/annotations.py @@ -27,10 +27,10 @@ def PublicAPI(obj): can expect these APIs to remain stable across RLlib releases. Subclasses that inherit from a ``@PublicAPI`` base class can be - assumed part of the RLlib public API as well (e.g., all agent classes - are in public API because Agent is ``@PublicAPI``). + assumed part of the RLlib public API as well (e.g., all trainer classes + are in public API because Trainer is ``@PublicAPI``). - In addition, you can assume all agent configurations are part of their + In addition, you can assume all trainer configurations are part of their public API as well. """ diff --git a/python/ray/rllib/utils/policy_server.py b/python/ray/rllib/utils/policy_server.py index 7e4f2443d..04dafc8ac 100644 --- a/python/ray/rllib/utils/policy_server.py +++ b/python/ray/rllib/utils/policy_server.py @@ -39,7 +39,7 @@ class PolicyServer(ThreadingMixIn, HTTPServer): server = PolicyServer(self, "localhost", 8900) server.serve_forever() >>> register_env("srv", lambda _: CartpoleServing()) - >>> pg = PGAgent(env="srv", config={"num_workers": 0}) + >>> pg = PGTrainer(env="srv", config={"num_workers": 0}) >>> while True: pg.train()