mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
33 lines
1.1 KiB
Python
33 lines
1.1 KiB
Python
![]() |
def get_filter_config(shape):
|
||
|
"""Returns a default Conv2D filter config (list) for a given image shape.
|
||
|
|
||
|
Args:
|
||
|
shape (Tuple[int]): The input (image) shape, e.g. (84,84,3).
|
||
|
|
||
|
Returns:
|
||
|
List[list]: The Conv2D filter configuration usable as `conv_filters`
|
||
|
inside a model config dict.
|
||
|
"""
|
||
|
shape = list(shape)
|
||
|
filters_84x84 = [
|
||
|
[16, [8, 8], 4],
|
||
|
[32, [4, 4], 2],
|
||
|
[256, [11, 11], 1],
|
||
|
]
|
||
|
filters_42x42 = [
|
||
|
[16, [4, 4], 2],
|
||
|
[32, [4, 4], 2],
|
||
|
[256, [11, 11], 1],
|
||
|
]
|
||
|
if len(shape) == 3 and shape[:2] == [84, 84]:
|
||
|
return filters_84x84
|
||
|
elif len(shape) == 3 and shape[:2] == [42, 42]:
|
||
|
return filters_42x42
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"No default configuration for obs shape {}".format(shape) +
|
||
|
", you must specify `conv_filters` manually as a model option. "
|
||
|
"Default configurations are only available for inputs of shape "
|
||
|
"[42, 42, K] and [84, 84, K]. You may alternatively want "
|
||
|
"to use a custom model or preprocessor.")
|