2018-12-21 03:44:34 +09:00
|
|
|
"""Registry of algorithm names for `rllib train --run=<alg_name>`"""
|
|
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
|
|
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
|
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_a2c():
|
|
|
|
from ray.rllib.agents import a3c
|
|
|
|
return a3c.A2CTrainer
|
|
|
|
|
|
|
|
|
|
|
|
def _import_a3c():
|
|
|
|
from ray.rllib.agents import a3c
|
|
|
|
return a3c.A3CTrainer
|
|
|
|
|
|
|
|
|
|
|
|
def _import_apex():
|
|
|
|
from ray.rllib.agents import dqn
|
|
|
|
return dqn.ApexTrainer
|
|
|
|
|
|
|
|
|
|
|
|
def _import_apex_ddpg():
|
|
|
|
from ray.rllib.agents import ddpg
|
|
|
|
return ddpg.ApexDDPGTrainer
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
|
2019-01-18 13:40:26 -08:00
|
|
|
def _import_appo():
|
|
|
|
from ray.rllib.agents import ppo
|
2019-04-07 00:36:18 -07:00
|
|
|
return ppo.APPOTrainer
|
2019-01-18 13:40:26 -08:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_ars():
|
|
|
|
from ray.rllib.agents import ars
|
|
|
|
return ars.ARSTrainer
|
2020-02-10 15:28:27 -08:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_bc():
|
|
|
|
from ray.rllib.agents import marwil
|
|
|
|
return marwil.BCTrainer
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
|
|
|
def _import_ddpg():
|
|
|
|
from ray.rllib.agents import ddpg
|
2019-04-07 00:36:18 -07:00
|
|
|
return ddpg.DDPGTrainer
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_ddppo():
|
|
|
|
from ray.rllib.agents import ppo
|
|
|
|
return ppo.DDPPOTrainer
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_dqn():
|
|
|
|
from ray.rllib.agents import dqn
|
|
|
|
return dqn.DQNTrainer
|
2019-04-26 17:49:53 -07:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_dreamer():
|
|
|
|
from ray.rllib.agents import dreamer
|
|
|
|
return dreamer.DREAMERTrainer
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
|
|
|
def _import_es():
|
|
|
|
from ray.rllib.agents import es
|
2019-04-07 00:36:18 -07:00
|
|
|
return es.ESTrainer
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_impala():
|
|
|
|
from ray.rllib.agents import impala
|
|
|
|
return impala.ImpalaTrainer
|
2019-07-03 15:59:47 -07:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_maml():
|
|
|
|
from ray.rllib.agents import maml
|
|
|
|
return maml.MAMLTrainer
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_marwil():
|
|
|
|
from ray.rllib.agents import marwil
|
|
|
|
return marwil.MARWILTrainer
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_mbmpo():
|
|
|
|
from ray.rllib.agents import mbmpo
|
|
|
|
return mbmpo.MBMPOTrainer
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
|
|
|
def _import_pg():
|
|
|
|
from ray.rllib.agents import pg
|
2019-04-07 00:36:18 -07:00
|
|
|
return pg.PGTrainer
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_ppo():
|
|
|
|
from ray.rllib.agents import ppo
|
|
|
|
return ppo.PPOTrainer
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_qmix():
|
|
|
|
from ray.rllib.agents import qmix
|
|
|
|
return qmix.QMixTrainer
|
2019-01-17 11:00:43 +08:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_sac():
|
|
|
|
from ray.rllib.agents import sac
|
|
|
|
return sac.SACTrainer
|
2020-06-23 09:48:23 -07:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_simple_q():
|
|
|
|
from ray.rllib.agents import dqn
|
|
|
|
return dqn.SimpleQTrainer
|
2020-08-02 09:12:09 -07:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_td3():
|
|
|
|
from ray.rllib.agents import ddpg
|
|
|
|
return ddpg.TD3Trainer
|
2020-08-26 04:24:05 -07:00
|
|
|
|
|
|
|
|
2018-12-21 03:44:34 +09:00
|
|
|
ALGORITHMS = {
|
2020-09-09 17:33:21 +02:00
|
|
|
"A2C": _import_a2c,
|
|
|
|
"A3C": _import_a3c,
|
|
|
|
"APEX": _import_apex,
|
2018-12-21 03:44:34 +09:00
|
|
|
"APEX_DDPG": _import_apex_ddpg,
|
2020-09-09 17:33:21 +02:00
|
|
|
"APPO": _import_appo,
|
2018-12-21 03:44:34 +09:00
|
|
|
"ARS": _import_ars,
|
2020-09-09 17:33:21 +02:00
|
|
|
"BC": _import_bc,
|
|
|
|
"ES": _import_es,
|
|
|
|
"DDPG": _import_ddpg,
|
|
|
|
"DDPPO": _import_ddppo,
|
2018-12-21 03:44:34 +09:00
|
|
|
"DQN": _import_dqn,
|
2020-09-09 17:33:21 +02:00
|
|
|
"DREAMER": _import_dreamer,
|
2018-12-21 03:44:34 +09:00
|
|
|
"IMPALA": _import_impala,
|
2020-06-23 09:48:23 -07:00
|
|
|
"MAML": _import_maml,
|
2020-09-09 17:33:21 +02:00
|
|
|
"MARWIL": _import_marwil,
|
2020-08-02 09:12:09 -07:00
|
|
|
"MBMPO": _import_mbmpo,
|
2020-09-09 17:33:21 +02:00
|
|
|
"PG": _import_pg,
|
|
|
|
"PPO": _import_ppo,
|
|
|
|
"QMIX": _import_qmix,
|
|
|
|
"SAC": _import_sac,
|
|
|
|
"SimpleQ": _import_simple_q,
|
|
|
|
"TD3": _import_td3,
|
2018-12-21 03:44:34 +09:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def get_agent_class(alg: str) -> type:
|
2018-12-21 03:44:34 +09:00
|
|
|
"""Returns the class of a known agent given its name."""
|
|
|
|
|
|
|
|
try:
|
|
|
|
return _get_agent_class(alg)
|
|
|
|
except ImportError:
|
|
|
|
from ray.rllib.agents.mock import _agent_import_failed
|
|
|
|
return _agent_import_failed(traceback.format_exc())
|
|
|
|
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def _get_agent_class(alg: str) -> type:
|
2018-12-21 03:44:34 +09:00
|
|
|
if alg in ALGORITHMS:
|
|
|
|
return ALGORITHMS[alg]()
|
|
|
|
elif alg in CONTRIBUTED_ALGORITHMS:
|
|
|
|
return CONTRIBUTED_ALGORITHMS[alg]()
|
|
|
|
elif alg == "script":
|
|
|
|
from ray.tune import script_runner
|
|
|
|
return script_runner.ScriptRunner
|
|
|
|
elif alg == "__fake":
|
2019-04-07 00:36:18 -07:00
|
|
|
from ray.rllib.agents.mock import _MockTrainer
|
|
|
|
return _MockTrainer
|
2018-12-21 03:44:34 +09:00
|
|
|
elif alg == "__sigmoid_fake_data":
|
|
|
|
from ray.rllib.agents.mock import _SigmoidFakeData
|
|
|
|
return _SigmoidFakeData
|
|
|
|
elif alg == "__parameter_tuning":
|
2019-04-07 00:36:18 -07:00
|
|
|
from ray.rllib.agents.mock import _ParameterTuningTrainer
|
|
|
|
return _ParameterTuningTrainer
|
2018-12-21 03:44:34 +09:00
|
|
|
else:
|
|
|
|
raise Exception(("Unknown algorithm {}.").format(alg))
|