mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] support nested dictionaries for CSVLogger (#5295)
This commit is contained in:
parent
b4823d63c6
commit
341dbf6c45
2 changed files with 18 additions and 15 deletions
|
@ -13,6 +13,7 @@ import numbers
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import ray.cloudpickle as cloudpickle
|
import ray.cloudpickle as cloudpickle
|
||||||
|
from ray.tune.util import flatten_dict
|
||||||
from ray.tune.syncer import get_log_syncer
|
from ray.tune.syncer import get_log_syncer
|
||||||
from ray.tune.result import (NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S,
|
from ray.tune.result import (NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S,
|
||||||
TIMESTEPS_TOTAL, EXPR_PARAM_FILE,
|
TIMESTEPS_TOTAL, EXPR_PARAM_FILE,
|
||||||
|
@ -107,19 +108,15 @@ class JsonLogger(Logger):
|
||||||
|
|
||||||
|
|
||||||
def to_tf_values(result, path):
|
def to_tf_values(result, path):
|
||||||
values = []
|
if use_tf150_api:
|
||||||
for attr, value in result.items():
|
type_list = [int, float, np.float32, np.float64, np.int32]
|
||||||
if value is not None:
|
else:
|
||||||
if use_tf150_api:
|
type_list = [int, float]
|
||||||
type_list = [int, float, np.float32, np.float64, np.int32]
|
flat_result = flatten_dict(result, delimiter="/")
|
||||||
else:
|
values = [
|
||||||
type_list = [int, float]
|
tf.Summary.Value(tag="/".join(path + [attr]), simple_value=value)
|
||||||
if type(value) in type_list:
|
for attr, value in flat_result.items() if type(value) in type_list
|
||||||
values.append(
|
]
|
||||||
tf.Summary.Value(
|
|
||||||
tag="/".join(path + [attr]), simple_value=value))
|
|
||||||
elif type(value) is dict:
|
|
||||||
values.extend(to_tf_values(value, path + [attr]))
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
@ -175,6 +172,10 @@ class CSVLogger(Logger):
|
||||||
self._csv_out = None
|
self._csv_out = None
|
||||||
|
|
||||||
def on_result(self, result):
|
def on_result(self, result):
|
||||||
|
tmp = result.copy()
|
||||||
|
if "config" in tmp:
|
||||||
|
del tmp["config"]
|
||||||
|
result = flatten_dict(tmp, delimiter="/")
|
||||||
if self._csv_out is None:
|
if self._csv_out is None:
|
||||||
self._csv_out = csv.DictWriter(self._file, result.keys())
|
self._csv_out = csv.DictWriter(self._file, result.keys())
|
||||||
if not self._continuing:
|
if not self._continuing:
|
||||||
|
@ -182,6 +183,7 @@ class CSVLogger(Logger):
|
||||||
self._csv_out.writerow(
|
self._csv_out.writerow(
|
||||||
{k: v
|
{k: v
|
||||||
for k, v in result.items() if k in self._csv_out.fieldnames})
|
for k, v in result.items() if k in self._csv_out.fieldnames})
|
||||||
|
self._file.flush()
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
self._file.flush()
|
self._file.flush()
|
||||||
|
|
|
@ -180,14 +180,15 @@ def deep_update(original, new_dict, new_keys_allowed, whitelist):
|
||||||
return original
|
return original
|
||||||
|
|
||||||
|
|
||||||
def flatten_dict(dt):
|
def flatten_dict(dt, delimiter=":"):
|
||||||
|
dt = copy.deepcopy(dt)
|
||||||
while any(isinstance(v, dict) for v in dt.values()):
|
while any(isinstance(v, dict) for v in dt.values()):
|
||||||
remove = []
|
remove = []
|
||||||
add = {}
|
add = {}
|
||||||
for key, value in dt.items():
|
for key, value in dt.items():
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
for subkey, v in value.items():
|
for subkey, v in value.items():
|
||||||
add[":".join([key, subkey])] = v
|
add[delimiter.join([key, subkey])] = v
|
||||||
remove.append(key)
|
remove.append(key)
|
||||||
dt.update(add)
|
dt.update(add)
|
||||||
for k in remove:
|
for k in remove:
|
||||||
|
|
Loading…
Add table
Reference in a new issue