2018-10-21 23:43:57 -07:00
|
|
|
import logging
|
|
|
|
|
2022-06-17 08:41:18 +02:00
|
|
|
from ray._private.usage import usage_lib
|
|
|
|
|
2018-01-24 16:55:17 -08:00
|
|
|
# Note: do not introduce unnecessary library dependencies here, e.g. gym.
|
|
|
|
# This file is imported from the tune module in order to register RLlib agents.
|
2019-01-23 21:27:26 -08:00
|
|
|
from ray.rllib.env.base_env import BaseEnv
|
2019-08-06 19:22:06 -04:00
|
|
|
from ray.rllib.env.external_env import ExternalEnv
|
2018-07-01 00:05:08 -07:00
|
|
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
|
|
|
from ray.rllib.env.vector_env import VectorEnv
|
2019-08-06 19:22:06 -04:00
|
|
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2019-08-06 19:22:06 -04:00
|
|
|
from ray.rllib.policy.tf_policy import TFPolicy
|
2020-08-20 17:05:57 +02:00
|
|
|
from ray.rllib.policy.torch_policy import TorchPolicy
|
2019-08-06 19:22:06 -04:00
|
|
|
from ray.tune.registry import register_trainable
|
2018-06-09 00:21:35 -07:00
|
|
|
|
2017-11-20 17:52:43 -08:00
|
|
|
|
2018-10-21 23:43:57 -07:00
|
|
|
def _setup_logger():
|
|
|
|
logger = logging.getLogger("ray.rllib")
|
|
|
|
handler = logging.StreamHandler()
|
|
|
|
handler.setFormatter(
|
|
|
|
logging.Formatter(
|
|
|
|
"%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s"
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2018-10-21 23:43:57 -07:00
|
|
|
logger.addHandler(handler)
|
|
|
|
logger.propagate = False
|
|
|
|
|
|
|
|
|
2017-11-20 17:52:43 -08:00
|
|
|
def _register_all():
|
2022-06-11 15:10:39 +02:00
|
|
|
from ray.rllib.algorithms.algorithm import Algorithm
|
|
|
|
from ray.rllib.algorithms.registry import ALGORITHMS, get_algorithm_class
|
2018-12-21 03:44:34 +09:00
|
|
|
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
|
2019-08-06 19:22:06 -04:00
|
|
|
|
2018-12-21 03:44:34 +09:00
|
|
|
for key in (
|
|
|
|
list(ALGORITHMS.keys())
|
|
|
|
+ list(CONTRIBUTED_ALGORITHMS.keys())
|
|
|
|
+ ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]
|
|
|
|
):
|
2022-06-11 15:10:39 +02:00
|
|
|
register_trainable(key, get_algorithm_class(key))
|
2017-11-20 17:52:43 -08:00
|
|
|
|
2019-08-06 19:22:06 -04:00
|
|
|
def _see_contrib(name):
|
|
|
|
"""Returns dummy agent class warning algo is in contrib/."""
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
class _SeeContrib(Algorithm):
|
2020-07-01 11:00:00 -07:00
|
|
|
def setup(self, config):
|
2019-08-06 19:22:06 -04:00
|
|
|
raise NameError("Please run `contrib/{}` instead.".format(name))
|
|
|
|
|
|
|
|
return _SeeContrib
|
|
|
|
|
2022-05-06 12:35:21 +02:00
|
|
|
# Also register the aliases minus contrib/ to give a good error message.
|
2019-08-06 19:22:06 -04:00
|
|
|
for key in list(CONTRIBUTED_ALGORITHMS.keys()):
|
|
|
|
assert key.startswith("contrib/")
|
|
|
|
alias = key.split("/", 1)[1]
|
2022-05-06 12:35:21 +02:00
|
|
|
if alias not in ALGORITHMS:
|
|
|
|
register_trainable(alias, _see_contrib(alias))
|
2019-08-06 19:22:06 -04:00
|
|
|
|
2017-11-20 17:52:43 -08:00
|
|
|
|
2018-10-21 23:43:57 -07:00
|
|
|
_setup_logger()
|
2018-06-09 00:21:35 -07:00
|
|
|
|
2022-04-30 07:51:01 -07:00
|
|
|
usage_lib.record_library_usage("rllib")
|
|
|
|
|
2018-06-09 00:21:35 -07:00
|
|
|
__all__ = [
|
2019-05-20 16:46:05 -07:00
|
|
|
"Policy",
|
|
|
|
"TFPolicy",
|
2020-08-20 17:05:57 +02:00
|
|
|
"TorchPolicy",
|
2019-06-03 06:49:24 +08:00
|
|
|
"RolloutWorker",
|
2018-07-19 15:30:36 -07:00
|
|
|
"SampleBatch",
|
2019-01-23 21:27:26 -08:00
|
|
|
"BaseEnv",
|
2018-07-19 15:30:36 -07:00
|
|
|
"MultiAgentEnv",
|
|
|
|
"VectorEnv",
|
2018-11-12 16:31:27 -08:00
|
|
|
"ExternalEnv",
|
2018-06-09 00:21:35 -07:00
|
|
|
]
|