mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Fix dnc input shape (#15939)
Co-authored-by: Steven Morad <sm2558@cam.ac.uk>
This commit is contained in:
parent
64fdac83a7
commit
581d63e607
1 changed files with 10 additions and 10 deletions
|
@ -38,10 +38,11 @@ class DNCMemory(TorchModelV2, nn.Module):
|
|||
# feeding to the DNC
|
||||
"preprocessor": torch.nn.Sequential(
|
||||
torch.nn.Linear(64, 64), torch.nn.Tanh()),
|
||||
# The input and output sizes of the
|
||||
# preprocessor module
|
||||
# Input size to the preprocessor
|
||||
"preprocessor_input_size": 64,
|
||||
"preprocessor_output_size": 64
|
||||
# The output size of the preprocessor
|
||||
# and the input size of the dnc
|
||||
"preprocessor_output_size": 64,
|
||||
}
|
||||
|
||||
MEMORY_KEYS = [
|
||||
|
@ -77,18 +78,17 @@ class DNCMemory(TorchModelV2, nn.Module):
|
|||
self.preprocessor = torch.nn.Sequential(
|
||||
torch.nn.Linear(self.obs_dim, self.cfg["preprocessor_input_size"]),
|
||||
self.cfg["preprocessor"],
|
||||
torch.nn.Linear(self.cfg["preprocessor_output_size"],
|
||||
self.obs_dim))
|
||||
)
|
||||
|
||||
self.logit_branch = SlimFC(
|
||||
in_size=self.obs_dim,
|
||||
in_size=self.cfg["hidden_size"],
|
||||
out_size=self.num_outputs,
|
||||
activation_fn=None,
|
||||
initializer=torch.nn.init.xavier_uniform_,
|
||||
)
|
||||
|
||||
self.value_branch = SlimFC(
|
||||
in_size=self.obs_dim,
|
||||
in_size=self.cfg["hidden_size"],
|
||||
out_size=1,
|
||||
activation_fn=None,
|
||||
initializer=torch.nn.init.xavier_uniform_,
|
||||
|
@ -189,7 +189,7 @@ class DNCMemory(TorchModelV2, nn.Module):
|
|||
|
||||
def build_dnc(self, device_idx: Union[int, None]) -> None:
|
||||
self.dnc = self.cfg["dnc_model"](
|
||||
input_size=self.obs_dim,
|
||||
input_size=self.cfg["preprocessor_output_size"],
|
||||
hidden_size=self.cfg["hidden_size"],
|
||||
num_layers=self.cfg["num_layers"],
|
||||
num_hidden_layers=self.cfg["num_hidden_layers"],
|
||||
|
@ -226,8 +226,8 @@ class DNCMemory(TorchModelV2, nn.Module):
|
|||
hidden = self.unpack_state(state) # type: ignore
|
||||
|
||||
# Run thru preprocessor before DNC
|
||||
z = self.preprocessor(flat.reshape(B * T, flat.shape[-1]))
|
||||
z = z.reshape(B, T, flat.shape[-1])
|
||||
z = self.preprocessor(flat.reshape(B * T, self.obs_dim))
|
||||
z = z.reshape(B, T, self.cfg["preprocessor_output_size"])
|
||||
output, hidden = self.dnc(z, hidden)
|
||||
packed_state = self.pack_state(*hidden)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue