[tune] Safer try-catch for TensorboardX (#8174)

Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com>
This commit is contained in:
Richard Liaw 2020-04-25 13:08:37 -07:00 committed by GitHub
parent 13c14eac07
commit 9dd3490c38
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 7 deletions

View file

@ -186,6 +186,9 @@ class TBXLogger(Logger):
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
"""
# NoneType is not supported on the last TBX release yet.
VALID_HPARAMS = (str, bool, int, float, list)
def _init(self):
try:
from tensorboardX import SummaryWriter
@ -253,14 +256,31 @@ class TBXLogger(Logger):
flat_params = flatten_dict(self.trial.evaluated_params)
scrubbed_params = {
k: v
for k, v in flat_params.items() if v is not None
for k, v in flat_params.items()
if isinstance(v, self.VALID_HPARAMS)
}
removed = {
k: v
for k, v in flat_params.items()
if not isinstance(v, self.VALID_HPARAMS)
}
if removed:
logger.info(
"Removed the following hyperparameter values when "
"logging to tensorboard: %s", str(removed))
from tensorboardX.summary import hparams
try:
experiment_tag, session_start_tag, session_end_tag = hparams(
hparam_dict=scrubbed_params, metric_dict=result)
self._file_writer.file_writer.add_summary(experiment_tag)
self._file_writer.file_writer.add_summary(session_start_tag)
self._file_writer.file_writer.add_summary(session_end_tag)
except Exception:
logger.exception("TensorboardX failed to log hparams. "
"This may be due to an unsupported type "
"in the hyperparameter values.")
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TBXLogger)

View file

@ -46,7 +46,7 @@ class LoggerSuite(unittest.TestCase):
logger.close()
def testTBX(self):
config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}}
config = {"a": 2, "b": [1, 2], "c": {"c": {"D": 123}}}
t = Trial(evaluated_params=config, trial_id="tbx")
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(0, 4))
@ -54,6 +54,25 @@ class LoggerSuite(unittest.TestCase):
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
logger.close()
def testBadTBX(self):
config = {"b": (1, 2, 3)}
t = Trial(evaluated_params=config, trial_id="tbx")
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(0, 4))
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
with self.assertLogs("ray.tune.logger", level="INFO") as cm:
logger.close()
assert "INFO" in cm.output[0]
config = {"None": None}
t = Trial(evaluated_params=config, trial_id="tbx")
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
logger.on_result(result(0, 4))
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
with self.assertLogs("ray.tune.logger", level="INFO") as cm:
logger.close()
assert "INFO" in cm.output[0]
if __name__ == "__main__":
import pytest