mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Fix test_dependency_torch and fix custom logger support for RLlib. (#15120)
This commit is contained in:
parent
57c0bd9912
commit
354c960fff
4 changed files with 160 additions and 25 deletions
|
@ -506,19 +506,36 @@ class Trainer(Trainable):
|
|||
|
||||
# Create a default logger creator if no logger_creator is specified
|
||||
if logger_creator is None:
|
||||
# Default logdir prefix containing the agent's name and the
|
||||
# env id.
|
||||
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
logdir_prefix = "{}_{}_{}".format(self._name, self._env_id,
|
||||
timestr)
|
||||
if not os.path.exists(DEFAULT_RESULTS_DIR):
|
||||
os.makedirs(DEFAULT_RESULTS_DIR)
|
||||
logdir = tempfile.mkdtemp(
|
||||
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
|
||||
|
||||
def default_logger_creator(config):
|
||||
"""Creates a Unified logger with a default logdir prefix
|
||||
containing the agent name and the env id
|
||||
"""
|
||||
if not os.path.exists(DEFAULT_RESULTS_DIR):
|
||||
os.makedirs(DEFAULT_RESULTS_DIR)
|
||||
logdir = tempfile.mkdtemp(
|
||||
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
|
||||
return UnifiedLogger(config, logdir, loggers=None)
|
||||
# Allow users to more precisely configure the created logger
|
||||
# via "logger_config.type".
|
||||
if config.get(
|
||||
"logger_config") and "type" in config["logger_config"]:
|
||||
|
||||
def default_logger_creator(config):
|
||||
"""Creates a custom logger with the default prefix."""
|
||||
cfg = config["logger_config"].copy()
|
||||
cls = cfg.pop("type")
|
||||
# Provide default for logdir, in case the user does
|
||||
# not specify this in the "logger_config" dict.
|
||||
logdir_ = cfg.pop("logdir", logdir)
|
||||
return from_config(cls=cls, _args=[cfg], logdir=logdir_)
|
||||
|
||||
# If no `type` given, use tune's UnifiedLogger as last resort.
|
||||
else:
|
||||
|
||||
def default_logger_creator(config):
|
||||
"""Creates a Unified logger with the default prefix."""
|
||||
return UnifiedLogger(config, logdir, loggers=None)
|
||||
|
||||
logger_creator = default_logger_creator
|
||||
|
||||
|
|
115
rllib/examples/custom_logger.py
Normal file
115
rllib/examples/custom_logger.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
"""
|
||||
This example script demonstrates how one can define a custom logger
|
||||
object for any RLlib Trainer via the Trainer's config dict's
|
||||
"logger_config" key.
|
||||
By default (logger_config=None), RLlib will construct a tune
|
||||
UnifiedLogger object, which logs JSON, CSV, and TBX output.
|
||||
|
||||
Below examples include:
|
||||
- Disable logging entirely.
|
||||
- Using only one of tune's Json, CSV, or TBX loggers.
|
||||
- Defining a custom logger (by sub-classing tune.logger.py::Logger).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
from ray.tune.logger import Logger
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--run", type=str, default="PPO")
|
||||
parser.add_argument("--num-cpus", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
|
||||
parser.add_argument("--as-test", action="store_true")
|
||||
parser.add_argument("--stop-iters", type=int, default=200)
|
||||
parser.add_argument("--stop-timesteps", type=int, default=100000)
|
||||
parser.add_argument("--stop-reward", type=float, default=150.0)
|
||||
|
||||
|
||||
class MyPrintLogger(Logger):
|
||||
"""Logs results by simply printing out everything.
|
||||
"""
|
||||
|
||||
def _init(self):
|
||||
# Custom init function.
|
||||
print("Initializing ...")
|
||||
# Setting up our log-line prefix.
|
||||
self.prefix = self.config.get("prefix")
|
||||
|
||||
def on_result(self, result: dict):
|
||||
# Define, what should happen on receiving a `result` (dict).
|
||||
print(f"{self.prefix}: {result}")
|
||||
|
||||
def close(self):
|
||||
# Releases all resources used by this logger.
|
||||
print("Closing")
|
||||
|
||||
def flush(self):
|
||||
# Flushing all possible disk writes to permanent storage.
|
||||
print("Flushing ;)", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init(num_cpus=args.num_cpus or None)
|
||||
|
||||
config = {
|
||||
"env": "CartPole-v0"
|
||||
if args.run not in ["DDPG", "TD3"] else "Pendulum-v0",
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"framework": args.framework,
|
||||
# Run with tracing enabled for tfe/tf2.
|
||||
"eager_tracing": args.framework in ["tfe", "tf2"],
|
||||
|
||||
# Setting up a custom logger config.
|
||||
# ----------------------------------
|
||||
# The following are different examples of custom logging setups:
|
||||
|
||||
# 1) Disable logging entirely.
|
||||
# "logger_config": {
|
||||
# # Use the tune.logger.NoopLogger class for no logging.
|
||||
# "type": "ray.tune.logger.NoopLogger",
|
||||
# },
|
||||
|
||||
# 2) Use tune's JsonLogger only.
|
||||
# Alternatively, use `CSVLogger` or `TBXLogger` instead of
|
||||
# `JsonLogger` in the "type" key below.
|
||||
# "logger_config": {
|
||||
# "type": "ray.tune.logger.JsonLogger",
|
||||
# # Optional: Custom logdir (do not define this here
|
||||
# # for using ~/ray_results/...).
|
||||
# "logdir": "/tmp",
|
||||
# },
|
||||
|
||||
# 3) Custom logger (see `MyPrintLogger` class above).
|
||||
"logger_config": {
|
||||
# Provide the class directly or via fully qualified class
|
||||
# path.
|
||||
"type": MyPrintLogger,
|
||||
# `config` keys:
|
||||
"prefix": "ABC",
|
||||
# Optional: Custom logdir (do not define this here
|
||||
# for using ~/ray_results/...).
|
||||
# "logdir": "/somewhere/on/my/file/system/"
|
||||
}
|
||||
}
|
||||
|
||||
stop = {
|
||||
"training_iteration": args.stop_iters,
|
||||
"timesteps_total": args.stop_timesteps,
|
||||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
results = tune.run(
|
||||
args.run, config=config, stop=stop, verbose=2, loggers=[MyPrintLogger])
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
ray.shutdown()
|
|
@ -9,9 +9,9 @@ if __name__ == "__main__":
|
|||
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
assert "tensorflow" not in sys.modules, \
|
||||
"TF initially present, when it shouldn't."
|
||||
"`tensorflow` initially present, when it shouldn't!"
|
||||
|
||||
# note: no ray.init(), to test it works without Ray
|
||||
# Note: No ray.init(), to test it works without Ray
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"framework": "torch",
|
||||
|
@ -19,7 +19,9 @@ if __name__ == "__main__":
|
|||
})
|
||||
trainer.train()
|
||||
|
||||
assert "tensorflow" not in sys.modules, "TF should not be imported"
|
||||
assert "tensorflow" not in sys.modules, \
|
||||
"`tensorflow` should not be imported after creating and " \
|
||||
"training A3CTrainer!"
|
||||
|
||||
# Clean up.
|
||||
del os.environ["RLLIB_TEST_NO_TF_IMPORT"]
|
||||
|
|
|
@ -2,33 +2,34 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Upstream change make it failed.")
|
||||
def test_dependency_torch():
|
||||
if __name__ == "__main__":
|
||||
# Do not import torch for testing purposes.
|
||||
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
|
||||
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
assert "torch" not in sys.modules, \
|
||||
"Torch initially present, when it shouldn't."
|
||||
"`torch` initially present, when it shouldn't!"
|
||||
|
||||
# note: no ray.init(), to test it works without Ray
|
||||
# Note: No ray.init(), to test it works without Ray
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0", config={
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"framework": "tf",
|
||||
"num_workers": 0
|
||||
"num_workers": 0,
|
||||
# Disable the logger due to a sort-import attempt of torch
|
||||
# inside the tensorboardX.SummaryWriter class.
|
||||
"logger_config": {
|
||||
"type": "ray.tune.logger.NoopLogger",
|
||||
},
|
||||
})
|
||||
trainer.train()
|
||||
|
||||
assert "torch" not in sys.modules, "Torch should not be imported"
|
||||
assert "torch" not in sys.modules, \
|
||||
"`torch` should not be imported after creating and " \
|
||||
"training A3CTrainer!"
|
||||
|
||||
# Clean up.
|
||||
del os.environ["RLLIB_TEST_NO_TORCH_IMPORT"]
|
||||
|
||||
print("ok")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-sv", __file__]))
|
||||
|
|
Loading…
Add table
Reference in a new issue