[rllib] Fix error in shape calculation. (#7301)

This commit is contained in:
Matthew Brulhardt 2020-02-25 17:16:29 -05:00 committed by GitHub
parent f14b6e477b
commit 75f683eec6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -229,6 +229,7 @@ class ModelCatalog:
else: else:
all_discrete = False all_discrete = False
size += np.product(action_space.spaces[i].shape) size += np.product(action_space.spaces[i].shape)
size = int(size)
return (tf.int64 if all_discrete else tf.float32, (None, size)) return (tf.int64 if all_discrete else tf.float32, (None, size))
elif isinstance(action_space, gym.spaces.Dict): elif isinstance(action_space, gym.spaces.Dict):
raise NotImplementedError( raise NotImplementedError(