mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00

* Remove all __future__ imports from RLlib. * Remove (object) again from tf_run_builder.py::TFRunBuilder. * Fix 2xLINT warnings. * Fix broken appo_policy import (must be appo_tf_policy) * Remove future imports from all other ray files (not just RLlib). * Remove future imports from all other ray files (not just RLlib). * Remove future import blocks that contain `unicode_literals` as well. Revert appo_tf_policy.py to appo_policy.py (belongs to another PR). * Add two empty lines before Schedule class. * Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
74 lines
2.3 KiB
Python
74 lines
2.3 KiB
Python
import logging
|
|
|
|
# Note: do not introduce unnecessary library dependencies here, e.g. gym.
|
|
# This file is imported from the tune module in order to register RLlib agents.
|
|
from ray.rllib.env.base_env import BaseEnv
|
|
from ray.rllib.env.external_env import ExternalEnv
|
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
|
from ray.rllib.env.vector_env import VectorEnv
|
|
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
|
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.policy.tf_policy import TFPolicy
|
|
|
|
from ray.tune.registry import register_trainable
|
|
|
|
|
|
def _setup_logger():
|
|
logger = logging.getLogger("ray.rllib")
|
|
handler = logging.StreamHandler()
|
|
handler.setFormatter(
|
|
logging.Formatter(
|
|
"%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s"
|
|
))
|
|
logger.addHandler(handler)
|
|
logger.propagate = False
|
|
|
|
|
|
def _register_all():
|
|
from ray.rllib.agents.trainer import Trainer, with_common_config
|
|
from ray.rllib.agents.registry import ALGORITHMS, get_agent_class
|
|
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
|
|
|
|
for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys(
|
|
)) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
|
|
register_trainable(key, get_agent_class(key))
|
|
|
|
def _see_contrib(name):
|
|
"""Returns dummy agent class warning algo is in contrib/."""
|
|
|
|
class _SeeContrib(Trainer):
|
|
_name = "SeeContrib"
|
|
_default_config = with_common_config({})
|
|
|
|
def _setup(self, config):
|
|
raise NameError(
|
|
"Please run `contrib/{}` instead.".format(name))
|
|
|
|
return _SeeContrib
|
|
|
|
# also register the aliases minus contrib/ to give a good error message
|
|
for key in list(CONTRIBUTED_ALGORITHMS.keys()):
|
|
assert key.startswith("contrib/")
|
|
alias = key.split("/", 1)[1]
|
|
register_trainable(alias, _see_contrib(alias))
|
|
|
|
|
|
_setup_logger()
|
|
_register_all()
|
|
|
|
__all__ = [
|
|
"Policy",
|
|
"PolicyGraph",
|
|
"TFPolicy",
|
|
"TFPolicyGraph",
|
|
"RolloutWorker",
|
|
"PolicyEvaluator",
|
|
"SampleBatch",
|
|
"BaseEnv",
|
|
"MultiAgentEnv",
|
|
"VectorEnv",
|
|
"ExternalEnv",
|
|
]
|