"""Registry of algorithm names for `rllib train --run=`""" import traceback from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS def _import_a2c(): from ray.rllib.agents import a3c return a3c.A2CTrainer, a3c.a2c.A2C_DEFAULT_CONFIG def _import_a3c(): from ray.rllib.agents import a3c return a3c.A3CTrainer, a3c.DEFAULT_CONFIG def _import_alpha_star(): from ray.rllib.agents.alpha_star.alpha_star import AlphaStarTrainer, DEFAULT_CONFIG return AlphaStarTrainer, DEFAULT_CONFIG def _import_apex(): from ray.rllib.agents import dqn return dqn.ApexTrainer, dqn.apex.APEX_DEFAULT_CONFIG def _import_apex_ddpg(): from ray.rllib.agents import ddpg return ddpg.ApexDDPGTrainer, ddpg.apex.APEX_DDPG_DEFAULT_CONFIG def _import_appo(): from ray.rllib.agents import ppo return ppo.APPOTrainer, ppo.appo.DEFAULT_CONFIG def _import_ars(): from ray.rllib.agents import ars return ars.ARSTrainer, ars.DEFAULT_CONFIG def _import_bandit_lints(): from ray.rllib.agents.bandit.bandit import BanditLinTSTrainer return BanditLinTSTrainer, BanditLinTSTrainer.get_default_config() def _import_bandit_linucb(): from ray.rllib.agents.bandit.bandit import BanditLinUCBTrainer return BanditLinUCBTrainer, BanditLinUCBTrainer.get_default_config() def _import_bc(): from ray.rllib.agents import marwil return marwil.BCTrainer, marwil.DEFAULT_CONFIG def _import_cql(): from ray.rllib.agents import cql return cql.CQLTrainer, cql.CQL_DEFAULT_CONFIG def _import_ddpg(): from ray.rllib.agents import ddpg return ddpg.DDPGTrainer, ddpg.DEFAULT_CONFIG def _import_ddppo(): from ray.rllib.agents import ppo return ppo.DDPPOTrainer, ppo.DEFAULT_CONFIG def _import_dqn(): from ray.rllib.agents import dqn return dqn.DQNTrainer, dqn.DEFAULT_CONFIG def _import_dreamer(): from ray.rllib.agents import dreamer return dreamer.DREAMERTrainer, dreamer.DEFAULT_CONFIG def _import_es(): from ray.rllib.agents import es return es.ESTrainer, es.DEFAULT_CONFIG def _import_impala(): from ray.rllib.agents import impala return impala.ImpalaTrainer, impala.DEFAULT_CONFIG def _import_maml(): from ray.rllib.agents import maml return maml.MAMLTrainer, maml.DEFAULT_CONFIG def _import_marwil(): from ray.rllib.agents import marwil return marwil.MARWILTrainer, marwil.DEFAULT_CONFIG def _import_mbmpo(): from ray.rllib.agents import mbmpo return mbmpo.MBMPOTrainer, mbmpo.DEFAULT_CONFIG def _import_pg(): from ray.rllib.agents import pg return pg.PGTrainer, pg.DEFAULT_CONFIG def _import_ppo(): from ray.rllib.agents import ppo return ppo.PPOTrainer, ppo.DEFAULT_CONFIG def _import_qmix(): from ray.rllib.agents import qmix return qmix.QMixTrainer, qmix.DEFAULT_CONFIG def _import_r2d2(): from ray.rllib.agents import dqn return dqn.R2D2Trainer, dqn.R2D2_DEFAULT_CONFIG def _import_sac(): from ray.rllib.agents import sac return sac.SACTrainer, sac.DEFAULT_CONFIG def _import_rnnsac(): from ray.rllib.agents import sac return sac.RNNSACTrainer, sac.RNNSAC_DEFAULT_CONFIG def _import_simple_q(): from ray.rllib.agents import dqn return dqn.SimpleQTrainer, dqn.simple_q.DEFAULT_CONFIG def _import_slate_q(): from ray.rllib.agents import slateq return slateq.SlateQTrainer, slateq.DEFAULT_CONFIG def _import_td3(): from ray.rllib.agents import ddpg return ddpg.TD3Trainer, ddpg.td3.TD3_DEFAULT_CONFIG ALGORITHMS = { "A2C": _import_a2c, "A3C": _import_a3c, "APPO": _import_appo, "APEX": _import_apex, "APEX_DDPG": _import_apex_ddpg, "ARS": _import_ars, "BanditLinTS": _import_bandit_lints, "BanditLinUCB": _import_bandit_linucb, "BC": _import_bc, "CQL": _import_cql, "ES": _import_es, "DDPG": _import_ddpg, "DDPPO": _import_ddppo, "DQN": _import_dqn, "DREAMER": _import_dreamer, "IMPALA": _import_impala, "MAML": _import_maml, "MARWIL": _import_marwil, "MBMPO": _import_mbmpo, "PG": _import_pg, "PPO": _import_ppo, "QMIX": _import_qmix, "R2D2": _import_r2d2, "RNNSAC": _import_rnnsac, "SAC": _import_sac, "SimpleQ": _import_simple_q, "SlateQ": _import_slate_q, "TD3": _import_td3, "AlphaStar": _import_alpha_star, } def get_trainer_class(alg: str, return_config=False) -> type: """Returns the class of a known Trainer given its name.""" try: return _get_trainer_class(alg, return_config=return_config) except ImportError: from ray.rllib.agents.mock import _trainer_import_failed class_ = _trainer_import_failed(traceback.format_exc()) config = class_.get_default_config() if return_config: return class_, config return class_ def _get_trainer_class(alg: str, return_config=False) -> type: if alg in ALGORITHMS: class_, config = ALGORITHMS[alg]() elif alg in CONTRIBUTED_ALGORITHMS: class_, config = CONTRIBUTED_ALGORITHMS[alg]() elif alg == "script": from ray.tune import script_runner class_, config = script_runner.ScriptRunner, {} elif alg == "__fake": from ray.rllib.agents.mock import _MockTrainer class_, config = _MockTrainer, _MockTrainer.get_default_config() elif alg == "__sigmoid_fake_data": from ray.rllib.agents.mock import _SigmoidFakeData class_, config = _SigmoidFakeData, _SigmoidFakeData.get_default_config() elif alg == "__parameter_tuning": from ray.rllib.agents.mock import _ParameterTuningTrainer class_, config = ( _ParameterTuningTrainer, _ParameterTuningTrainer.get_default_config(), ) else: raise Exception(("Unknown algorithm {}.").format(alg)) if return_config: return class_, config return class_