From 7559fdb1418e4c35f09e2edbc7c6762b3889f278 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 21 Nov 2019 15:55:56 -0800 Subject: [PATCH] =?UTF-8?q?[rllib/tune]=20Cache=20get=5Fpreprocessor()=20c?= =?UTF-8?q?alls,=20default=20max=5Ffailur=E2=80=A6=20(#6211)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/ray/tune/tune.py | 2 +- rllib/models/model.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 08a8f7bb7..24ab4d58b 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -74,7 +74,7 @@ def run(run_or_experiment, checkpoint_score_attr=None, global_checkpoint_period=10, export_formats=None, - max_failures=3, + max_failures=0, restore=None, search_alg=None, scheduler=None, diff --git a/rllib/models/model.py b/rllib/models/model.py index 309caa460..242fb2ea2 100644 --- a/rllib/models/model.py +++ b/rllib/models/model.py @@ -232,6 +232,10 @@ def restore_original_dimensions(obs, obs_space, tensorlib=tf): return obs +# Cache of preprocessors, for if the user is calling unpack obs often. +_cache = {} + + def _unpack_obs(obs, space, tensorlib=tf): """Unpack a flattened Dict or Tuple observation array/tensor. @@ -243,7 +247,13 @@ def _unpack_obs(obs, space, tensorlib=tf): if (isinstance(space, gym.spaces.Dict) or isinstance(space, gym.spaces.Tuple)): - prep = get_preprocessor(space)(space) + if id(space) in _cache: + prep = _cache[id(space)] + else: + prep = get_preprocessor(space)(space) + # Make an attempt to cache the result, if enough space left. + if len(_cache) < 999: + _cache[id(space)] = prep if len(obs.shape) != 2 or obs.shape[1] != prep.shape[0]: raise ValueError( "Expected flattened obs shape of [None, {}], got {}".format(