mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] [rfc] add contrib module and guideline for merging (#3565)
This adds guidelines for merging code into `rllib/contrib` vs `rllib/agents`. Also, clean up the agent import code to make registration easier.
This commit is contained in:
parent
cf0c4745f4
commit
303883a3b6
17 changed files with 280 additions and 79 deletions
|
@ -94,6 +94,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin
|
|||
rllib-env.rst
|
||||
rllib-algorithms.rst
|
||||
rllib-models.rst
|
||||
rllib-dev.rst
|
||||
rllib-concepts.rst
|
||||
rllib-package-ref.rst
|
||||
|
||||
|
|
|
@ -248,7 +248,7 @@ Tuned examples: `Humanoid-v1 <https://github.com/ray-project/ray/blob/master/pyt
|
|||
|
||||
QMIX Monotonic Value Factorisation (QMIX, VDN, IQN)
|
||||
---------------------------------------------------
|
||||
`[paper] <https://arxiv.org/abs/1803.11485>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/qmix/qmix.py>`__ Q-Mix is a specialized multi-agent algorithm. Code here is adapted from https://github.com/oxwhirl/pymarl_alpha to integrate with RLlib multi-agent APIs. To use Q-Mix, you must specify an agent `grouping <https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/group_agents_wrapper.py>`__ in the environment (see the `two-step game example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/twostep_game.py>`__). Currently, all agents in the group must be homogeneous. The algorithm can be scaled by increasing the number of workers or using Ape-X.
|
||||
`[paper] <https://arxiv.org/abs/1803.11485>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/qmix/qmix.py>`__ Q-Mix is a specialized multi-agent algorithm. Code here is adapted from https://github.com/oxwhirl/pymarl_alpha to integrate with RLlib multi-agent APIs. To use Q-Mix, you must specify an agent `grouping <rllib-env.html#grouping-agents>`__ in the environment (see the `two-step game example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/twostep_game.py>`__). Currently, all agents in the group must be homogeneous. The algorithm can be scaled by increasing the number of workers or using Ape-X.
|
||||
|
||||
Q-Mix is implemented in `PyTorch <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/qmix/qmix_policy_graph.py>`__ and is currently *experimental*.
|
||||
|
||||
|
|
65
doc/source/rllib-dev.rst
Normal file
65
doc/source/rllib-dev.rst
Normal file
|
@ -0,0 +1,65 @@
|
|||
RLlib Development
|
||||
=================
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
Feature development and upcoming priorities are tracked on the `RLlib project board <https://github.com/ray-project/ray/projects/6>`__ (note that this may not include all development efforts). For discussion of issues and new features, we use the `Ray dev list <https://groups.google.com/forum/#!forum/ray-dev>`__ and `GitHub issues page <https://github.com/ray-project/ray/issues>`__.
|
||||
|
||||
Benchmarks
|
||||
----------
|
||||
|
||||
Currently we host a number of full training run results in the `rl-experiments repo <https://github.com/ray-project/rl-experiments>`__, and maintain a list of working hyperparameter configurations in `tuned_examples <https://github.com/ray-project/ray/tree/master/python/ray/rllib/tuned_examples>`__. Benchmark results are extremely valuable to the community, so if you happen to have results that may be of interest, consider making a pull request to either repo.
|
||||
|
||||
Contributing Algorithms
|
||||
-----------------------
|
||||
|
||||
These are the guidelines for merging new algorithms into RLlib:
|
||||
|
||||
* Contributed algorithms (`rllib/contrib <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib>`__):
|
||||
- must subclass Agent and implement the ``_train()`` method
|
||||
- must include a lightweight test (<30s to run) to sanity check functionality
|
||||
- should include tuned hyperparameter examples and documentation
|
||||
- should offer functionality not present in existing algorithms
|
||||
|
||||
* Fully integrated algorithms (`rllib/agents <https://github.com/ray-project/ray/tree/master/python/ray/rllib/agents>`__) have the following additional requirements:
|
||||
- must fully implement the Agent API
|
||||
- must offer substantial new functionality not possible to add to other algorithms
|
||||
- should support custom models and preprocessors
|
||||
- should use RLlib abstractions and support distributed execution
|
||||
|
||||
Both integrated and contributed algorithms ship with the ``ray`` PyPI package, and are tested as part of Ray's automated tests. The main difference between contributed and fully integrated algorithms is that the latter will be maintained by the Ray team to a much greater extent with respect to bugs and integration with RLlib features.
|
||||
|
||||
How to add an algorithm to ``contrib``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
It takes just two changes to add an algorithm to `contrib <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib>`__. A minimal example can be found `here <https://github.com/ray-project/ray/tree/master/python/ray/rllib/contrib/random_agent/random_agent.py>`__. First, subclass `Agent <https://github.com/ray-project/ray/tree/master/python/ray/rllib/agents/agent.py>`__ and implement the ``_init`` and ``_train`` methods:
|
||||
|
||||
.. literalinclude:: ../../python/ray/rllib/contrib/random_agent/random_agent.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
||||
Second, register the agent with a name in `contrib/registry.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/contrib/registry.py>`__.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def _import_random_agent():
|
||||
from ray.rllib.contrib.random_agent.random_agent import RandomAgent
|
||||
return RandomAgent
|
||||
|
||||
def _import_random_agent_2():
|
||||
from ray.rllib.contrib.random_agent_2.random_agent_2 import RandomAgent2
|
||||
return RandomAgent
|
||||
|
||||
CONTRIBUTED_ALGORITHMS = {
|
||||
"contrib/RandomAgent": _import_random_agent,
|
||||
"contrib/RandomAgent2": _import_random_agent_2,
|
||||
# ...
|
||||
}
|
||||
|
||||
After registration, you can run and visualize agent progress using ``rllib train``:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
rllib train --run=contrib/RandomAgent --env=CartPole-v0
|
||||
tensorboard --logdir=~/ray_results
|
|
@ -20,6 +20,7 @@ APEX-DQN **Yes** `+parametric`_ No **Yes** No
|
|||
APEX-DDPG No **Yes** **Yes** No
|
||||
ES **Yes** **Yes** No No
|
||||
ARS **Yes** **Yes** No No
|
||||
QMIX **Yes** No **Yes** **Yes**
|
||||
============= ======================= ================== =========== ==================
|
||||
|
||||
.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces
|
||||
|
|
|
@ -68,7 +68,7 @@ Algorithms
|
|||
|
||||
- `Evolution Strategies <rllib-algorithms.html#evolution-strategies>`__
|
||||
|
||||
* Multi-agent / specialized
|
||||
* Multi-agent specific
|
||||
|
||||
- `QMIX Monotonic Value Factorisation (QMIX, VDN, IQN) <rllib-algorithms.html#qmix-monotonic-value-factorisation-qmix-vdn-iqn>`__
|
||||
|
||||
|
@ -82,6 +82,13 @@ Models and Preprocessors
|
|||
* `Variable-length / Parametric Action Spaces <rllib-models.html#variable-length-parametric-action-spaces>`__
|
||||
* `Model-Based Rollouts <rllib-models.html#model-based-rollouts>`__
|
||||
|
||||
RLlib Development
|
||||
-----------------
|
||||
|
||||
* `Features <rllib-dev.html#feature-development>`__
|
||||
* `Benchmarks <rllib-dev.html#benchmarks>`__
|
||||
* `Contributing Algorithms <rllib-dev.html#contributing-algorithms>`__
|
||||
|
||||
RLlib Concepts
|
||||
--------------
|
||||
* `Policy Graphs <rllib-concepts.html>`__
|
||||
|
|
|
@ -31,12 +31,11 @@ def _setup_logger():
|
|||
|
||||
def _register_all():
|
||||
|
||||
for key in [
|
||||
"PPO", "ES", "DQN", "APEX", "A3C", "PG", "DDPG", "APEX_DDPG",
|
||||
"IMPALA", "ARS", "A2C", "QMIX", "APEX_QMIX", "__fake",
|
||||
"__sigmoid_fake_data", "__parameter_tuning"
|
||||
]:
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
from ray.rllib.agents.registry import ALGORITHMS
|
||||
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
|
||||
for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys(
|
||||
)) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
register_trainable(key, get_agent_class(key))
|
||||
|
||||
|
||||
|
|
|
@ -10,7 +10,6 @@ import pickle
|
|||
import six
|
||||
import tempfile
|
||||
import tensorflow as tf
|
||||
import traceback
|
||||
from types import FunctionType
|
||||
|
||||
import ray
|
||||
|
@ -542,69 +541,3 @@ def _register_if_needed(env_object):
|
|||
name = env_object.__name__
|
||||
register_env(name, lambda config: env_object(config))
|
||||
return name
|
||||
|
||||
|
||||
def get_agent_class(alg):
|
||||
"""Returns the class of a known agent given its name."""
|
||||
|
||||
try:
|
||||
return _get_agent_class(alg)
|
||||
except ImportError:
|
||||
from ray.rllib.agents.mock import _agent_import_failed
|
||||
return _agent_import_failed(traceback.format_exc())
|
||||
|
||||
|
||||
def _get_agent_class(alg):
|
||||
if alg == "DDPG":
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.DDPGAgent
|
||||
elif alg == "APEX_DDPG":
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.ApexDDPGAgent
|
||||
elif alg == "PPO":
|
||||
from ray.rllib.agents import ppo
|
||||
return ppo.PPOAgent
|
||||
elif alg == "ES":
|
||||
from ray.rllib.agents import es
|
||||
return es.ESAgent
|
||||
elif alg == "ARS":
|
||||
from ray.rllib.agents import ars
|
||||
return ars.ARSAgent
|
||||
elif alg == "DQN":
|
||||
from ray.rllib.agents import dqn
|
||||
return dqn.DQNAgent
|
||||
elif alg == "APEX":
|
||||
from ray.rllib.agents import dqn
|
||||
return dqn.ApexAgent
|
||||
elif alg == "A3C":
|
||||
from ray.rllib.agents import a3c
|
||||
return a3c.A3CAgent
|
||||
elif alg == "A2C":
|
||||
from ray.rllib.agents import a3c
|
||||
return a3c.A2CAgent
|
||||
elif alg == "PG":
|
||||
from ray.rllib.agents import pg
|
||||
return pg.PGAgent
|
||||
elif alg == "IMPALA":
|
||||
from ray.rllib.agents import impala
|
||||
return impala.ImpalaAgent
|
||||
elif alg == "QMIX":
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.QMixAgent
|
||||
elif alg == "APEX_QMIX":
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.ApexQMixAgent
|
||||
elif alg == "script":
|
||||
from ray.tune import script_runner
|
||||
return script_runner.ScriptRunner
|
||||
elif alg == "__fake":
|
||||
from ray.rllib.agents.mock import _MockAgent
|
||||
return _MockAgent
|
||||
elif alg == "__sigmoid_fake_data":
|
||||
from ray.rllib.agents.mock import _SigmoidFakeData
|
||||
return _SigmoidFakeData
|
||||
elif alg == "__parameter_tuning":
|
||||
from ray.rllib.agents.mock import _ParameterTuningAgent
|
||||
return _ParameterTuningAgent
|
||||
else:
|
||||
raise Exception(("Unknown algorithm {}.").format(alg))
|
||||
|
|
122
python/ray/rllib/agents/registry.py
Normal file
122
python/ray/rllib/agents/registry.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
"""Registry of algorithm names for `rllib train --run=<alg_name>`"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import traceback
|
||||
|
||||
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
|
||||
|
||||
|
||||
def _import_qmix():
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.QMixAgent
|
||||
|
||||
|
||||
def _import_apex_qmix():
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.ApexQMixAgent
|
||||
|
||||
|
||||
def _import_ddpg():
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.DDPGAgent
|
||||
|
||||
|
||||
def _import_apex_ddpg():
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.ApexDDPGAgent
|
||||
|
||||
|
||||
def _import_ppo():
|
||||
from ray.rllib.agents import ppo
|
||||
return ppo.PPOAgent
|
||||
|
||||
|
||||
def _import_es():
|
||||
from ray.rllib.agents import es
|
||||
return es.ESAgent
|
||||
|
||||
|
||||
def _import_ars():
|
||||
from ray.rllib.agents import ars
|
||||
return ars.ARSAgent
|
||||
|
||||
|
||||
def _import_dqn():
|
||||
from ray.rllib.agents import dqn
|
||||
return dqn.DQNAgent
|
||||
|
||||
|
||||
def _import_apex():
|
||||
from ray.rllib.agents import dqn
|
||||
return dqn.ApexAgent
|
||||
|
||||
|
||||
def _import_a3c():
|
||||
from ray.rllib.agents import a3c
|
||||
return a3c.A3CAgent
|
||||
|
||||
|
||||
def _import_a2c():
|
||||
from ray.rllib.agents import a3c
|
||||
return a3c.A2CAgent
|
||||
|
||||
|
||||
def _import_pg():
|
||||
from ray.rllib.agents import pg
|
||||
return pg.PGAgent
|
||||
|
||||
|
||||
def _import_impala():
|
||||
from ray.rllib.agents import impala
|
||||
return impala.ImpalaAgent
|
||||
|
||||
|
||||
ALGORITHMS = {
|
||||
"DDPG": _import_ddpg,
|
||||
"APEX_DDPG": _import_apex_ddpg,
|
||||
"PPO": _import_ppo,
|
||||
"ES": _import_es,
|
||||
"ARS": _import_ars,
|
||||
"DQN": _import_dqn,
|
||||
"APEX": _import_apex,
|
||||
"A3C": _import_a3c,
|
||||
"A2C": _import_a2c,
|
||||
"PG": _import_pg,
|
||||
"IMPALA": _import_impala,
|
||||
"QMIX": _import_qmix,
|
||||
"APEX_QMIX": _import_apex_qmix,
|
||||
}
|
||||
|
||||
|
||||
def get_agent_class(alg):
|
||||
"""Returns the class of a known agent given its name."""
|
||||
|
||||
try:
|
||||
return _get_agent_class(alg)
|
||||
except ImportError:
|
||||
from ray.rllib.agents.mock import _agent_import_failed
|
||||
return _agent_import_failed(traceback.format_exc())
|
||||
|
||||
|
||||
def _get_agent_class(alg):
|
||||
if alg in ALGORITHMS:
|
||||
return ALGORITHMS[alg]()
|
||||
elif alg in CONTRIBUTED_ALGORITHMS:
|
||||
return CONTRIBUTED_ALGORITHMS[alg]()
|
||||
elif alg == "script":
|
||||
from ray.tune import script_runner
|
||||
return script_runner.ScriptRunner
|
||||
elif alg == "__fake":
|
||||
from ray.rllib.agents.mock import _MockAgent
|
||||
return _MockAgent
|
||||
elif alg == "__sigmoid_fake_data":
|
||||
from ray.rllib.agents.mock import _SigmoidFakeData
|
||||
return _SigmoidFakeData
|
||||
elif alg == "__parameter_tuning":
|
||||
from ray.rllib.agents.mock import _ParameterTuningAgent
|
||||
return _ParameterTuningAgent
|
||||
else:
|
||||
raise Exception(("Unknown algorithm {}.").format(alg))
|
3
python/ray/rllib/contrib/README.rst
Normal file
3
python/ray/rllib/contrib/README.rst
Normal file
|
@ -0,0 +1,3 @@
|
|||
Contributed algorithms, which can be run via `rllib train --run=contrib/<alg_name>`
|
||||
|
||||
See https://ray.readthedocs.io/en/latest/rllib-dev.html for guidelines.
|
0
python/ray/rllib/contrib/__init__.py
Normal file
0
python/ray/rllib/contrib/__init__.py
Normal file
52
python/ray/rllib/contrib/random_agent/random_agent.py
Normal file
52
python/ray/rllib/contrib/random_agent/random_agent.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
class RandomAgent(Agent):
|
||||
"""Agent that takes random actions and never learns."""
|
||||
|
||||
_agent_name = "RandomAgent"
|
||||
_default_config = with_common_config({
|
||||
"rollouts_per_iteration": 10,
|
||||
})
|
||||
|
||||
@override(Agent)
|
||||
def _init(self):
|
||||
self.env = self.env_creator(self.config["env_config"])
|
||||
|
||||
@override(Agent)
|
||||
def _train(self):
|
||||
rewards = []
|
||||
steps = 0
|
||||
for _ in range(self.config["rollouts_per_iteration"]):
|
||||
obs = self.env.reset()
|
||||
done = False
|
||||
reward = 0.0
|
||||
while not done:
|
||||
action = self.env.action_space.sample()
|
||||
obs, r, done, info = self.env.step(action)
|
||||
reward += r
|
||||
steps += 1
|
||||
rewards.append(reward)
|
||||
return {
|
||||
"episode_reward_mean": np.mean(rewards),
|
||||
"timesteps_this_iter": steps,
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# don't enable yapf after, it's buggy here
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
agent = RandomAgent(
|
||||
env="CartPole-v0", config={"rollouts_per_iteration": 10})
|
||||
result = agent.train()
|
||||
assert result["episode_reward_mean"] > 10, result
|
||||
print("Test: OK")
|
15
python/ray/rllib/contrib/registry.py
Normal file
15
python/ray/rllib/contrib/registry.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
"""Registry of algorithm names for `rllib train --run=<alg_name>`"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
def _import_random_agent():
|
||||
from ray.rllib.contrib.random_agent.random_agent import RandomAgent
|
||||
return RandomAgent
|
||||
|
||||
|
||||
CONTRIBUTED_ALGORITHMS = {
|
||||
"contrib/RandomAgent": _import_random_agent,
|
||||
}
|
|
@ -11,7 +11,7 @@ import pickle
|
|||
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
Example Usage via RLlib CLI:
|
||||
|
|
|
@ -7,7 +7,7 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
import ray
|
||||
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
|
||||
|
||||
def get_mean_action(alg, obs):
|
||||
|
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
import sys
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.rllib.test.test_multi_agent_env import MultiCartpole, MultiMountainCar
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.tune.registry import register_env
|
||||
|
|
|
@ -339,4 +339,4 @@ class Trial(object):
|
|||
identifier = self.trainable_name
|
||||
if self.experiment_tag:
|
||||
identifier += "_" + self.experiment_tag
|
||||
return identifier
|
||||
return identifier.replace("/", "_")
|
||||
|
|
|
@ -385,6 +385,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
|||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/examples/custom_metrics_and_callbacks.py --num-iters=2
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/contrib/random_agent/random_agent.py
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/examples/twostep_game.py --stop=2000 --run=PG
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue