[rllib] Better error message for unsupported non-atari image observation sizes (#3444)

This commit is contained in:
Eric Liang 2018-12-03 01:24:36 -08:00 committed by GitHub
parent 4abafd7e62
commit 7abfbfd2f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 9 deletions

View file

@ -16,7 +16,7 @@ class VisionNetwork(Model):
inputs = input_dict["obs"]
filters = options.get("conv_filters")
if not filters:
filters = get_filter_config(options)
filters = get_filter_config(inputs)
activation = get_activation_fn(options.get("conv_activation"))
@ -47,7 +47,7 @@ class VisionNetwork(Model):
return flatten(fc2), flatten(fc1)
def get_filter_config(options):
def get_filter_config(inputs):
filters_84x84 = [
[16, [8, 8], 4],
[32, [4, 4], 2],
@ -58,12 +58,15 @@ def get_filter_config(options):
[32, [4, 4], 2],
[256, [11, 11], 1],
]
dim = options.get("dim")
if dim == 84:
shape = inputs.shape.as_list()[1:]
if len(shape) == 3 and shape[:2] == [84, 84]:
return filters_84x84
elif dim == 42:
elif len(shape) == 3 and shape[:2] == [42, 42]:
return filters_42x42
else:
raise ValueError(
"No default configuration for image size={}".format(dim) +
", you must specify `conv_filters` manually as a model option.")
"No default configuration for obs input {}".format(inputs) +
", you must specify `conv_filters` manually as a model option. "
"Default configurations are only available for inputs of size "
"[?, 42, 42, K] and [?, 84, 84, K]. You may alternatively want "
"to use a custom model or preprocessor.")

View file

@ -72,13 +72,13 @@ class ModelCatalogTest(unittest.TestCase):
with tf.variable_scope("test1"):
p1 = ModelCatalog.get_model({
"obs": np.zeros((10, 3), dtype=np.float32)
"obs": tf.zeros((10, 3), dtype=tf.float32)
}, Box(0, 1, shape=(3, ), dtype=np.float32), 5, {})
self.assertEqual(type(p1), FullyConnectedNetwork)
with tf.variable_scope("test2"):
p2 = ModelCatalog.get_model({
"obs": np.zeros((10, 84, 84, 3), dtype=np.float32)
"obs": tf.zeros((10, 84, 84, 3), dtype=tf.float32)
}, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), 5, {})
self.assertEqual(type(p2), VisionNetwork)