[tune] Fix unflatten dict (#11948)

This commit is contained in:
Kristian Hartikainen 2020-11-12 16:43:15 +00:00 committed by GitHub
parent 9920933e31
commit 07f401d99d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 3 deletions

View file

@ -268,14 +268,15 @@ def flatten_dict(dt, delimiter="/", prevent_delimiter=False):
def unflatten_dict(dt, delimiter="/"):
"""Unflatten dict. Does not support unflattening lists."""
out = defaultdict(dict)
dict_type = type(dt)
out = dict_type()
for key, val in dt.items():
path = key.split(delimiter)
item = out
for k in path[:-1]:
item = item[k]
item = item.setdefault(k, dict_type())
item[path[-1]] = val
return dict(out)
return out
def unflattened_lookup(flat_key, lookup, delimiter="/", **kwargs):

View file

@ -0,0 +1,43 @@
from collections import OrderedDict
import unittest
from .util import unflatten_dict
class UnflattenDictTest(unittest.TestCase):
def test_output_type(self):
in_ = OrderedDict({"a/b": 1, "c/d": 2, "e": 3})
out = unflatten_dict(in_)
assert type(in_) is type(out)
def test_one_level_nested(self):
result = unflatten_dict({"a/b": 1, "c/d": 2, "e": 3})
assert result == {"a": {"b": 1}, "c": {"d": 2}, "e": 3}
def test_multi_level_nested(self):
result = unflatten_dict({"a/b/c/d": 1, "b/c/d": 2, "c/d": 3, "e": 4})
assert result == {
"a": {
"b": {
"c": {
"d": 1,
},
},
},
"b": {
"c": {
"d": 2,
},
},
"c": {
"d": 3,
},
"e": 4,
}
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))