[rllib/tune] Cache get_preprocessor() calls, default max_failur… (#6211)

This commit is contained in:
Eric Liang 2019-11-21 15:55:56 -08:00 committed by Richard Liaw
parent d3227f2f2d
commit 7559fdb141
2 changed files with 12 additions and 2 deletions

View file

@ -74,7 +74,7 @@ def run(run_or_experiment,
checkpoint_score_attr=None,
global_checkpoint_period=10,
export_formats=None,
max_failures=3,
max_failures=0,
restore=None,
search_alg=None,
scheduler=None,

View file

@ -232,6 +232,10 @@ def restore_original_dimensions(obs, obs_space, tensorlib=tf):
return obs
# Cache of preprocessors, for if the user is calling unpack obs often.
_cache = {}
def _unpack_obs(obs, space, tensorlib=tf):
"""Unpack a flattened Dict or Tuple observation array/tensor.
@ -243,7 +247,13 @@ def _unpack_obs(obs, space, tensorlib=tf):
if (isinstance(space, gym.spaces.Dict)
or isinstance(space, gym.spaces.Tuple)):
prep = get_preprocessor(space)(space)
if id(space) in _cache:
prep = _cache[id(space)]
else:
prep = get_preprocessor(space)(space)
# Make an attempt to cache the result, if enough space left.
if len(_cache) < 999:
_cache[id(space)] = prep
if len(obs.shape) != 2 or obs.shape[1] != prep.shape[0]:
raise ValueError(
"Expected flattened obs shape of [None, {}], got {}".format(