[rllib] [rfc] add contrib module and guideline for merging (#3565)

This adds guidelines for merging code into `rllib/contrib` vs `rllib/agents`. Also, clean up the agent import code to make registration easier.
This commit is contained in:
Eric Liang 2018-12-21 03:44:34 +09:00 committed by Richard Liaw
parent cf0c4745f4
commit 303883a3b6
17 changed files with 280 additions and 79 deletions

View file

@ -94,6 +94,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin
rllib-env.rst
rllib-algorithms.rst
rllib-models.rst
rllib-dev.rst
rllib-concepts.rst
rllib-package-ref.rst

View file

@ -248,7 +248,7 @@ Tuned examples: `Humanoid-v1 <https://github.com/ray-project/ray/blob/master/pyt
QMIX Monotonic Value Factorisation (QMIX, VDN, IQN)
---------------------------------------------------
`[paper] <https://arxiv.org/abs/1803.11485>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/qmix/qmix.py>`__ Q-Mix is a specialized multi-agent algorithm. Code here is adapted from https://github.com/oxwhirl/pymarl_alpha to integrate with RLlib multi-agent APIs. To use Q-Mix, you must specify an agent `grouping <https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/group_agents_wrapper.py>`__ in the environment (see the `two-step game example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/twostep_game.py>`__). Currently, all agents in the group must be homogeneous. The algorithm can be scaled by increasing the number of workers or using Ape-X.
`[paper] <https://arxiv.org/abs/1803.11485>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/qmix/qmix.py>`__ Q-Mix is a specialized multi-agent algorithm. Code here is adapted from https://github.com/oxwhirl/pymarl_alpha to integrate with RLlib multi-agent APIs. To use Q-Mix, you must specify an agent `grouping <rllib-env.html#grouping-agents>`__ in the environment (see the `two-step game example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/twostep_game.py>`__). Currently, all agents in the group must be homogeneous. The algorithm can be scaled by increasing the number of workers or using Ape-X.
Q-Mix is implemented in `PyTorch <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/qmix/qmix_policy_graph.py>`__ and is currently *experimental*.

65
doc/source/rllib-dev.rst Normal file
View file

@ -0,0 +1,65 @@
RLlib Development
=================
Features
--------
Feature development and upcoming priorities are tracked on the `RLlib project board <https://github.com/ray-project/ray/projects/6>`__ (note that this may not include all development efforts). For discussion of issues and new features, we use the `Ray dev list <https://groups.google.com/forum/#!forum/ray-dev>`__ and `GitHub issues page <https://github.com/ray-project/ray/issues>`__.
Benchmarks
----------
Currently we host a number of full training run results in the `rl-experiments repo <https://github.com/ray-project/rl-experiments>`__, and maintain a list of working hyperparameter configurations in `tuned_examples <https://github.com/ray-project/ray/tree/master/python/ray/rllib/tuned_examples>`__. Benchmark results are extremely valuable to the community, so if you happen to have results that may be of interest, consider making a pull request to either repo.
Contributing Algorithms
-----------------------
These are the guidelines for merging new algorithms into RLlib:
* Contributed algorithms (`rllib/contrib <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib>`__):
- must subclass Agent and implement the ``_train()`` method
- must include a lightweight test (<30s to run) to sanity check functionality
- should include tuned hyperparameter examples and documentation
- should offer functionality not present in existing algorithms
* Fully integrated algorithms (`rllib/agents <https://github.com/ray-project/ray/tree/master/python/ray/rllib/agents>`__) have the following additional requirements:
- must fully implement the Agent API
- must offer substantial new functionality not possible to add to other algorithms
- should support custom models and preprocessors
- should use RLlib abstractions and support distributed execution
Both integrated and contributed algorithms ship with the ``ray`` PyPI package, and are tested as part of Ray's automated tests. The main difference between contributed and fully integrated algorithms is that the latter will be maintained by the Ray team to a much greater extent with respect to bugs and integration with RLlib features.
How to add an algorithm to ``contrib``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
It takes just two changes to add an algorithm to `contrib <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib>`__. A minimal example can be found `here <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib/random_agent/random_agent.py>`__. First, subclass `Agent <https://github.com/ray-project/ray/tree/master/python/ray/rllib/agents/agent.py>`__ and implement the ``_init`` and ``_train`` methods:
.. literalinclude:: ../../python/ray/rllib/contrib/random_agent/random_agent.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
Second, register the agent with a name in `contrib/registry.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/contrib/registry.py>`__.
.. code-block:: python
def _import_random_agent():
from ray.rllib.contrib.random_agent.random_agent import RandomAgent
return RandomAgent
def _import_random_agent_2():
from ray.rllib.contrib.random_agent_2.random_agent_2 import RandomAgent2
return RandomAgent
CONTRIBUTED_ALGORITHMS = {
"contrib/RandomAgent": _import_random_agent,
"contrib/RandomAgent2": _import_random_agent_2,
# ...
}
After registration, you can run and visualize agent progress using ``rllib train``:
.. code-block:: bash
rllib train --run=contrib/RandomAgent --env=CartPole-v0
tensorboard --logdir=~/ray_results

View file

@ -20,6 +20,7 @@ APEX-DQN **Yes** `+parametric`_ No **Yes** No
APEX-DDPG No **Yes** **Yes** No
ES **Yes** **Yes** No No
ARS **Yes** **Yes** No No
QMIX **Yes** No **Yes** **Yes**
============= ======================= ================== =========== ==================
.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces

View file

@ -68,7 +68,7 @@ Algorithms
- `Evolution Strategies <rllib-algorithms.html#evolution-strategies>`__
* Multi-agent / specialized
* Multi-agent specific
- `QMIX Monotonic Value Factorisation (QMIX, VDN, IQN) <rllib-algorithms.html#qmix-monotonic-value-factorisation-qmix-vdn-iqn>`__
@ -82,6 +82,13 @@ Models and Preprocessors
* `Variable-length / Parametric Action Spaces <rllib-models.html#variable-length-parametric-action-spaces>`__
* `Model-Based Rollouts <rllib-models.html#model-based-rollouts>`__
RLlib Development
-----------------
* `Features <rllib-dev.html#feature-development>`__
* `Benchmarks <rllib-dev.html#benchmarks>`__
* `Contributing Algorithms <rllib-dev.html#contributing-algorithms>`__
RLlib Concepts
--------------
* `Policy Graphs <rllib-concepts.html>`__

View file

@ -31,12 +31,11 @@ def _setup_logger():
def _register_all():
for key in [
"PPO", "ES", "DQN", "APEX", "A3C", "PG", "DDPG", "APEX_DDPG",
"IMPALA", "ARS", "A2C", "QMIX", "APEX_QMIX", "__fake",
"__sigmoid_fake_data", "__parameter_tuning"
]:
from ray.rllib.agents.agent import get_agent_class
from ray.rllib.agents.registry import ALGORITHMS
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys(
)) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
from ray.rllib.agents.registry import get_agent_class
register_trainable(key, get_agent_class(key))

View file

@ -10,7 +10,6 @@ import pickle
import six
import tempfile
import tensorflow as tf
import traceback
from types import FunctionType
import ray
@ -542,69 +541,3 @@ def _register_if_needed(env_object):
name = env_object.__name__
register_env(name, lambda config: env_object(config))
return name
def get_agent_class(alg):
"""Returns the class of a known agent given its name."""
try:
return _get_agent_class(alg)
except ImportError:
from ray.rllib.agents.mock import _agent_import_failed
return _agent_import_failed(traceback.format_exc())
def _get_agent_class(alg):
if alg == "DDPG":
from ray.rllib.agents import ddpg
return ddpg.DDPGAgent
elif alg == "APEX_DDPG":
from ray.rllib.agents import ddpg
return ddpg.ApexDDPGAgent
elif alg == "PPO":
from ray.rllib.agents import ppo
return ppo.PPOAgent
elif alg == "ES":
from ray.rllib.agents import es
return es.ESAgent
elif alg == "ARS":
from ray.rllib.agents import ars
return ars.ARSAgent
elif alg == "DQN":
from ray.rllib.agents import dqn
return dqn.DQNAgent
elif alg == "APEX":
from ray.rllib.agents import dqn
return dqn.ApexAgent
elif alg == "A3C":
from ray.rllib.agents import a3c
return a3c.A3CAgent
elif alg == "A2C":
from ray.rllib.agents import a3c
return a3c.A2CAgent
elif alg == "PG":
from ray.rllib.agents import pg
return pg.PGAgent
elif alg == "IMPALA":
from ray.rllib.agents import impala
return impala.ImpalaAgent
elif alg == "QMIX":
from ray.rllib.agents import qmix
return qmix.QMixAgent
elif alg == "APEX_QMIX":
from ray.rllib.agents import qmix
return qmix.ApexQMixAgent
elif alg == "script":
from ray.tune import script_runner
return script_runner.ScriptRunner
elif alg == "__fake":
from ray.rllib.agents.mock import _MockAgent
return _MockAgent
elif alg == "__sigmoid_fake_data":
from ray.rllib.agents.mock import _SigmoidFakeData
return _SigmoidFakeData
elif alg == "__parameter_tuning":
from ray.rllib.agents.mock import _ParameterTuningAgent
return _ParameterTuningAgent
else:
raise Exception(("Unknown algorithm {}.").format(alg))

View file

@ -0,0 +1,122 @@
"""Registry of algorithm names for `rllib train --run=<alg_name>`"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import traceback
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
def _import_qmix():
from ray.rllib.agents import qmix
return qmix.QMixAgent
def _import_apex_qmix():
from ray.rllib.agents import qmix
return qmix.ApexQMixAgent
def _import_ddpg():
from ray.rllib.agents import ddpg
return ddpg.DDPGAgent
def _import_apex_ddpg():
from ray.rllib.agents import ddpg
return ddpg.ApexDDPGAgent
def _import_ppo():
from ray.rllib.agents import ppo
return ppo.PPOAgent
def _import_es():
from ray.rllib.agents import es
return es.ESAgent
def _import_ars():
from ray.rllib.agents import ars
return ars.ARSAgent
def _import_dqn():
from ray.rllib.agents import dqn
return dqn.DQNAgent
def _import_apex():
from ray.rllib.agents import dqn
return dqn.ApexAgent
def _import_a3c():
from ray.rllib.agents import a3c
return a3c.A3CAgent
def _import_a2c():
from ray.rllib.agents import a3c
return a3c.A2CAgent
def _import_pg():
from ray.rllib.agents import pg
return pg.PGAgent
def _import_impala():
from ray.rllib.agents import impala
return impala.ImpalaAgent
ALGORITHMS = {
"DDPG": _import_ddpg,
"APEX_DDPG": _import_apex_ddpg,
"PPO": _import_ppo,
"ES": _import_es,
"ARS": _import_ars,
"DQN": _import_dqn,
"APEX": _import_apex,
"A3C": _import_a3c,
"A2C": _import_a2c,
"PG": _import_pg,
"IMPALA": _import_impala,
"QMIX": _import_qmix,
"APEX_QMIX": _import_apex_qmix,
}
def get_agent_class(alg):
"""Returns the class of a known agent given its name."""
try:
return _get_agent_class(alg)
except ImportError:
from ray.rllib.agents.mock import _agent_import_failed
return _agent_import_failed(traceback.format_exc())
def _get_agent_class(alg):
if alg in ALGORITHMS:
return ALGORITHMS[alg]()
elif alg in CONTRIBUTED_ALGORITHMS:
return CONTRIBUTED_ALGORITHMS[alg]()
elif alg == "script":
from ray.tune import script_runner
return script_runner.ScriptRunner
elif alg == "__fake":
from ray.rllib.agents.mock import _MockAgent
return _MockAgent
elif alg == "__sigmoid_fake_data":
from ray.rllib.agents.mock import _SigmoidFakeData
return _SigmoidFakeData
elif alg == "__parameter_tuning":
from ray.rllib.agents.mock import _ParameterTuningAgent
return _ParameterTuningAgent
else:
raise Exception(("Unknown algorithm {}.").format(alg))

View file

@ -0,0 +1,3 @@
Contributed algorithms, which can be run via `rllib train --run=contrib/<alg_name>`
See https://ray.readthedocs.io/en/latest/rllib-dev.html for guidelines.

View file

View file

@ -0,0 +1,52 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.utils.annotations import override
# yapf: disable
# __sphinx_doc_begin__
class RandomAgent(Agent):
"""Agent that takes random actions and never learns."""
_agent_name = "RandomAgent"
_default_config = with_common_config({
"rollouts_per_iteration": 10,
})
@override(Agent)
def _init(self):
self.env = self.env_creator(self.config["env_config"])
@override(Agent)
def _train(self):
rewards = []
steps = 0
for _ in range(self.config["rollouts_per_iteration"]):
obs = self.env.reset()
done = False
reward = 0.0
while not done:
action = self.env.action_space.sample()
obs, r, done, info = self.env.step(action)
reward += r
steps += 1
rewards.append(reward)
return {
"episode_reward_mean": np.mean(rewards),
"timesteps_this_iter": steps,
}
# __sphinx_doc_end__
# don't enable yapf after, it's buggy here
if __name__ == "__main__":
agent = RandomAgent(
env="CartPole-v0", config={"rollouts_per_iteration": 10})
result = agent.train()
assert result["episode_reward_mean"] > 10, result
print("Test: OK")

View file

@ -0,0 +1,15 @@
"""Registry of algorithm names for `rllib train --run=<alg_name>`"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def _import_random_agent():
from ray.rllib.contrib.random_agent.random_agent import RandomAgent
return RandomAgent
CONTRIBUTED_ALGORITHMS = {
"contrib/RandomAgent": _import_random_agent,
}

View file

@ -11,7 +11,7 @@ import pickle
import gym
import ray
from ray.rllib.agents.agent import get_agent_class
from ray.rllib.agents.registry import get_agent_class
EXAMPLE_USAGE = """
Example Usage via RLlib CLI:

View file

@ -7,7 +7,7 @@ from __future__ import print_function
import numpy as np
import ray
from ray.rllib.agents.agent import get_agent_class
from ray.rllib.agents.registry import get_agent_class
def get_mean_action(alg, obs):

View file

@ -8,7 +8,7 @@ import numpy as np
import sys
import ray
from ray.rllib.agents.agent import get_agent_class
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.test.test_multi_agent_env import MultiCartpole, MultiMountainCar
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.tune.registry import register_env

View file

@ -339,4 +339,4 @@ class Trial(object):
identifier = self.trainable_name
if self.experiment_tag:
identifier += "_" + self.experiment_tag
return identifier
return identifier.replace("/", "_")

View file

@ -385,6 +385,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/rllib/examples/custom_metrics_and_callbacks.py --num-iters=2
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/rllib/contrib/random_agent/random_agent.py
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/rllib/examples/twostep_game.py --stop=2000 --run=PG