mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] DQN (Rainbow): Fix torch noisy layer support and loss (#16716)
This commit is contained in:
parent
1fd0eb805e
commit
d553d4da6c
2 changed files with 9 additions and 3 deletions
|
@ -122,7 +122,13 @@ class DQNTorchModel(TorchModelV2, nn.Module):
|
|||
|
||||
# Value layer (nodes=1).
|
||||
if self.dueling:
|
||||
value_module.add_module("V", SlimFC(ins, 1, activation_fn=None))
|
||||
if use_noisy:
|
||||
value_module.add_module(
|
||||
"V",
|
||||
NoisyLayer(ins, self.num_atoms, sigma0, activation=None))
|
||||
elif q_hiddens:
|
||||
value_module.add_module(
|
||||
"V", SlimFC(ins, self.num_atoms, activation_fn=None))
|
||||
self.value_module = value_module
|
||||
|
||||
def get_q_value_distributions(self, model_out):
|
||||
|
|
|
@ -62,7 +62,7 @@ class QLoss:
|
|||
# Indispensable judgement which is missed in most implementations
|
||||
# when b happens to be an integer, lb == ub, so pr_j(s', a*) will
|
||||
# be discarded because (ub-b) == (b-lb) == 0.
|
||||
floor_equal_ceil = (ub - lb < 0.5).float()
|
||||
floor_equal_ceil = ((ub - lb) < 0.5).float()
|
||||
|
||||
# (batch_size, num_atoms, num_atoms)
|
||||
l_project = F.one_hot(lb.long(), num_atoms)
|
||||
|
@ -79,7 +79,7 @@ class QLoss:
|
|||
# Rainbow paper claims that using this cross entropy loss for
|
||||
# priority is robust and insensitive to `prioritized_replay_alpha`
|
||||
self.td_error = softmax_cross_entropy_with_logits(
|
||||
logits=q_logits_t_selected, labels=m)
|
||||
logits=q_logits_t_selected, labels=m.detach())
|
||||
self.loss = torch.mean(self.td_error * importance_weights)
|
||||
self.stats = {
|
||||
# TODO: better Q stats for dist dqn
|
||||
|
|
Loading…
Add table
Reference in a new issue