[RLlib] Fix test_dependency_torch and fix custom logger support for RLlib. (#15120)

This commit is contained in:
Sven Mika 2021-04-24 08:13:41 +02:00 committed by GitHub
parent 57c0bd9912
commit 354c960fff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 160 additions and 25 deletions

View file

@ -506,19 +506,36 @@ class Trainer(Trainable):
# Create a default logger creator if no logger_creator is specified
if logger_creator is None:
# Default logdir prefix containing the agent's name and the
# env id.
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
logdir_prefix = "{}_{}_{}".format(self._name, self._env_id,
timestr)
if not os.path.exists(DEFAULT_RESULTS_DIR):
os.makedirs(DEFAULT_RESULTS_DIR)
logdir = tempfile.mkdtemp(
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
def default_logger_creator(config):
"""Creates a Unified logger with a default logdir prefix
containing the agent name and the env id
"""
if not os.path.exists(DEFAULT_RESULTS_DIR):
os.makedirs(DEFAULT_RESULTS_DIR)
logdir = tempfile.mkdtemp(
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
return UnifiedLogger(config, logdir, loggers=None)
# Allow users to more precisely configure the created logger
# via "logger_config.type".
if config.get(
"logger_config") and "type" in config["logger_config"]:
def default_logger_creator(config):
"""Creates a custom logger with the default prefix."""
cfg = config["logger_config"].copy()
cls = cfg.pop("type")
# Provide default for logdir, in case the user does
# not specify this in the "logger_config" dict.
logdir_ = cfg.pop("logdir", logdir)
return from_config(cls=cls, _args=[cfg], logdir=logdir_)
# If no `type` given, use tune's UnifiedLogger as last resort.
else:
def default_logger_creator(config):
"""Creates a Unified logger with the default prefix."""
return UnifiedLogger(config, logdir, loggers=None)
logger_creator = default_logger_creator

View file

@ -0,0 +1,115 @@
"""
This example script demonstrates how one can define a custom logger
object for any RLlib Trainer via the Trainer's config dict's
"logger_config" key.
By default (logger_config=None), RLlib will construct a tune
UnifiedLogger object, which logs JSON, CSV, and TBX output.
Below examples include:
- Disable logging entirely.
- Using only one of tune's Json, CSV, or TBX loggers.
- Defining a custom logger (by sub-classing tune.logger.py::Logger).
"""
import argparse
import os
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.tune.logger import Logger
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
parser.add_argument("--as-test", action="store_true")
parser.add_argument("--stop-iters", type=int, default=200)
parser.add_argument("--stop-timesteps", type=int, default=100000)
parser.add_argument("--stop-reward", type=float, default=150.0)
class MyPrintLogger(Logger):
"""Logs results by simply printing out everything.
"""
def _init(self):
# Custom init function.
print("Initializing ...")
# Setting up our log-line prefix.
self.prefix = self.config.get("prefix")
def on_result(self, result: dict):
# Define, what should happen on receiving a `result` (dict).
print(f"{self.prefix}: {result}")
def close(self):
# Releases all resources used by this logger.
print("Closing")
def flush(self):
# Flushing all possible disk writes to permanent storage.
print("Flushing ;)", flush=True)
if __name__ == "__main__":
import ray
from ray import tune
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None)
config = {
"env": "CartPole-v0"
if args.run not in ["DDPG", "TD3"] else "Pendulum-v0",
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": args.framework,
# Run with tracing enabled for tfe/tf2.
"eager_tracing": args.framework in ["tfe", "tf2"],
# Setting up a custom logger config.
# ----------------------------------
# The following are different examples of custom logging setups:
# 1) Disable logging entirely.
# "logger_config": {
# # Use the tune.logger.NoopLogger class for no logging.
# "type": "ray.tune.logger.NoopLogger",
# },
# 2) Use tune's JsonLogger only.
# Alternatively, use `CSVLogger` or `TBXLogger` instead of
# `JsonLogger` in the "type" key below.
# "logger_config": {
# "type": "ray.tune.logger.JsonLogger",
# # Optional: Custom logdir (do not define this here
# # for using ~/ray_results/...).
# "logdir": "/tmp",
# },
# 3) Custom logger (see `MyPrintLogger` class above).
"logger_config": {
# Provide the class directly or via fully qualified class
# path.
"type": MyPrintLogger,
# `config` keys:
"prefix": "ABC",
# Optional: Custom logdir (do not define this here
# for using ~/ray_results/...).
# "logdir": "/somewhere/on/my/file/system/"
}
}
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
results = tune.run(
args.run, config=config, stop=stop, verbose=2, loggers=[MyPrintLogger])
if args.as_test:
check_learning_achieved(results, args.stop_reward)
ray.shutdown()

View file

@ -9,9 +9,9 @@ if __name__ == "__main__":
from ray.rllib.agents.a3c import A2CTrainer
assert "tensorflow" not in sys.modules, \
"TF initially present, when it shouldn't."
"`tensorflow` initially present, when it shouldn't!"
# note: no ray.init(), to test it works without Ray
# Note: No ray.init(), to test it works without Ray
trainer = A2CTrainer(
env="CartPole-v0", config={
"framework": "torch",
@ -19,7 +19,9 @@ if __name__ == "__main__":
})
trainer.train()
assert "tensorflow" not in sys.modules, "TF should not be imported"
assert "tensorflow" not in sys.modules, \
"`tensorflow` should not be imported after creating and " \
"training A3CTrainer!"
# Clean up.
del os.environ["RLLIB_TEST_NO_TF_IMPORT"]

View file

@ -2,33 +2,34 @@
import os
import sys
import pytest
@pytest.mark.skip(reason="Upstream change make it failed.")
def test_dependency_torch():
if __name__ == "__main__":
# Do not import torch for testing purposes.
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
from ray.rllib.agents.a3c import A2CTrainer
assert "torch" not in sys.modules, \
"Torch initially present, when it shouldn't."
"`torch` initially present, when it shouldn't!"
# note: no ray.init(), to test it works without Ray
# Note: No ray.init(), to test it works without Ray
trainer = A2CTrainer(
env="CartPole-v0", config={
env="CartPole-v0",
config={
"framework": "tf",
"num_workers": 0
"num_workers": 0,
# Disable the logger due to a sort-import attempt of torch
# inside the tensorboardX.SummaryWriter class.
"logger_config": {
"type": "ray.tune.logger.NoopLogger",
},
})
trainer.train()
assert "torch" not in sys.modules, "Torch should not be imported"
assert "torch" not in sys.modules, \
"`torch` should not be imported after creating and " \
"training A3CTrainer!"
# Clean up.
del os.environ["RLLIB_TEST_NO_TORCH_IMPORT"]
print("ok")
if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))