2018-12-21 03:44:34 +09:00
|
|
|
"""Registry of algorithm names for `rllib train --run=<alg_name>`"""
|
|
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
|
|
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
|
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_a2c():
|
2022-06-01 09:29:16 +02:00
|
|
|
import ray.rllib.algorithms.a2c as a2c
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return a2c.A2C, a2c.A2CConfig().to_dict()
|
2020-09-09 17:33:21 +02:00
|
|
|
|
|
|
|
|
|
|
|
def _import_a3c():
|
2022-06-01 09:29:16 +02:00
|
|
|
import ray.rllib.algorithms.a3c as a3c
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return a3c.A3C, a3c.A3CConfig().to_dict()
|
2020-09-09 17:33:21 +02:00
|
|
|
|
|
|
|
|
2022-02-08 16:43:00 +01:00
|
|
|
def _import_alpha_star():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.alpha_star as alpha_star
|
2022-02-08 16:43:00 +01:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return alpha_star.AlphaStar, alpha_star.AlphaStarConfig().to_dict()
|
2022-02-08 16:43:00 +01:00
|
|
|
|
|
|
|
|
2022-05-18 09:58:25 +02:00
|
|
|
def _import_alpha_zero():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.alpha_zero as alpha_zero
|
2022-05-18 09:58:25 +02:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return alpha_zero.AlphaZero, alpha_zero.AlphaZeroConfig().to_dict()
|
2022-05-18 09:58:25 +02:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_apex():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.apex_dqn as apex_dqn
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return apex_dqn.ApexDQN, apex_dqn.ApexDQNConfig().to_dict()
|
2020-09-09 17:33:21 +02:00
|
|
|
|
|
|
|
|
|
|
|
def _import_apex_ddpg():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.apex_ddpg as apex_ddpg
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return apex_ddpg.ApexDDPG, apex_ddpg.ApexDDPGConfig().to_dict()
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
|
2019-01-18 13:40:26 -08:00
|
|
|
def _import_appo():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.appo as appo
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return appo.APPO, appo.APPOConfig().to_dict()
|
2019-01-18 13:40:26 -08:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_ars():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.ars as ars
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return ars.ARS, ars.ARSConfig().to_dict()
|
2020-02-10 15:28:27 -08:00
|
|
|
|
|
|
|
|
2022-01-27 13:58:12 +01:00
|
|
|
def _import_bandit_lints():
|
2022-06-04 07:35:24 +02:00
|
|
|
from ray.rllib.algorithms.bandit.bandit import BanditLinTS
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return BanditLinTS, BanditLinTS.get_default_config()
|
2022-01-27 13:58:12 +01:00
|
|
|
|
|
|
|
|
|
|
|
def _import_bandit_linucb():
|
2022-06-04 07:35:24 +02:00
|
|
|
from ray.rllib.algorithms.bandit.bandit import BanditLinUCB
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return BanditLinUCB, BanditLinUCB.get_default_config()
|
2022-01-27 13:58:12 +01:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_bc():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.bc as bc
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return bc.BC, bc.BCConfig().to_dict()
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-12-30 07:11:57 -08:00
|
|
|
def _import_cql():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.cql as cql
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return cql.CQL, cql.CQLConfig().to_dict()
|
2020-12-30 07:11:57 -08:00
|
|
|
|
|
|
|
|
2022-06-08 02:42:02 -07:00
|
|
|
def _import_crr():
|
|
|
|
from ray.rllib.algorithms import crr
|
|
|
|
|
2022-07-17 14:25:53 -07:00
|
|
|
return crr.CRR, crr.CRRConfig().to_dict()
|
2022-06-08 02:42:02 -07:00
|
|
|
|
|
|
|
|
2018-12-21 03:44:34 +09:00
|
|
|
def _import_ddpg():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.ddpg as ddpg
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return ddpg.DDPG, ddpg.DDPGConfig().to_dict()
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_ddppo():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.ddppo as ddppo
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return ddppo.DDPPO, ddppo.DDPPOConfig().to_dict()
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_dqn():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.dqn as dqn
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return dqn.DQN, dqn.DQNConfig().to_dict()
|
2019-04-26 17:49:53 -07:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_dreamer():
|
2022-06-11 15:10:39 +02:00
|
|
|
import ray.rllib.algorithms.dreamer as dreamer
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return dreamer.Dreamer, dreamer.DreamerConfig().to_dict()
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
|
|
|
def _import_es():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.es as es
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return es.ES, es.ESConfig().to_dict()
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_impala():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.impala as impala
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return impala.Impala, impala.ImpalaConfig().to_dict()
|
2019-07-03 15:59:47 -07:00
|
|
|
|
|
|
|
|
2022-05-06 12:35:21 +02:00
|
|
|
def _import_maddpg():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.maddpg as maddpg
|
2022-05-06 12:35:21 +02:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return maddpg.MADDPG, maddpg.MADDPGConfig().to_dict()
|
2022-05-06 12:35:21 +02:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_maml():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.maml as maml
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return maml.MAML, maml.MAMLConfig().to_dict()
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_marwil():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.marwil as marwil
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return marwil.MARWIL, marwil.MARWILConfig().to_dict()
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_mbmpo():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.mbmpo as mbmpo
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return mbmpo.MBMPO, mbmpo.MBMPOConfig().to_dict()
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
|
|
|
def _import_pg():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.pg as pg
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return pg.PG, pg.PGConfig().to_dict()
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_ppo():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.ppo as ppo
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return ppo.PPO, ppo.PPOConfig().to_dict()
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_qmix():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.qmix as qmix
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return qmix.QMix, qmix.QMixConfig().to_dict()
|
2019-01-17 11:00:43 +08:00
|
|
|
|
|
|
|
|
2021-02-25 12:18:11 +01:00
|
|
|
def _import_r2d2():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.r2d2 as r2d2
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return r2d2.R2D2, r2d2.R2D2Config().to_dict()
|
2021-02-25 12:18:11 +01:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_sac():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.sac as sac
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return sac.SAC, sac.SACConfig().to_dict()
|
2020-06-23 09:48:23 -07:00
|
|
|
|
|
|
|
|
2021-07-25 16:04:52 +02:00
|
|
|
def _import_rnnsac():
|
2022-05-19 09:30:42 -07:00
|
|
|
from ray.rllib.algorithms import sac
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return sac.RNNSAC, sac.RNNSACConfig().to_dict()
|
2021-07-25 16:04:52 +02:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_simple_q():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.simple_q as simple_q
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return simple_q.SimpleQ, simple_q.SimpleQConfig().to_dict()
|
2020-08-02 09:12:09 -07:00
|
|
|
|
|
|
|
|
2020-11-03 00:52:04 -08:00
|
|
|
def _import_slate_q():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.slateq as slateq
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return slateq.SlateQ, slateq.SlateQConfig().to_dict()
|
2020-11-03 00:52:04 -08:00
|
|
|
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
def _import_td3():
|
2022-06-04 07:35:24 +02:00
|
|
|
import ray.rllib.algorithms.td3 as td3
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-04 07:35:24 +02:00
|
|
|
return td3.TD3, td3.TD3Config().to_dict()
|
2020-08-26 04:24:05 -07:00
|
|
|
|
|
|
|
|
2018-12-21 03:44:34 +09:00
|
|
|
ALGORITHMS = {
|
2020-09-09 17:33:21 +02:00
|
|
|
"A2C": _import_a2c,
|
|
|
|
"A3C": _import_a3c,
|
2022-05-18 09:58:25 +02:00
|
|
|
"AlphaZero": _import_alpha_zero,
|
2020-09-09 17:33:21 +02:00
|
|
|
"APEX": _import_apex,
|
2018-12-21 03:44:34 +09:00
|
|
|
"APEX_DDPG": _import_apex_ddpg,
|
|
|
|
"ARS": _import_ars,
|
2022-01-27 13:58:12 +01:00
|
|
|
"BanditLinTS": _import_bandit_lints,
|
|
|
|
"BanditLinUCB": _import_bandit_linucb,
|
2020-09-09 17:33:21 +02:00
|
|
|
"BC": _import_bc,
|
2020-12-30 07:11:57 -08:00
|
|
|
"CQL": _import_cql,
|
2022-06-08 02:42:02 -07:00
|
|
|
"CRR": _import_crr,
|
2020-09-09 17:33:21 +02:00
|
|
|
"ES": _import_es,
|
|
|
|
"DDPG": _import_ddpg,
|
|
|
|
"DDPPO": _import_ddppo,
|
2018-12-21 03:44:34 +09:00
|
|
|
"DQN": _import_dqn,
|
2022-06-04 07:35:24 +02:00
|
|
|
"Dreamer": _import_dreamer,
|
2018-12-21 03:44:34 +09:00
|
|
|
"IMPALA": _import_impala,
|
2022-06-04 07:35:24 +02:00
|
|
|
"APPO": _import_appo,
|
|
|
|
"AlphaStar": _import_alpha_star,
|
2022-05-06 12:35:21 +02:00
|
|
|
"MADDPG": _import_maddpg,
|
2020-06-23 09:48:23 -07:00
|
|
|
"MAML": _import_maml,
|
2020-09-09 17:33:21 +02:00
|
|
|
"MARWIL": _import_marwil,
|
2020-08-02 09:12:09 -07:00
|
|
|
"MBMPO": _import_mbmpo,
|
2020-09-09 17:33:21 +02:00
|
|
|
"PG": _import_pg,
|
|
|
|
"PPO": _import_ppo,
|
|
|
|
"QMIX": _import_qmix,
|
2021-02-25 12:18:11 +01:00
|
|
|
"R2D2": _import_r2d2,
|
2021-07-25 16:04:52 +02:00
|
|
|
"RNNSAC": _import_rnnsac,
|
2022-01-27 13:58:12 +01:00
|
|
|
"SAC": _import_sac,
|
2020-09-09 17:33:21 +02:00
|
|
|
"SimpleQ": _import_simple_q,
|
2022-01-27 13:58:12 +01:00
|
|
|
"SlateQ": _import_slate_q,
|
2020-09-09 17:33:21 +02:00
|
|
|
"TD3": _import_td3,
|
2018-12-21 03:44:34 +09:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
def get_algorithm_class(alg: str, return_config=False) -> type:
|
2021-02-08 12:05:16 +01:00
|
|
|
"""Returns the class of a known Trainer given its name."""
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
try:
|
2022-06-11 15:10:39 +02:00
|
|
|
return _get_algorithm_class(alg, return_config=return_config)
|
2018-12-21 03:44:34 +09:00
|
|
|
except ImportError:
|
2022-06-11 15:10:39 +02:00
|
|
|
from ray.rllib.algorithms.mock import _algorithm_import_failed
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
class_ = _algorithm_import_failed(traceback.format_exc())
|
2021-11-23 23:01:05 +01:00
|
|
|
config = class_.get_default_config()
|
2021-02-08 12:05:16 +01:00
|
|
|
if return_config:
|
|
|
|
return class_, config
|
|
|
|
return class_
|
|
|
|
|
2018-12-21 03:44:34 +09:00
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
# Backward compat alias.
|
|
|
|
get_trainer_class = get_algorithm_class
|
|
|
|
|
|
|
|
|
|
|
|
def _get_algorithm_class(alg: str, return_config=False) -> type:
|
2018-12-21 03:44:34 +09:00
|
|
|
if alg in ALGORITHMS:
|
2021-02-08 12:05:16 +01:00
|
|
|
class_, config = ALGORITHMS[alg]()
|
2018-12-21 03:44:34 +09:00
|
|
|
elif alg in CONTRIBUTED_ALGORITHMS:
|
2021-02-08 12:05:16 +01:00
|
|
|
class_, config = CONTRIBUTED_ALGORITHMS[alg]()
|
2018-12-21 03:44:34 +09:00
|
|
|
elif alg == "script":
|
|
|
|
from ray.tune import script_runner
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2021-02-08 12:05:16 +01:00
|
|
|
class_, config = script_runner.ScriptRunner, {}
|
2018-12-21 03:44:34 +09:00
|
|
|
elif alg == "__fake":
|
2022-06-11 15:10:39 +02:00
|
|
|
from ray.rllib.algorithms.mock import _MockTrainer
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2021-11-23 23:01:05 +01:00
|
|
|
class_, config = _MockTrainer, _MockTrainer.get_default_config()
|
2018-12-21 03:44:34 +09:00
|
|
|
elif alg == "__sigmoid_fake_data":
|
2022-06-11 15:10:39 +02:00
|
|
|
from ray.rllib.algorithms.mock import _SigmoidFakeData
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2021-11-23 23:01:05 +01:00
|
|
|
class_, config = _SigmoidFakeData, _SigmoidFakeData.get_default_config()
|
2018-12-21 03:44:34 +09:00
|
|
|
elif alg == "__parameter_tuning":
|
2022-06-11 15:10:39 +02:00
|
|
|
from ray.rllib.algorithms.mock import _ParameterTuningTrainer
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2021-02-08 12:05:16 +01:00
|
|
|
class_, config = (
|
|
|
|
_ParameterTuningTrainer,
|
2021-11-23 23:01:05 +01:00
|
|
|
_ParameterTuningTrainer.get_default_config(),
|
|
|
|
)
|
2018-12-21 03:44:34 +09:00
|
|
|
else:
|
2022-06-17 08:41:18 +02:00
|
|
|
raise Exception("Unknown algorithm {}.".format(alg))
|
2021-02-08 12:05:16 +01:00
|
|
|
|
|
|
|
if return_config:
|
|
|
|
return class_, config
|
|
|
|
return class_
|