ray/rllib/utils/tracking_dict.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

32 lines
1.1 KiB
Python

class UsageTrackingDict(dict):
"""Dict that tracks which keys have been accessed.
It can also intercept gets and allow an arbitrary callback to be applied
(i.e., to lazily convert numpy arrays to Tensors).
We make the simplifying assumption only __getitem__ is used to access
values.
"""
def __init__(self, *args, **kwargs):
dict.__init__(self, *args, **kwargs)
self.accessed_keys = set()
self.intercepted_values = {}
self.get_interceptor = None
def set_get_interceptor(self, fn):
self.get_interceptor = fn
def __getitem__(self, key):
self.accessed_keys.add(key)
value = dict.__getitem__(self, key)
if self.get_interceptor:
if key not in self.intercepted_values:
self.intercepted_values[key] = self.get_interceptor(value)
value = self.intercepted_values[key]
return value
def __setitem__(self, key, value):
dict.__setitem__(self, key, value)
if key in self.intercepted_values:
self.intercepted_values[key] = value