[tune] support nested dictionaries for CSVLogger (#5295)

This commit is contained in:
lanlin 2019-07-28 05:44:34 +08:00 committed by Richard Liaw
parent b4823d63c6
commit 341dbf6c45
2 changed files with 18 additions and 15 deletions

View file

@ -13,6 +13,7 @@ import numbers
import numpy as np
import ray.cloudpickle as cloudpickle
from ray.tune.util import flatten_dict
from ray.tune.syncer import get_log_syncer
from ray.tune.result import (NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S,
TIMESTEPS_TOTAL, EXPR_PARAM_FILE,
@ -107,19 +108,15 @@ class JsonLogger(Logger):
def to_tf_values(result, path):
values = []
for attr, value in result.items():
if value is not None:
if use_tf150_api:
type_list = [int, float, np.float32, np.float64, np.int32]
else:
type_list = [int, float]
if type(value) in type_list:
values.append(
tf.Summary.Value(
tag="/".join(path + [attr]), simple_value=value))
elif type(value) is dict:
values.extend(to_tf_values(value, path + [attr]))
if use_tf150_api:
type_list = [int, float, np.float32, np.float64, np.int32]
else:
type_list = [int, float]
flat_result = flatten_dict(result, delimiter="/")
values = [
tf.Summary.Value(tag="/".join(path + [attr]), simple_value=value)
for attr, value in flat_result.items() if type(value) in type_list
]
return values
@ -175,6 +172,10 @@ class CSVLogger(Logger):
self._csv_out = None
def on_result(self, result):
tmp = result.copy()
if "config" in tmp:
del tmp["config"]
result = flatten_dict(tmp, delimiter="/")
if self._csv_out is None:
self._csv_out = csv.DictWriter(self._file, result.keys())
if not self._continuing:
@ -182,6 +183,7 @@ class CSVLogger(Logger):
self._csv_out.writerow(
{k: v
for k, v in result.items() if k in self._csv_out.fieldnames})
self._file.flush()
def flush(self):
self._file.flush()

View file

@ -180,14 +180,15 @@ def deep_update(original, new_dict, new_keys_allowed, whitelist):
return original
def flatten_dict(dt):
def flatten_dict(dt, delimiter=":"):
dt = copy.deepcopy(dt)
while any(isinstance(v, dict) for v in dt.values()):
remove = []
add = {}
for key, value in dt.items():
if isinstance(value, dict):
for subkey, v in value.items():
add[":".join([key, subkey])] = v
add[delimiter.join([key, subkey])] = v
remove.append(key)
dt.update(add)
for k in remove: