mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Add test for flatten_dict. (#17241)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
362f7b7c56
commit
f3a31a3b94
2 changed files with 47 additions and 3 deletions
|
@ -1,3 +1,4 @@
|
|||
import copy
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import sys
|
||||
|
@ -9,7 +10,8 @@ import ray
|
|||
import ray._private.utils
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.tune.utils.util import wait_for_gpu
|
||||
from ray.tune.utils.util import unflatten_dict, unflatten_list_dict
|
||||
from ray.tune.utils.util import (flatten_dict, unflatten_dict,
|
||||
unflatten_list_dict)
|
||||
from ray.tune.utils.trainable import TrainableUtil
|
||||
|
||||
|
||||
|
@ -70,6 +72,44 @@ class TrainableUtilTest(unittest.TestCase):
|
|||
self.assertEquals(loaded["data"][str(i)], open(path, "rb").read())
|
||||
|
||||
|
||||
class FlattenDictTest(unittest.TestCase):
|
||||
def test_output_type(self):
|
||||
in_ = OrderedDict({"a": {"b": 1}, "c": {"d": 2}, "e": 3})
|
||||
out = flatten_dict(in_)
|
||||
assert type(in_) is type(out)
|
||||
|
||||
def test_one_level_nested(self):
|
||||
ori_in = OrderedDict({"a": {"b": 1}, "c": {"d": 2}, "e": 3})
|
||||
in_ = copy.deepcopy(ori_in)
|
||||
result = flatten_dict(in_)
|
||||
assert in_ == ori_in
|
||||
assert result == {"a/b": 1, "c/d": 2, "e": 3}
|
||||
|
||||
def test_multi_level_nested(self):
|
||||
ori_in = OrderedDict({
|
||||
"a": {
|
||||
"b": {
|
||||
"c": {
|
||||
"d": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
"b": {
|
||||
"c": {
|
||||
"d": 2,
|
||||
},
|
||||
},
|
||||
"c": {
|
||||
"d": 3,
|
||||
},
|
||||
"e": 4,
|
||||
})
|
||||
in_ = copy.deepcopy(ori_in)
|
||||
result = flatten_dict(in_)
|
||||
assert in_ == ori_in
|
||||
assert result == {"a/b/c/d": 1, "b/c/d": 2, "c/d": 3, "e": 4}
|
||||
|
||||
|
||||
class UnflattenDictTest(unittest.TestCase):
|
||||
def test_output_type(self):
|
||||
in_ = OrderedDict({"a/b": 1, "c/d": 2, "e": 3})
|
||||
|
|
|
@ -284,7 +284,11 @@ def deep_update(original,
|
|||
|
||||
|
||||
def flatten_dict(dt, delimiter="/", prevent_delimiter=False):
|
||||
"""Flatten dict."""
|
||||
"""Flatten dict.
|
||||
|
||||
Output and input are of the same dict type.
|
||||
Input dict remains the same after the operation.
|
||||
"""
|
||||
dt = copy.copy(dt)
|
||||
if prevent_delimiter and any(delimiter in key for key in dt):
|
||||
# Raise if delimiter is any of the keys
|
||||
|
@ -298,7 +302,7 @@ def flatten_dict(dt, delimiter="/", prevent_delimiter=False):
|
|||
if isinstance(value, dict):
|
||||
for subkey, v in value.items():
|
||||
if prevent_delimiter and delimiter in subkey:
|
||||
# Raise if delimiter is in any of the subkeys
|
||||
# Raise if delimiter is in any of the subkeys
|
||||
raise ValueError(
|
||||
"Found delimiter `{}` in key when trying to "
|
||||
"flatten array. Please avoid using the delimiter "
|
||||
|
|
Loading…
Add table
Reference in a new issue