[RLlib] Fix AsyncReplayOptimizer bug where it swallows all good worker tasks … (#7111)

This commit is contained in:
Sven Mika 2020-02-11 21:51:44 +01:00 committed by GitHub
parent fea54ab97f
commit 2a0e4d94aa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -3,15 +3,16 @@
https://arxiv.org/abs/1803.00933"""
import collections
import logging
import numpy as np
import os
import random
import time
import threading
import numpy as np
from six.moves import queue
import threading
import time
import ray
from ray.exceptions import RayError
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
@ -27,6 +28,8 @@ SAMPLE_QUEUE_DEPTH = 2
REPLAY_QUEUE_DEPTH = 4
LEARNER_QUEUE_MAX_SIZE = 16
logger = logging.getLogger(__name__)
class AsyncReplayOptimizer(PolicyOptimizer):
"""Main event loop of the Ape-X optimizer (async sampling with replay).
@ -206,19 +209,42 @@ class AsyncReplayOptimizer(PolicyOptimizer):
with self.timers["sample_processing"]:
completed = list(self.sample_tasks.completed())
counts = ray_get_and_free([c[1][1] for c in completed])
for i, (ev, (sample_batch, count)) in enumerate(completed):
sample_timesteps += counts[i]
# First try a batched ray.get().
ray_error = None
try:
counts = {
i: v
for i, v in enumerate(
ray_get_and_free([c[1][1] for c in completed]))
}
# If there are failed workers, try to recover the still good ones
# (via non-batched ray.get()) and store the first error (to raise
# later).
except RayError:
counts = {}
for i, c in enumerate(completed):
try:
counts[i] = ray_get_and_free(c[1][1])
except RayError as e:
logger.exception(
"Error in completed task: {}".format(e))
ray_error = ray_error if ray_error is not None else e
for i, (ev, (sample_batch, count)) in enumerate(completed):
# Skip failed tasks.
if i not in counts:
continue
sample_timesteps += counts[i]
# Send the data to the replay buffer
random.choice(
self.replay_actors).add_batch.remote(sample_batch)
# Update weights if needed
# Update weights if needed.
self.steps_since_update[ev] += counts[i]
if self.steps_since_update[ev] >= self.max_weight_sync_delay:
# Note that it's important to pull new weights once
# updated to avoid excessive correlation between actors
# updated to avoid excessive correlation between actors.
if weights is None or self.learner.weights_updated:
self.learner.weights_updated = False
with self.timers["put_weights"]:
@ -228,9 +254,14 @@ class AsyncReplayOptimizer(PolicyOptimizer):
self.num_weight_syncs += 1
self.steps_since_update[ev] = 0
# Kick off another sample request
# Kick off another sample request.
self.sample_tasks.add(ev, ev.sample_with_count.remote())
# Now that all still good tasks have been kicked off again,
# we can throw the error.
if ray_error:
raise ray_error
with self.timers["replay_processing"]:
for ra, replay in self.replay_tasks.completed():
self.replay_tasks.add(ra, ra.replay.remote())