mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Safer try-catch for TensorboardX (#8174)
Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com>
This commit is contained in:
parent
13c14eac07
commit
9dd3490c38
2 changed files with 46 additions and 7 deletions
|
@ -186,6 +186,9 @@ class TBXLogger(Logger):
|
|||
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
|
||||
"""
|
||||
|
||||
# NoneType is not supported on the last TBX release yet.
|
||||
VALID_HPARAMS = (str, bool, int, float, list)
|
||||
|
||||
def _init(self):
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
@ -253,14 +256,31 @@ class TBXLogger(Logger):
|
|||
flat_params = flatten_dict(self.trial.evaluated_params)
|
||||
scrubbed_params = {
|
||||
k: v
|
||||
for k, v in flat_params.items() if v is not None
|
||||
for k, v in flat_params.items()
|
||||
if isinstance(v, self.VALID_HPARAMS)
|
||||
}
|
||||
|
||||
removed = {
|
||||
k: v
|
||||
for k, v in flat_params.items()
|
||||
if not isinstance(v, self.VALID_HPARAMS)
|
||||
}
|
||||
if removed:
|
||||
logger.info(
|
||||
"Removed the following hyperparameter values when "
|
||||
"logging to tensorboard: %s", str(removed))
|
||||
|
||||
from tensorboardX.summary import hparams
|
||||
try:
|
||||
experiment_tag, session_start_tag, session_end_tag = hparams(
|
||||
hparam_dict=scrubbed_params, metric_dict=result)
|
||||
self._file_writer.file_writer.add_summary(experiment_tag)
|
||||
self._file_writer.file_writer.add_summary(session_start_tag)
|
||||
self._file_writer.file_writer.add_summary(session_end_tag)
|
||||
except Exception:
|
||||
logger.exception("TensorboardX failed to log hparams. "
|
||||
"This may be due to an unsupported type "
|
||||
"in the hyperparameter values.")
|
||||
|
||||
|
||||
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TBXLogger)
|
||||
|
|
|
@ -46,7 +46,7 @@ class LoggerSuite(unittest.TestCase):
|
|||
logger.close()
|
||||
|
||||
def testTBX(self):
|
||||
config = {"a": 2, "b": 5, "c": {"c": {"D": 123}, "e": None}}
|
||||
config = {"a": 2, "b": [1, 2], "c": {"c": {"D": 123}}}
|
||||
t = Trial(evaluated_params=config, trial_id="tbx")
|
||||
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(0, 4))
|
||||
|
@ -54,6 +54,25 @@ class LoggerSuite(unittest.TestCase):
|
|||
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
|
||||
logger.close()
|
||||
|
||||
def testBadTBX(self):
|
||||
config = {"b": (1, 2, 3)}
|
||||
t = Trial(evaluated_params=config, trial_id="tbx")
|
||||
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(0, 4))
|
||||
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
|
||||
with self.assertLogs("ray.tune.logger", level="INFO") as cm:
|
||||
logger.close()
|
||||
assert "INFO" in cm.output[0]
|
||||
|
||||
config = {"None": None}
|
||||
t = Trial(evaluated_params=config, trial_id="tbx")
|
||||
logger = TBXLogger(config=config, logdir=self.test_dir, trial=t)
|
||||
logger.on_result(result(0, 4))
|
||||
logger.on_result(result(2, 4, score=[1, 2, 3], hello={"world": 1}))
|
||||
with self.assertLogs("ray.tune.logger", level="INFO") as cm:
|
||||
logger.close()
|
||||
assert "INFO" in cm.output[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
|
Loading…
Add table
Reference in a new issue