[tune] Add test for flatten_dict. (#17241)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
xwjiang2010 2021-07-21 22:01:01 -07:00 committed by GitHub
parent 362f7b7c56
commit f3a31a3b94
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 3 deletions

View file

@ -1,3 +1,4 @@
import copy
from collections import OrderedDict
import os
import sys
@ -9,7 +10,8 @@ import ray
import ray._private.utils
import ray.cloudpickle as cloudpickle
from ray.tune.utils.util import wait_for_gpu
from ray.tune.utils.util import unflatten_dict, unflatten_list_dict
from ray.tune.utils.util import (flatten_dict, unflatten_dict,
unflatten_list_dict)
from ray.tune.utils.trainable import TrainableUtil
@ -70,6 +72,44 @@ class TrainableUtilTest(unittest.TestCase):
self.assertEquals(loaded["data"][str(i)], open(path, "rb").read())
class FlattenDictTest(unittest.TestCase):
def test_output_type(self):
in_ = OrderedDict({"a": {"b": 1}, "c": {"d": 2}, "e": 3})
out = flatten_dict(in_)
assert type(in_) is type(out)
def test_one_level_nested(self):
ori_in = OrderedDict({"a": {"b": 1}, "c": {"d": 2}, "e": 3})
in_ = copy.deepcopy(ori_in)
result = flatten_dict(in_)
assert in_ == ori_in
assert result == {"a/b": 1, "c/d": 2, "e": 3}
def test_multi_level_nested(self):
ori_in = OrderedDict({
"a": {
"b": {
"c": {
"d": 1,
},
},
},
"b": {
"c": {
"d": 2,
},
},
"c": {
"d": 3,
},
"e": 4,
})
in_ = copy.deepcopy(ori_in)
result = flatten_dict(in_)
assert in_ == ori_in
assert result == {"a/b/c/d": 1, "b/c/d": 2, "c/d": 3, "e": 4}
class UnflattenDictTest(unittest.TestCase):
def test_output_type(self):
in_ = OrderedDict({"a/b": 1, "c/d": 2, "e": 3})

View file

@ -284,7 +284,11 @@ def deep_update(original,
def flatten_dict(dt, delimiter="/", prevent_delimiter=False):
"""Flatten dict."""
"""Flatten dict.
Output and input are of the same dict type.
Input dict remains the same after the operation.
"""
dt = copy.copy(dt)
if prevent_delimiter and any(delimiter in key for key in dt):
# Raise if delimiter is any of the keys
@ -298,7 +302,7 @@ def flatten_dict(dt, delimiter="/", prevent_delimiter=False):
if isinstance(value, dict):
for subkey, v in value.items():
if prevent_delimiter and delimiter in subkey:
# Raise if delimiter is in any of the subkeys
# Raise if delimiter is in any of the subkeys
raise ValueError(
"Found delimiter `{}` in key when trying to "
"flatten array. Please avoid using the delimiter "