mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Fixing conv filters config for ComplexInputNetwork (#14749)
This commit is contained in:
parent
8874ccec2d
commit
3df7a010b1
2 changed files with 6 additions and 4 deletions
|
@ -49,8 +49,9 @@ class ComplexInputNetwork(TFModelV2):
|
|||
# Image space.
|
||||
if len(component.shape) == 3:
|
||||
config = {
|
||||
"conv_filters": model_config.get(
|
||||
"conv_filters", get_filter_config(component.shape)),
|
||||
"conv_filters": model_config["conv_filters"]
|
||||
if "conv_filters" in model_config else
|
||||
get_filter_config(obs_space.shape),
|
||||
"conv_activation": model_config.get("conv_activation"),
|
||||
"post_fcnet_hiddens": [],
|
||||
}
|
||||
|
|
|
@ -55,8 +55,9 @@ class ComplexInputNetwork(TorchModelV2, nn.Module):
|
|||
# Image space.
|
||||
if len(component.shape) == 3:
|
||||
config = {
|
||||
"conv_filters": model_config.get(
|
||||
"conv_filters", get_filter_config(component.shape)),
|
||||
"conv_filters": model_config["conv_filters"]
|
||||
if "conv_filters" in model_config else
|
||||
get_filter_config(obs_space.shape),
|
||||
"conv_activation": model_config.get("conv_activation"),
|
||||
"post_fcnet_hiddens": [],
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue