From 07f401d99d38a73164072fbbb7208a63d6460590 Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Thu, 12 Nov 2020 16:43:15 +0000 Subject: [PATCH] [tune] Fix unflatten dict (#11948) --- python/ray/tune/utils/util.py | 7 ++--- python/ray/tune/utils/util_test.py | 43 ++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) create mode 100644 python/ray/tune/utils/util_test.py diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index af6e23ea3..1048f0e0e 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -268,14 +268,15 @@ def flatten_dict(dt, delimiter="/", prevent_delimiter=False): def unflatten_dict(dt, delimiter="/"): """Unflatten dict. Does not support unflattening lists.""" - out = defaultdict(dict) + dict_type = type(dt) + out = dict_type() for key, val in dt.items(): path = key.split(delimiter) item = out for k in path[:-1]: - item = item[k] + item = item.setdefault(k, dict_type()) item[path[-1]] = val - return dict(out) + return out def unflattened_lookup(flat_key, lookup, delimiter="/", **kwargs): diff --git a/python/ray/tune/utils/util_test.py b/python/ray/tune/utils/util_test.py new file mode 100644 index 000000000..534061f68 --- /dev/null +++ b/python/ray/tune/utils/util_test.py @@ -0,0 +1,43 @@ +from collections import OrderedDict + +import unittest + +from .util import unflatten_dict + + +class UnflattenDictTest(unittest.TestCase): + def test_output_type(self): + in_ = OrderedDict({"a/b": 1, "c/d": 2, "e": 3}) + out = unflatten_dict(in_) + assert type(in_) is type(out) + + def test_one_level_nested(self): + result = unflatten_dict({"a/b": 1, "c/d": 2, "e": 3}) + assert result == {"a": {"b": 1}, "c": {"d": 2}, "e": 3} + + def test_multi_level_nested(self): + result = unflatten_dict({"a/b/c/d": 1, "b/c/d": 2, "c/d": 3, "e": 4}) + assert result == { + "a": { + "b": { + "c": { + "d": 1, + }, + }, + }, + "b": { + "c": { + "d": 2, + }, + }, + "c": { + "d": 3, + }, + "e": 4, + } + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__]))