ray/rllib/models/utils.py
2020-10-12 15:00:42 +02:00

67 lines
2.2 KiB
Python

from ray.rllib.utils.framework import try_import_tf, try_import_torch
def get_filter_config(shape):
"""Returns a default Conv2D filter config (list) for a given image shape.
Args:
shape (Tuple[int]): The input (image) shape, e.g. (84,84,3).
Returns:
List[list]: The Conv2D filter configuration usable as `conv_filters`
inside a model config dict.
"""
shape = list(shape)
filters_84x84 = [
[16, [8, 8], 4],
[32, [4, 4], 2],
[256, [11, 11], 1],
]
filters_42x42 = [
[16, [4, 4], 2],
[32, [4, 4], 2],
[256, [11, 11], 1],
]
if len(shape) == 3 and shape[:2] == [84, 84]:
return filters_84x84
elif len(shape) == 3 and shape[:2] == [42, 42]:
return filters_42x42
else:
raise ValueError(
"No default configuration for obs shape {}".format(shape) +
", you must specify `conv_filters` manually as a model option. "
"Default configurations are only available for inputs of shape "
"[42, 42, K] and [84, 84, K]. You may alternatively want "
"to use a custom model or preprocessor.")
def get_initializer(name, framework="tf"):
"""Returns a framework specific initializer, given a name string.
Args:
name (str): One of "xavier_uniform" (default), "xavier_normal".
framework (str): One of "tf" or "torch".
Returns:
A framework-specific initializer function, e.g.
tf.keras.initializers.GlorotUniform or
torch.nn.init.xavier_uniform_.
Raises:
ValueError: If name is an unknown initializer.
"""
if framework == "torch":
_, nn = try_import_torch()
if name in [None, "default", "xavier_uniform"]:
return nn.init.xavier_uniform_
elif name == "xavier_normal":
return nn.init.xavier_normal_
else:
tf1, tf, tfv = try_import_tf()
if name in [None, "default", "xavier_uniform"]:
return tf.keras.initializers.GlorotUniform
elif name == "xavier_normal":
return tf.keras.initializers.GlorotNormal
raise ValueError("Unknown activation ({}) for framework={}!".format(
name, framework))