[RLlib] DQN (Rainbow): Fix torch noisy layer support and loss (#16716)

This commit is contained in:
Grzegorz Bartyzel 2021-07-13 22:48:06 +02:00 committed by GitHub
parent 1fd0eb805e
commit d553d4da6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 3 deletions

View file

@ -122,7 +122,13 @@ class DQNTorchModel(TorchModelV2, nn.Module):
# Value layer (nodes=1).
if self.dueling:
value_module.add_module("V", SlimFC(ins, 1, activation_fn=None))
if use_noisy:
value_module.add_module(
"V",
NoisyLayer(ins, self.num_atoms, sigma0, activation=None))
elif q_hiddens:
value_module.add_module(
"V", SlimFC(ins, self.num_atoms, activation_fn=None))
self.value_module = value_module
def get_q_value_distributions(self, model_out):

View file

@ -62,7 +62,7 @@ class QLoss:
# Indispensable judgement which is missed in most implementations
# when b happens to be an integer, lb == ub, so pr_j(s', a*) will
# be discarded because (ub-b) == (b-lb) == 0.
floor_equal_ceil = (ub - lb < 0.5).float()
floor_equal_ceil = ((ub - lb) < 0.5).float()
# (batch_size, num_atoms, num_atoms)
l_project = F.one_hot(lb.long(), num_atoms)
@ -79,7 +79,7 @@ class QLoss:
# Rainbow paper claims that using this cross entropy loss for
# priority is robust and insensitive to `prioritized_replay_alpha`
self.td_error = softmax_cross_entropy_with_logits(
logits=q_logits_t_selected, labels=m)
logits=q_logits_t_selected, labels=m.detach())
self.loss = torch.mean(self.td_error * importance_weights)
self.stats = {
# TODO: better Q stats for dist dqn