mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Fix edge case in n-step calculation and non-apex replay prioritization (#2929)
* fix * lint
This commit is contained in:
parent
4ffe1e3556
commit
f1c55497ce
3 changed files with 23 additions and 25 deletions
|
@ -453,21 +453,18 @@ def adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
|
|||
|
||||
The ith new_obs is also adjusted to point to the (i+n_step-1)'th new obs.
|
||||
|
||||
If the episode finishes, the reward will be truncated. After this rewrite,
|
||||
all the arrays will be shortened by (n_step - 1).
|
||||
At the end of the trajectory, n is truncated to fit in the traj length.
|
||||
"""
|
||||
for i in range(len(rewards) - n_step + 1):
|
||||
if dones[i]:
|
||||
continue # episode end
|
||||
|
||||
assert not any(dones[:-1]), "Unexpected done in middle of trajectory"
|
||||
|
||||
traj_length = len(rewards)
|
||||
for i in range(traj_length):
|
||||
for j in range(1, n_step):
|
||||
new_obs[i] = new_obs[i + j]
|
||||
rewards[i] += gamma**j * rewards[i + j]
|
||||
if dones[i + j]:
|
||||
break # episode end
|
||||
# truncate ends of the trajectory
|
||||
new_len = len(obs) - n_step + 1
|
||||
for arr in [obs, actions, rewards, new_obs, dones]:
|
||||
del arr[new_len:]
|
||||
if i + j < traj_length:
|
||||
new_obs[i] = new_obs[i + j]
|
||||
dones[i] = dones[i + j]
|
||||
rewards[i] += gamma**j * rewards[i + j]
|
||||
|
||||
|
||||
def _postprocess_dqn(policy_graph, sample_batch):
|
||||
|
|
|
@ -92,13 +92,13 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
|||
|
||||
for policy_id, s in batch.policy_batches.items():
|
||||
for row in s.rows():
|
||||
if "weights" not in row:
|
||||
row["weights"] = np.ones_like(row["rewards"])
|
||||
self.replay_buffers[policy_id].add(
|
||||
pack_if_needed(row["obs"]),
|
||||
row["actions"], row["rewards"],
|
||||
pack_if_needed(row["new_obs"]), row["dones"],
|
||||
row["weights"])
|
||||
row["actions"],
|
||||
row["rewards"],
|
||||
pack_if_needed(row["new_obs"]),
|
||||
row["dones"],
|
||||
weight=None)
|
||||
|
||||
if self.num_steps_sampled >= self.replay_starts:
|
||||
self._optimize()
|
||||
|
|
|
@ -11,15 +11,16 @@ class DQNTest(unittest.TestCase):
|
|||
def testNStep(self):
|
||||
obs = [1, 2, 3, 4, 5, 6, 7]
|
||||
actions = ["a", "b", "a", "a", "a", "b", "a"]
|
||||
rewards = [10.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100000.0]
|
||||
rewards = [10.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100.0]
|
||||
new_obs = [2, 3, 4, 5, 6, 7, 8]
|
||||
dones = [1, 0, 0, 0, 0, 1, 0]
|
||||
dones = [0, 0, 0, 0, 0, 0, 1]
|
||||
adjust_nstep(3, 0.9, obs, actions, rewards, new_obs, dones)
|
||||
self.assertEqual(obs, [1, 2, 3, 4, 5])
|
||||
self.assertEqual(actions, ["a", "b", "a", "a", "a"])
|
||||
self.assertEqual(rewards, [10.0, 171.0, 271.0, 271.0, 190.0])
|
||||
self.assertEqual(new_obs, [2, 5, 6, 7, 7])
|
||||
self.assertEqual(dones, [1, 0, 0, 0, 0])
|
||||
self.assertEqual(obs, [1, 2, 3, 4, 5, 6, 7])
|
||||
self.assertEqual(actions, ["a", "b", "a", "a", "a", "b", "a"])
|
||||
self.assertEqual(new_obs, [4, 5, 6, 7, 8, 8, 8])
|
||||
self.assertEqual(dones, [0, 0, 0, 0, 1, 1, 1])
|
||||
self.assertEqual(rewards,
|
||||
[91.0, 171.0, 271.0, 271.0, 271.0, 190.0, 100.0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Add table
Reference in a new issue