[RLlib] Cast fcnet_hiddens to list for DQN models (list vs tuple mismatch error) (#14308)

This commit is contained in:
Kai Fricke 2021-02-25 08:06:08 +01:00 committed by GitHub
parent adbdacae58
commit d9e5d5f47a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 3 deletions

View file

@ -159,7 +159,7 @@ def build_q_model(policy: Policy, obs_space: gym.spaces.Space,
if config["hiddens"]: if config["hiddens"]:
# try to infer the last layer size, otherwise fall back to 256 # try to infer the last layer size, otherwise fall back to 256
num_outputs = ([256] + config["model"]["fcnet_hiddens"])[-1] num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1]
config["model"]["no_final_linear"] = True config["model"]["no_final_linear"] = True
else: else:
num_outputs = action_space.n num_outputs = action_space.n

View file

@ -150,7 +150,7 @@ def build_q_model_and_distribution(
if config["hiddens"]: if config["hiddens"]:
# try to infer the last layer size, otherwise fall back to 256 # try to infer the last layer size, otherwise fall back to 256
num_outputs = ([256] + config["model"]["fcnet_hiddens"])[-1] num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1]
config["model"]["no_final_linear"] = True config["model"]["no_final_linear"] = True
else: else:
num_outputs = action_space.n num_outputs = action_space.n

View file

@ -24,7 +24,7 @@ class FullyConnectedNetwork(TorchModelV2, nn.Module):
model_config, name) model_config, name)
nn.Module.__init__(self) nn.Module.__init__(self)
hiddens = model_config.get("fcnet_hiddens", []) + \ hiddens = list(model_config.get("fcnet_hiddens", [])) + \
model_config.get("post_fcnet_hiddens", []) model_config.get("post_fcnet_hiddens", [])
activation = model_config.get("fcnet_activation") activation = model_config.get("fcnet_activation")
if not model_config.get("fcnet_hiddens", []): if not model_config.get("fcnet_hiddens", []):