From 354c960fff36b60fbcc59b7e49e2f8c0d0955e35 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Sat, 24 Apr 2021 08:13:41 +0200 Subject: [PATCH] [RLlib] Fix test_dependency_torch and fix custom logger support for RLlib. (#15120) --- rllib/agents/trainer.py | 35 +++++--- rllib/examples/custom_logger.py | 115 +++++++++++++++++++++++++++ rllib/tests/test_dependency_tf.py | 8 +- rllib/tests/test_dependency_torch.py | 27 ++++--- 4 files changed, 160 insertions(+), 25 deletions(-) create mode 100644 rllib/examples/custom_logger.py diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 9a8bda0ee..d99fc53a0 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -506,19 +506,36 @@ class Trainer(Trainable): # Create a default logger creator if no logger_creator is specified if logger_creator is None: + # Default logdir prefix containing the agent's name and the + # env id. timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") logdir_prefix = "{}_{}_{}".format(self._name, self._env_id, timestr) + if not os.path.exists(DEFAULT_RESULTS_DIR): + os.makedirs(DEFAULT_RESULTS_DIR) + logdir = tempfile.mkdtemp( + prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR) - def default_logger_creator(config): - """Creates a Unified logger with a default logdir prefix - containing the agent name and the env id - """ - if not os.path.exists(DEFAULT_RESULTS_DIR): - os.makedirs(DEFAULT_RESULTS_DIR) - logdir = tempfile.mkdtemp( - prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR) - return UnifiedLogger(config, logdir, loggers=None) + # Allow users to more precisely configure the created logger + # via "logger_config.type". + if config.get( + "logger_config") and "type" in config["logger_config"]: + + def default_logger_creator(config): + """Creates a custom logger with the default prefix.""" + cfg = config["logger_config"].copy() + cls = cfg.pop("type") + # Provide default for logdir, in case the user does + # not specify this in the "logger_config" dict. + logdir_ = cfg.pop("logdir", logdir) + return from_config(cls=cls, _args=[cfg], logdir=logdir_) + + # If no `type` given, use tune's UnifiedLogger as last resort. + else: + + def default_logger_creator(config): + """Creates a Unified logger with the default prefix.""" + return UnifiedLogger(config, logdir, loggers=None) logger_creator = default_logger_creator diff --git a/rllib/examples/custom_logger.py b/rllib/examples/custom_logger.py new file mode 100644 index 000000000..ccb90fa93 --- /dev/null +++ b/rllib/examples/custom_logger.py @@ -0,0 +1,115 @@ +""" +This example script demonstrates how one can define a custom logger +object for any RLlib Trainer via the Trainer's config dict's +"logger_config" key. +By default (logger_config=None), RLlib will construct a tune +UnifiedLogger object, which logs JSON, CSV, and TBX output. + +Below examples include: +- Disable logging entirely. +- Using only one of tune's Json, CSV, or TBX loggers. +- Defining a custom logger (by sub-classing tune.logger.py::Logger). +""" + +import argparse +import os + +from ray.rllib.utils.test_utils import check_learning_achieved +from ray.tune.logger import Logger + +parser = argparse.ArgumentParser() +parser.add_argument("--run", type=str, default="PPO") +parser.add_argument("--num-cpus", type=int, default=0) +parser.add_argument( + "--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf") +parser.add_argument("--as-test", action="store_true") +parser.add_argument("--stop-iters", type=int, default=200) +parser.add_argument("--stop-timesteps", type=int, default=100000) +parser.add_argument("--stop-reward", type=float, default=150.0) + + +class MyPrintLogger(Logger): + """Logs results by simply printing out everything. + """ + + def _init(self): + # Custom init function. + print("Initializing ...") + # Setting up our log-line prefix. + self.prefix = self.config.get("prefix") + + def on_result(self, result: dict): + # Define, what should happen on receiving a `result` (dict). + print(f"{self.prefix}: {result}") + + def close(self): + # Releases all resources used by this logger. + print("Closing") + + def flush(self): + # Flushing all possible disk writes to permanent storage. + print("Flushing ;)", flush=True) + + +if __name__ == "__main__": + import ray + from ray import tune + + args = parser.parse_args() + + ray.init(num_cpus=args.num_cpus or None) + + config = { + "env": "CartPole-v0" + if args.run not in ["DDPG", "TD3"] else "Pendulum-v0", + # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. + "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), + "framework": args.framework, + # Run with tracing enabled for tfe/tf2. + "eager_tracing": args.framework in ["tfe", "tf2"], + + # Setting up a custom logger config. + # ---------------------------------- + # The following are different examples of custom logging setups: + + # 1) Disable logging entirely. + # "logger_config": { + # # Use the tune.logger.NoopLogger class for no logging. + # "type": "ray.tune.logger.NoopLogger", + # }, + + # 2) Use tune's JsonLogger only. + # Alternatively, use `CSVLogger` or `TBXLogger` instead of + # `JsonLogger` in the "type" key below. + # "logger_config": { + # "type": "ray.tune.logger.JsonLogger", + # # Optional: Custom logdir (do not define this here + # # for using ~/ray_results/...). + # "logdir": "/tmp", + # }, + + # 3) Custom logger (see `MyPrintLogger` class above). + "logger_config": { + # Provide the class directly or via fully qualified class + # path. + "type": MyPrintLogger, + # `config` keys: + "prefix": "ABC", + # Optional: Custom logdir (do not define this here + # for using ~/ray_results/...). + # "logdir": "/somewhere/on/my/file/system/" + } + } + + stop = { + "training_iteration": args.stop_iters, + "timesteps_total": args.stop_timesteps, + "episode_reward_mean": args.stop_reward, + } + + results = tune.run( + args.run, config=config, stop=stop, verbose=2, loggers=[MyPrintLogger]) + + if args.as_test: + check_learning_achieved(results, args.stop_reward) + ray.shutdown() diff --git a/rllib/tests/test_dependency_tf.py b/rllib/tests/test_dependency_tf.py index bf2fdf153..dccc581b1 100644 --- a/rllib/tests/test_dependency_tf.py +++ b/rllib/tests/test_dependency_tf.py @@ -9,9 +9,9 @@ if __name__ == "__main__": from ray.rllib.agents.a3c import A2CTrainer assert "tensorflow" not in sys.modules, \ - "TF initially present, when it shouldn't." + "`tensorflow` initially present, when it shouldn't!" - # note: no ray.init(), to test it works without Ray + # Note: No ray.init(), to test it works without Ray trainer = A2CTrainer( env="CartPole-v0", config={ "framework": "torch", @@ -19,7 +19,9 @@ if __name__ == "__main__": }) trainer.train() - assert "tensorflow" not in sys.modules, "TF should not be imported" + assert "tensorflow" not in sys.modules, \ + "`tensorflow` should not be imported after creating and " \ + "training A3CTrainer!" # Clean up. del os.environ["RLLIB_TEST_NO_TF_IMPORT"] diff --git a/rllib/tests/test_dependency_torch.py b/rllib/tests/test_dependency_torch.py index 3a2651e29..05dcd519e 100755 --- a/rllib/tests/test_dependency_torch.py +++ b/rllib/tests/test_dependency_torch.py @@ -2,33 +2,34 @@ import os import sys -import pytest - -@pytest.mark.skip(reason="Upstream change make it failed.") -def test_dependency_torch(): +if __name__ == "__main__": # Do not import torch for testing purposes. os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1" from ray.rllib.agents.a3c import A2CTrainer assert "torch" not in sys.modules, \ - "Torch initially present, when it shouldn't." + "`torch` initially present, when it shouldn't!" - # note: no ray.init(), to test it works without Ray + # Note: No ray.init(), to test it works without Ray trainer = A2CTrainer( - env="CartPole-v0", config={ + env="CartPole-v0", + config={ "framework": "tf", - "num_workers": 0 + "num_workers": 0, + # Disable the logger due to a sort-import attempt of torch + # inside the tensorboardX.SummaryWriter class. + "logger_config": { + "type": "ray.tune.logger.NoopLogger", + }, }) trainer.train() - assert "torch" not in sys.modules, "Torch should not be imported" + assert "torch" not in sys.modules, \ + "`torch` should not be imported after creating and " \ + "training A3CTrainer!" # Clean up. del os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] print("ok") - - -if __name__ == "__main__": - sys.exit(pytest.main(["-sv", __file__]))