[RLlib] Fix dnc input shape (#15939)

Co-authored-by: Steven Morad <sm2558@cam.ac.uk>
This commit is contained in:
Steven Morad 2021-05-20 19:06:02 -07:00 committed by GitHub
parent 64fdac83a7
commit 581d63e607
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -38,10 +38,11 @@ class DNCMemory(TorchModelV2, nn.Module):
# feeding to the DNC
"preprocessor": torch.nn.Sequential(
torch.nn.Linear(64, 64), torch.nn.Tanh()),
# The input and output sizes of the
# preprocessor module
# Input size to the preprocessor
"preprocessor_input_size": 64,
"preprocessor_output_size": 64
# The output size of the preprocessor
# and the input size of the dnc
"preprocessor_output_size": 64,
}
MEMORY_KEYS = [
@ -77,18 +78,17 @@ class DNCMemory(TorchModelV2, nn.Module):
self.preprocessor = torch.nn.Sequential(
torch.nn.Linear(self.obs_dim, self.cfg["preprocessor_input_size"]),
self.cfg["preprocessor"],
torch.nn.Linear(self.cfg["preprocessor_output_size"],
self.obs_dim))
)
self.logit_branch = SlimFC(
in_size=self.obs_dim,
in_size=self.cfg["hidden_size"],
out_size=self.num_outputs,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
self.value_branch = SlimFC(
in_size=self.obs_dim,
in_size=self.cfg["hidden_size"],
out_size=1,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
@ -189,7 +189,7 @@ class DNCMemory(TorchModelV2, nn.Module):
def build_dnc(self, device_idx: Union[int, None]) -> None:
self.dnc = self.cfg["dnc_model"](
input_size=self.obs_dim,
input_size=self.cfg["preprocessor_output_size"],
hidden_size=self.cfg["hidden_size"],
num_layers=self.cfg["num_layers"],
num_hidden_layers=self.cfg["num_hidden_layers"],
@ -226,8 +226,8 @@ class DNCMemory(TorchModelV2, nn.Module):
hidden = self.unpack_state(state) # type: ignore
# Run thru preprocessor before DNC
z = self.preprocessor(flat.reshape(B * T, flat.shape[-1]))
z = z.reshape(B, T, flat.shape[-1])
z = self.preprocessor(flat.reshape(B * T, self.obs_dim))
z = z.reshape(B, T, self.cfg["preprocessor_output_size"])
output, hidden = self.dnc(z, hidden)
packed_state = self.pack_state(*hidden)