mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Cast fcnet_hiddens to list for DQN models (list vs tuple mismatch error) (#14308)
This commit is contained in:
parent
adbdacae58
commit
d9e5d5f47a
3 changed files with 3 additions and 3 deletions
|
@ -159,7 +159,7 @@ def build_q_model(policy: Policy, obs_space: gym.spaces.Space,
|
||||||
|
|
||||||
if config["hiddens"]:
|
if config["hiddens"]:
|
||||||
# try to infer the last layer size, otherwise fall back to 256
|
# try to infer the last layer size, otherwise fall back to 256
|
||||||
num_outputs = ([256] + config["model"]["fcnet_hiddens"])[-1]
|
num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1]
|
||||||
config["model"]["no_final_linear"] = True
|
config["model"]["no_final_linear"] = True
|
||||||
else:
|
else:
|
||||||
num_outputs = action_space.n
|
num_outputs = action_space.n
|
||||||
|
|
|
@ -150,7 +150,7 @@ def build_q_model_and_distribution(
|
||||||
|
|
||||||
if config["hiddens"]:
|
if config["hiddens"]:
|
||||||
# try to infer the last layer size, otherwise fall back to 256
|
# try to infer the last layer size, otherwise fall back to 256
|
||||||
num_outputs = ([256] + config["model"]["fcnet_hiddens"])[-1]
|
num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1]
|
||||||
config["model"]["no_final_linear"] = True
|
config["model"]["no_final_linear"] = True
|
||||||
else:
|
else:
|
||||||
num_outputs = action_space.n
|
num_outputs = action_space.n
|
||||||
|
|
|
@ -24,7 +24,7 @@ class FullyConnectedNetwork(TorchModelV2, nn.Module):
|
||||||
model_config, name)
|
model_config, name)
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
hiddens = model_config.get("fcnet_hiddens", []) + \
|
hiddens = list(model_config.get("fcnet_hiddens", [])) + \
|
||||||
model_config.get("post_fcnet_hiddens", [])
|
model_config.get("post_fcnet_hiddens", [])
|
||||||
activation = model_config.get("fcnet_activation")
|
activation = model_config.get("fcnet_activation")
|
||||||
if not model_config.get("fcnet_hiddens", []):
|
if not model_config.get("fcnet_hiddens", []):
|
||||||
|
|
Loading…
Add table
Reference in a new issue