diff --git a/rllib/models/tf/complex_input_net.py b/rllib/models/tf/complex_input_net.py index 8bc691e24..235701854 100644 --- a/rllib/models/tf/complex_input_net.py +++ b/rllib/models/tf/complex_input_net.py @@ -49,8 +49,9 @@ class ComplexInputNetwork(TFModelV2): # Image space. if len(component.shape) == 3: config = { - "conv_filters": model_config.get( - "conv_filters", get_filter_config(component.shape)), + "conv_filters": model_config["conv_filters"] + if "conv_filters" in model_config else + get_filter_config(obs_space.shape), "conv_activation": model_config.get("conv_activation"), "post_fcnet_hiddens": [], } diff --git a/rllib/models/torch/complex_input_net.py b/rllib/models/torch/complex_input_net.py index 2b9601947..db72d6c9e 100644 --- a/rllib/models/torch/complex_input_net.py +++ b/rllib/models/torch/complex_input_net.py @@ -55,8 +55,9 @@ class ComplexInputNetwork(TorchModelV2, nn.Module): # Image space. if len(component.shape) == 3: config = { - "conv_filters": model_config.get( - "conv_filters", get_filter_config(component.shape)), + "conv_filters": model_config["conv_filters"] + if "conv_filters" in model_config else + get_filter_config(obs_space.shape), "conv_activation": model_config.get("conv_activation"), "post_fcnet_hiddens": [], }