mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[rllib/tune] Cache get_preprocessor() calls, default max_failur… (#6211)
This commit is contained in:
parent
d3227f2f2d
commit
7559fdb141
2 changed files with 12 additions and 2 deletions
|
@ -74,7 +74,7 @@ def run(run_or_experiment,
|
|||
checkpoint_score_attr=None,
|
||||
global_checkpoint_period=10,
|
||||
export_formats=None,
|
||||
max_failures=3,
|
||||
max_failures=0,
|
||||
restore=None,
|
||||
search_alg=None,
|
||||
scheduler=None,
|
||||
|
|
|
@ -232,6 +232,10 @@ def restore_original_dimensions(obs, obs_space, tensorlib=tf):
|
|||
return obs
|
||||
|
||||
|
||||
# Cache of preprocessors, for if the user is calling unpack obs often.
|
||||
_cache = {}
|
||||
|
||||
|
||||
def _unpack_obs(obs, space, tensorlib=tf):
|
||||
"""Unpack a flattened Dict or Tuple observation array/tensor.
|
||||
|
||||
|
@ -243,7 +247,13 @@ def _unpack_obs(obs, space, tensorlib=tf):
|
|||
|
||||
if (isinstance(space, gym.spaces.Dict)
|
||||
or isinstance(space, gym.spaces.Tuple)):
|
||||
prep = get_preprocessor(space)(space)
|
||||
if id(space) in _cache:
|
||||
prep = _cache[id(space)]
|
||||
else:
|
||||
prep = get_preprocessor(space)(space)
|
||||
# Make an attempt to cache the result, if enough space left.
|
||||
if len(_cache) < 999:
|
||||
_cache[id(space)] = prep
|
||||
if len(obs.shape) != 2 or obs.shape[1] != prep.shape[0]:
|
||||
raise ValueError(
|
||||
"Expected flattened obs shape of [None, {}], got {}".format(
|
||||
|
|
Loading…
Add table
Reference in a new issue