diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 7a3bf97a3..9118663ec 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -13,6 +13,7 @@ import numbers import numpy as np import ray.cloudpickle as cloudpickle +from ray.tune.util import flatten_dict from ray.tune.syncer import get_log_syncer from ray.tune.result import (NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL, EXPR_PARAM_FILE, @@ -107,19 +108,15 @@ class JsonLogger(Logger): def to_tf_values(result, path): - values = [] - for attr, value in result.items(): - if value is not None: - if use_tf150_api: - type_list = [int, float, np.float32, np.float64, np.int32] - else: - type_list = [int, float] - if type(value) in type_list: - values.append( - tf.Summary.Value( - tag="/".join(path + [attr]), simple_value=value)) - elif type(value) is dict: - values.extend(to_tf_values(value, path + [attr])) + if use_tf150_api: + type_list = [int, float, np.float32, np.float64, np.int32] + else: + type_list = [int, float] + flat_result = flatten_dict(result, delimiter="/") + values = [ + tf.Summary.Value(tag="/".join(path + [attr]), simple_value=value) + for attr, value in flat_result.items() if type(value) in type_list + ] return values @@ -175,6 +172,10 @@ class CSVLogger(Logger): self._csv_out = None def on_result(self, result): + tmp = result.copy() + if "config" in tmp: + del tmp["config"] + result = flatten_dict(tmp, delimiter="/") if self._csv_out is None: self._csv_out = csv.DictWriter(self._file, result.keys()) if not self._continuing: @@ -182,6 +183,7 @@ class CSVLogger(Logger): self._csv_out.writerow( {k: v for k, v in result.items() if k in self._csv_out.fieldnames}) + self._file.flush() def flush(self): self._file.flush() diff --git a/python/ray/tune/util.py b/python/ray/tune/util.py index a6c122428..06cb4f0eb 100644 --- a/python/ray/tune/util.py +++ b/python/ray/tune/util.py @@ -180,14 +180,15 @@ def deep_update(original, new_dict, new_keys_allowed, whitelist): return original -def flatten_dict(dt): +def flatten_dict(dt, delimiter=":"): + dt = copy.deepcopy(dt) while any(isinstance(v, dict) for v in dt.values()): remove = [] add = {} for key, value in dt.items(): if isinstance(value, dict): for subkey, v in value.items(): - add[":".join([key, subkey])] = v + add[delimiter.join([key, subkey])] = v remove.append(key) dt.update(add) for k in remove: