[RLlib] Fixing conv filters config for ComplexInputNetwork (#14749)

This commit is contained in:
Jack Parsons 2021-03-24 15:15:36 +00:00 committed by GitHub
parent 8874ccec2d
commit 3df7a010b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 4 deletions

View file

@ -49,8 +49,9 @@ class ComplexInputNetwork(TFModelV2):
# Image space. # Image space.
if len(component.shape) == 3: if len(component.shape) == 3:
config = { config = {
"conv_filters": model_config.get( "conv_filters": model_config["conv_filters"]
"conv_filters", get_filter_config(component.shape)), if "conv_filters" in model_config else
get_filter_config(obs_space.shape),
"conv_activation": model_config.get("conv_activation"), "conv_activation": model_config.get("conv_activation"),
"post_fcnet_hiddens": [], "post_fcnet_hiddens": [],
} }

View file

@ -55,8 +55,9 @@ class ComplexInputNetwork(TorchModelV2, nn.Module):
# Image space. # Image space.
if len(component.shape) == 3: if len(component.shape) == 3:
config = { config = {
"conv_filters": model_config.get( "conv_filters": model_config["conv_filters"]
"conv_filters", get_filter_config(component.shape)), if "conv_filters" in model_config else
get_filter_config(obs_space.shape),
"conv_activation": model_config.get("conv_activation"), "conv_activation": model_config.get("conv_activation"),
"post_fcnet_hiddens": [], "post_fcnet_hiddens": [],
} }