[rllib] Fix edge case in n-step calculation and non-apex replay prioritization (#2929)

* fix

* lint
This commit is contained in:
Eric Liang 2018-09-28 15:22:33 -07:00 committed by GitHub
parent 4ffe1e3556
commit f1c55497ce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 25 deletions

View file

@ -453,21 +453,18 @@ def adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
The ith new_obs is also adjusted to point to the (i+n_step-1)'th new obs.
If the episode finishes, the reward will be truncated. After this rewrite,
all the arrays will be shortened by (n_step - 1).
At the end of the trajectory, n is truncated to fit in the traj length.
"""
for i in range(len(rewards) - n_step + 1):
if dones[i]:
continue # episode end
assert not any(dones[:-1]), "Unexpected done in middle of trajectory"
traj_length = len(rewards)
for i in range(traj_length):
for j in range(1, n_step):
new_obs[i] = new_obs[i + j]
rewards[i] += gamma**j * rewards[i + j]
if dones[i + j]:
break # episode end
# truncate ends of the trajectory
new_len = len(obs) - n_step + 1
for arr in [obs, actions, rewards, new_obs, dones]:
del arr[new_len:]
if i + j < traj_length:
new_obs[i] = new_obs[i + j]
dones[i] = dones[i + j]
rewards[i] += gamma**j * rewards[i + j]
def _postprocess_dqn(policy_graph, sample_batch):

View file

@ -92,13 +92,13 @@ class SyncReplayOptimizer(PolicyOptimizer):
for policy_id, s in batch.policy_batches.items():
for row in s.rows():
if "weights" not in row:
row["weights"] = np.ones_like(row["rewards"])
self.replay_buffers[policy_id].add(
pack_if_needed(row["obs"]),
row["actions"], row["rewards"],
pack_if_needed(row["new_obs"]), row["dones"],
row["weights"])
row["actions"],
row["rewards"],
pack_if_needed(row["new_obs"]),
row["dones"],
weight=None)
if self.num_steps_sampled >= self.replay_starts:
self._optimize()

View file

@ -11,15 +11,16 @@ class DQNTest(unittest.TestCase):
def testNStep(self):
obs = [1, 2, 3, 4, 5, 6, 7]
actions = ["a", "b", "a", "a", "a", "b", "a"]
rewards = [10.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100000.0]
rewards = [10.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100.0]
new_obs = [2, 3, 4, 5, 6, 7, 8]
dones = [1, 0, 0, 0, 0, 1, 0]
dones = [0, 0, 0, 0, 0, 0, 1]
adjust_nstep(3, 0.9, obs, actions, rewards, new_obs, dones)
self.assertEqual(obs, [1, 2, 3, 4, 5])
self.assertEqual(actions, ["a", "b", "a", "a", "a"])
self.assertEqual(rewards, [10.0, 171.0, 271.0, 271.0, 190.0])
self.assertEqual(new_obs, [2, 5, 6, 7, 7])
self.assertEqual(dones, [1, 0, 0, 0, 0])
self.assertEqual(obs, [1, 2, 3, 4, 5, 6, 7])
self.assertEqual(actions, ["a", "b", "a", "a", "a", "b", "a"])
self.assertEqual(new_obs, [4, 5, 6, 7, 8, 8, 8])
self.assertEqual(dones, [0, 0, 0, 0, 1, 1, 1])
self.assertEqual(rewards,
[91.0, 171.0, 271.0, 271.0, 271.0, 190.0, 100.0])
if __name__ == '__main__':