mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
154 lines
3.3 KiB
Python
154 lines
3.3 KiB
Python
"""Registry of algorithm names for `rllib train --run=<alg_name>`"""
|
|
|
|
import traceback
|
|
|
|
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
|
|
|
|
|
|
def _import_sac():
|
|
from ray.rllib.agents import sac
|
|
return sac.SACTrainer
|
|
|
|
|
|
def _import_appo():
|
|
from ray.rllib.agents import ppo
|
|
return ppo.APPOTrainer
|
|
|
|
|
|
def _import_ddppo():
|
|
from ray.rllib.agents import ppo
|
|
return ppo.DDPPOTrainer
|
|
|
|
|
|
def _import_qmix():
|
|
from ray.rllib.agents import qmix
|
|
return qmix.QMixTrainer
|
|
|
|
|
|
def _import_apex_qmix():
|
|
from ray.rllib.agents import qmix
|
|
return qmix.ApexQMixTrainer
|
|
|
|
|
|
def _import_ddpg():
|
|
from ray.rllib.agents import ddpg
|
|
return ddpg.DDPGTrainer
|
|
|
|
|
|
def _import_apex_ddpg():
|
|
from ray.rllib.agents import ddpg
|
|
return ddpg.ApexDDPGTrainer
|
|
|
|
|
|
def _import_td3():
|
|
from ray.rllib.agents import ddpg
|
|
return ddpg.TD3Trainer
|
|
|
|
|
|
def _import_ppo():
|
|
from ray.rllib.agents import ppo
|
|
return ppo.PPOTrainer
|
|
|
|
|
|
def _import_es():
|
|
from ray.rllib.agents import es
|
|
return es.ESTrainer
|
|
|
|
|
|
def _import_ars():
|
|
from ray.rllib.agents import ars
|
|
return ars.ARSTrainer
|
|
|
|
|
|
def _import_dqn():
|
|
from ray.rllib.agents import dqn
|
|
return dqn.DQNTrainer
|
|
|
|
|
|
def _import_simple_q():
|
|
from ray.rllib.agents import dqn
|
|
return dqn.SimpleQTrainer
|
|
|
|
|
|
def _import_apex():
|
|
from ray.rllib.agents import dqn
|
|
return dqn.ApexTrainer
|
|
|
|
|
|
def _import_a3c():
|
|
from ray.rllib.agents import a3c
|
|
return a3c.A3CTrainer
|
|
|
|
|
|
def _import_a2c():
|
|
from ray.rllib.agents import a3c
|
|
return a3c.A2CTrainer
|
|
|
|
|
|
def _import_pg():
|
|
from ray.rllib.agents import pg
|
|
return pg.PGTrainer
|
|
|
|
|
|
def _import_impala():
|
|
from ray.rllib.agents import impala
|
|
return impala.ImpalaTrainer
|
|
|
|
|
|
def _import_marwil():
|
|
from ray.rllib.agents import marwil
|
|
return marwil.MARWILTrainer
|
|
|
|
|
|
ALGORITHMS = {
|
|
"SAC": _import_sac,
|
|
"DDPG": _import_ddpg,
|
|
"APEX_DDPG": _import_apex_ddpg,
|
|
"TD3": _import_td3,
|
|
"PPO": _import_ppo,
|
|
"ES": _import_es,
|
|
"ARS": _import_ars,
|
|
"DQN": _import_dqn,
|
|
"SimpleQ": _import_simple_q,
|
|
"APEX": _import_apex,
|
|
"A3C": _import_a3c,
|
|
"A2C": _import_a2c,
|
|
"PG": _import_pg,
|
|
"IMPALA": _import_impala,
|
|
"QMIX": _import_qmix,
|
|
"APEX_QMIX": _import_apex_qmix,
|
|
"APPO": _import_appo,
|
|
"DDPPO": _import_ddppo,
|
|
"MARWIL": _import_marwil,
|
|
}
|
|
|
|
|
|
def get_agent_class(alg):
|
|
"""Returns the class of a known agent given its name."""
|
|
|
|
try:
|
|
return _get_agent_class(alg)
|
|
except ImportError:
|
|
from ray.rllib.agents.mock import _agent_import_failed
|
|
return _agent_import_failed(traceback.format_exc())
|
|
|
|
|
|
def _get_agent_class(alg):
|
|
if alg in ALGORITHMS:
|
|
return ALGORITHMS[alg]()
|
|
elif alg in CONTRIBUTED_ALGORITHMS:
|
|
return CONTRIBUTED_ALGORITHMS[alg]()
|
|
elif alg == "script":
|
|
from ray.tune import script_runner
|
|
return script_runner.ScriptRunner
|
|
elif alg == "__fake":
|
|
from ray.rllib.agents.mock import _MockTrainer
|
|
return _MockTrainer
|
|
elif alg == "__sigmoid_fake_data":
|
|
from ray.rllib.agents.mock import _SigmoidFakeData
|
|
return _SigmoidFakeData
|
|
elif alg == "__parameter_tuning":
|
|
from ray.rllib.agents.mock import _ParameterTuningTrainer
|
|
return _ParameterTuningTrainer
|
|
else:
|
|
raise Exception(("Unknown algorithm {}.").format(alg))
|