mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Better error message for unsupported non-atari image observation sizes (#3444)
This commit is contained in:
parent
4abafd7e62
commit
7abfbfd2f7
2 changed files with 12 additions and 9 deletions
|
@ -16,7 +16,7 @@ class VisionNetwork(Model):
|
|||
inputs = input_dict["obs"]
|
||||
filters = options.get("conv_filters")
|
||||
if not filters:
|
||||
filters = get_filter_config(options)
|
||||
filters = get_filter_config(inputs)
|
||||
|
||||
activation = get_activation_fn(options.get("conv_activation"))
|
||||
|
||||
|
@ -47,7 +47,7 @@ class VisionNetwork(Model):
|
|||
return flatten(fc2), flatten(fc1)
|
||||
|
||||
|
||||
def get_filter_config(options):
|
||||
def get_filter_config(inputs):
|
||||
filters_84x84 = [
|
||||
[16, [8, 8], 4],
|
||||
[32, [4, 4], 2],
|
||||
|
@ -58,12 +58,15 @@ def get_filter_config(options):
|
|||
[32, [4, 4], 2],
|
||||
[256, [11, 11], 1],
|
||||
]
|
||||
dim = options.get("dim")
|
||||
if dim == 84:
|
||||
shape = inputs.shape.as_list()[1:]
|
||||
if len(shape) == 3 and shape[:2] == [84, 84]:
|
||||
return filters_84x84
|
||||
elif dim == 42:
|
||||
elif len(shape) == 3 and shape[:2] == [42, 42]:
|
||||
return filters_42x42
|
||||
else:
|
||||
raise ValueError(
|
||||
"No default configuration for image size={}".format(dim) +
|
||||
", you must specify `conv_filters` manually as a model option.")
|
||||
"No default configuration for obs input {}".format(inputs) +
|
||||
", you must specify `conv_filters` manually as a model option. "
|
||||
"Default configurations are only available for inputs of size "
|
||||
"[?, 42, 42, K] and [?, 84, 84, K]. You may alternatively want "
|
||||
"to use a custom model or preprocessor.")
|
||||
|
|
|
@ -72,13 +72,13 @@ class ModelCatalogTest(unittest.TestCase):
|
|||
|
||||
with tf.variable_scope("test1"):
|
||||
p1 = ModelCatalog.get_model({
|
||||
"obs": np.zeros((10, 3), dtype=np.float32)
|
||||
"obs": tf.zeros((10, 3), dtype=tf.float32)
|
||||
}, Box(0, 1, shape=(3, ), dtype=np.float32), 5, {})
|
||||
self.assertEqual(type(p1), FullyConnectedNetwork)
|
||||
|
||||
with tf.variable_scope("test2"):
|
||||
p2 = ModelCatalog.get_model({
|
||||
"obs": np.zeros((10, 84, 84, 3), dtype=np.float32)
|
||||
"obs": tf.zeros((10, 84, 84, 3), dtype=tf.float32)
|
||||
}, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), 5, {})
|
||||
self.assertEqual(type(p2), VisionNetwork)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue