2019-04-07 00:36:18 -07:00
|
|
|
import logging
|
2019-05-10 20:36:18 -07:00
|
|
|
import os
|
2019-04-07 00:36:18 -07:00
|
|
|
|
2017-12-30 00:24:54 -08:00
|
|
|
from ray.rllib.utils.filter_manager import FilterManager
|
2018-07-01 00:05:08 -07:00
|
|
|
from ray.rllib.utils.filter import Filter
|
|
|
|
from ray.rllib.utils.policy_client import PolicyClient
|
|
|
|
from ray.rllib.utils.policy_server import PolicyServer
|
2018-12-26 03:07:11 -08:00
|
|
|
from ray.tune.util import merge_dicts, deep_update
|
2017-12-30 00:24:54 -08:00
|
|
|
|
2019-04-07 00:36:18 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
def renamed_class(cls, old_name):
|
|
|
|
"""Helper class for renaming classes with a warning."""
|
|
|
|
|
|
|
|
class DeprecationWrapper(cls):
|
|
|
|
# note: **kw not supported for ray.remote classes
|
|
|
|
def __init__(self, *args, **kw):
|
|
|
|
new_name = cls.__module__ + "." + cls.__name__
|
|
|
|
logger.warn("DeprecationWarning: {} has been renamed to {}. ".
|
|
|
|
format(old_name, new_name) +
|
|
|
|
"This will raise an error in the future.")
|
|
|
|
cls.__init__(self, *args, **kw)
|
|
|
|
|
|
|
|
DeprecationWrapper.__name__ = cls.__name__
|
|
|
|
|
|
|
|
return DeprecationWrapper
|
|
|
|
|
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
def add_mixins(base, mixins):
|
|
|
|
"""Returns a new class with mixins applied in priority order."""
|
|
|
|
|
|
|
|
mixins = list(mixins or [])
|
|
|
|
|
|
|
|
while mixins:
|
|
|
|
|
|
|
|
class new_base(mixins.pop(), base):
|
|
|
|
pass
|
|
|
|
|
|
|
|
base = new_base
|
|
|
|
|
|
|
|
return base
|
|
|
|
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
def renamed_agent(cls):
|
2019-04-15 09:12:23 -07:00
|
|
|
"""Helper class for renaming Agent => Trainer with a warning."""
|
|
|
|
|
2019-04-07 00:36:18 -07:00
|
|
|
class DeprecationWrapper(cls):
|
2019-04-15 09:12:23 -07:00
|
|
|
def __init__(self, config=None, env=None, logger_creator=None):
|
2019-04-07 00:36:18 -07:00
|
|
|
old_name = cls.__name__.replace("Trainer", "Agent")
|
2019-05-20 16:46:05 -07:00
|
|
|
new_name = cls.__module__ + "." + cls.__name__
|
2019-04-07 00:36:18 -07:00
|
|
|
logger.warn("DeprecationWarning: {} has been renamed to {}. ".
|
|
|
|
format(old_name, new_name) +
|
|
|
|
"This will raise an error in the future.")
|
2019-04-15 09:12:23 -07:00
|
|
|
cls.__init__(self, config, env, logger_creator)
|
|
|
|
|
|
|
|
DeprecationWrapper.__name__ = cls.__name__
|
2019-04-07 00:36:18 -07:00
|
|
|
|
|
|
|
return DeprecationWrapper
|
|
|
|
|
|
|
|
|
2019-05-10 20:36:18 -07:00
|
|
|
def try_import_tf():
|
|
|
|
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
|
|
|
logger.warning("Not importing TensorFlow for test purposes")
|
|
|
|
return None
|
|
|
|
|
|
|
|
try:
|
2019-08-05 13:23:54 -07:00
|
|
|
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
2019-05-16 22:12:07 -07:00
|
|
|
import tensorflow.compat.v1 as tf
|
2019-08-05 13:23:54 -07:00
|
|
|
tf.logging.set_verbosity(tf.logging.ERROR)
|
2019-05-16 22:12:07 -07:00
|
|
|
tf.disable_v2_behavior()
|
2019-05-10 20:36:18 -07:00
|
|
|
return tf
|
|
|
|
except ImportError:
|
2019-05-16 22:12:07 -07:00
|
|
|
try:
|
|
|
|
import tensorflow as tf
|
|
|
|
return tf
|
|
|
|
except ImportError:
|
|
|
|
return None
|
2019-05-10 20:36:18 -07:00
|
|
|
|
|
|
|
|
2019-08-01 23:37:36 -07:00
|
|
|
def try_import_tfp():
|
|
|
|
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
|
|
|
logger.warning(
|
|
|
|
"Not importing TensorFlow Probability for test purposes.")
|
|
|
|
return None
|
|
|
|
|
|
|
|
try:
|
|
|
|
import tensorflow_probability as tfp
|
|
|
|
return tfp
|
|
|
|
except ImportError:
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
2018-12-26 03:07:11 -08:00
|
|
|
__all__ = [
|
2019-04-07 00:36:18 -07:00
|
|
|
"Filter",
|
|
|
|
"FilterManager",
|
|
|
|
"PolicyClient",
|
|
|
|
"PolicyServer",
|
|
|
|
"merge_dicts",
|
|
|
|
"deep_update",
|
|
|
|
"renamed_class",
|
2019-05-10 20:36:18 -07:00
|
|
|
"try_import_tf",
|
2018-12-26 03:07:11 -08:00
|
|
|
]
|