mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Fix unflatten dict (#11948)
This commit is contained in:
parent
9920933e31
commit
07f401d99d
2 changed files with 47 additions and 3 deletions
|
@ -268,14 +268,15 @@ def flatten_dict(dt, delimiter="/", prevent_delimiter=False):
|
|||
|
||||
def unflatten_dict(dt, delimiter="/"):
|
||||
"""Unflatten dict. Does not support unflattening lists."""
|
||||
out = defaultdict(dict)
|
||||
dict_type = type(dt)
|
||||
out = dict_type()
|
||||
for key, val in dt.items():
|
||||
path = key.split(delimiter)
|
||||
item = out
|
||||
for k in path[:-1]:
|
||||
item = item[k]
|
||||
item = item.setdefault(k, dict_type())
|
||||
item[path[-1]] = val
|
||||
return dict(out)
|
||||
return out
|
||||
|
||||
|
||||
def unflattened_lookup(flat_key, lookup, delimiter="/", **kwargs):
|
||||
|
|
43
python/ray/tune/utils/util_test.py
Normal file
43
python/ray/tune/utils/util_test.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
import unittest
|
||||
|
||||
from .util import unflatten_dict
|
||||
|
||||
|
||||
class UnflattenDictTest(unittest.TestCase):
|
||||
def test_output_type(self):
|
||||
in_ = OrderedDict({"a/b": 1, "c/d": 2, "e": 3})
|
||||
out = unflatten_dict(in_)
|
||||
assert type(in_) is type(out)
|
||||
|
||||
def test_one_level_nested(self):
|
||||
result = unflatten_dict({"a/b": 1, "c/d": 2, "e": 3})
|
||||
assert result == {"a": {"b": 1}, "c": {"d": 2}, "e": 3}
|
||||
|
||||
def test_multi_level_nested(self):
|
||||
result = unflatten_dict({"a/b/c/d": 1, "b/c/d": 2, "c/d": 3, "e": 4})
|
||||
assert result == {
|
||||
"a": {
|
||||
"b": {
|
||||
"c": {
|
||||
"d": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
"b": {
|
||||
"c": {
|
||||
"d": 2,
|
||||
},
|
||||
},
|
||||
"c": {
|
||||
"d": 3,
|
||||
},
|
||||
"e": 4,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
Loading…
Add table
Reference in a new issue