import logging

# Note: do not introduce unnecessary library dependencies here, e.g. gym.
# This file is imported from the tune module in order to register RLlib agents.
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import TFPolicy

from ray.tune.registry import register_trainable


def _setup_logger():
    logger = logging.getLogger("ray.rllib")
    handler = logging.StreamHandler()
    handler.setFormatter(
        logging.Formatter(
            "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s"
        ))
    logger.addHandler(handler)
    logger.propagate = False


def _register_all():
    from ray.rllib.agents.trainer import Trainer, with_common_config
    from ray.rllib.agents.registry import ALGORITHMS, get_agent_class
    from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS

    for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys(
    )) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
        register_trainable(key, get_agent_class(key))

    def _see_contrib(name):
        """Returns dummy agent class warning algo is in contrib/."""

        class _SeeContrib(Trainer):
            _name = "SeeContrib"
            _default_config = with_common_config({})

            def _setup(self, config):
                raise NameError(
                    "Please run `contrib/{}` instead.".format(name))

        return _SeeContrib

    # also register the aliases minus contrib/ to give a good error message
    for key in list(CONTRIBUTED_ALGORITHMS.keys()):
        assert key.startswith("contrib/")
        alias = key.split("/", 1)[1]
        register_trainable(alias, _see_contrib(alias))


_setup_logger()
_register_all()

__all__ = [
    "Policy",
    "PolicyGraph",
    "TFPolicy",
    "TFPolicyGraph",
    "RolloutWorker",
    "PolicyEvaluator",
    "SampleBatch",
    "BaseEnv",
    "MultiAgentEnv",
    "VectorEnv",
    "ExternalEnv",
]