[tune] support nested dictionaries for CSVLogger (#5295)

This commit is contained in:
lanlin 2019-07-28 05:44:34 +08:00 committed by Richard Liaw
parent b4823d63c6
commit 341dbf6c45
2 changed files with 18 additions and 15 deletions

View file

@ -13,6 +13,7 @@ import numbers
import numpy as np import numpy as np
import ray.cloudpickle as cloudpickle import ray.cloudpickle as cloudpickle
from ray.tune.util import flatten_dict
from ray.tune.syncer import get_log_syncer from ray.tune.syncer import get_log_syncer
from ray.tune.result import (NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S, from ray.tune.result import (NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S,
TIMESTEPS_TOTAL, EXPR_PARAM_FILE, TIMESTEPS_TOTAL, EXPR_PARAM_FILE,
@ -107,19 +108,15 @@ class JsonLogger(Logger):
def to_tf_values(result, path): def to_tf_values(result, path):
values = [] if use_tf150_api:
for attr, value in result.items(): type_list = [int, float, np.float32, np.float64, np.int32]
if value is not None: else:
if use_tf150_api: type_list = [int, float]
type_list = [int, float, np.float32, np.float64, np.int32] flat_result = flatten_dict(result, delimiter="/")
else: values = [
type_list = [int, float] tf.Summary.Value(tag="/".join(path + [attr]), simple_value=value)
if type(value) in type_list: for attr, value in flat_result.items() if type(value) in type_list
values.append( ]
tf.Summary.Value(
tag="/".join(path + [attr]), simple_value=value))
elif type(value) is dict:
values.extend(to_tf_values(value, path + [attr]))
return values return values
@ -175,6 +172,10 @@ class CSVLogger(Logger):
self._csv_out = None self._csv_out = None
def on_result(self, result): def on_result(self, result):
tmp = result.copy()
if "config" in tmp:
del tmp["config"]
result = flatten_dict(tmp, delimiter="/")
if self._csv_out is None: if self._csv_out is None:
self._csv_out = csv.DictWriter(self._file, result.keys()) self._csv_out = csv.DictWriter(self._file, result.keys())
if not self._continuing: if not self._continuing:
@ -182,6 +183,7 @@ class CSVLogger(Logger):
self._csv_out.writerow( self._csv_out.writerow(
{k: v {k: v
for k, v in result.items() if k in self._csv_out.fieldnames}) for k, v in result.items() if k in self._csv_out.fieldnames})
self._file.flush()
def flush(self): def flush(self):
self._file.flush() self._file.flush()

View file

@ -180,14 +180,15 @@ def deep_update(original, new_dict, new_keys_allowed, whitelist):
return original return original
def flatten_dict(dt): def flatten_dict(dt, delimiter=":"):
dt = copy.deepcopy(dt)
while any(isinstance(v, dict) for v in dt.values()): while any(isinstance(v, dict) for v in dt.values()):
remove = [] remove = []
add = {} add = {}
for key, value in dt.items(): for key, value in dt.items():
if isinstance(value, dict): if isinstance(value, dict):
for subkey, v in value.items(): for subkey, v in value.items():
add[":".join([key, subkey])] = v add[delimiter.join([key, subkey])] = v
remove.append(key) remove.append(key)
dt.update(add) dt.update(add)
for k in remove: for k in remove: