From 2a0e4d94aa162feadbd2d4d7ca6b0693d21aa12f Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Tue, 11 Feb 2020 21:51:44 +0100 Subject: [PATCH] =?UTF-8?q?[RLlib]=20Fix=20AsyncReplayOptimizer=20bug=20wh?= =?UTF-8?q?ere=20it=20swallows=20all=20good=20worker=20tasks=20=E2=80=A6?= =?UTF-8?q?=20(#7111)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rllib/optimizers/async_replay_optimizer.py | 51 +++++++++++++++++----- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/rllib/optimizers/async_replay_optimizer.py b/rllib/optimizers/async_replay_optimizer.py index a1a0d00dc..59a5ade1e 100644 --- a/rllib/optimizers/async_replay_optimizer.py +++ b/rllib/optimizers/async_replay_optimizer.py @@ -3,15 +3,16 @@ https://arxiv.org/abs/1803.00933""" import collections +import logging +import numpy as np import os import random -import time -import threading - -import numpy as np from six.moves import queue +import threading +import time import ray +from ray.exceptions import RayError from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch @@ -27,6 +28,8 @@ SAMPLE_QUEUE_DEPTH = 2 REPLAY_QUEUE_DEPTH = 4 LEARNER_QUEUE_MAX_SIZE = 16 +logger = logging.getLogger(__name__) + class AsyncReplayOptimizer(PolicyOptimizer): """Main event loop of the Ape-X optimizer (async sampling with replay). @@ -206,19 +209,42 @@ class AsyncReplayOptimizer(PolicyOptimizer): with self.timers["sample_processing"]: completed = list(self.sample_tasks.completed()) - counts = ray_get_and_free([c[1][1] for c in completed]) - for i, (ev, (sample_batch, count)) in enumerate(completed): - sample_timesteps += counts[i] + # First try a batched ray.get(). + ray_error = None + try: + counts = { + i: v + for i, v in enumerate( + ray_get_and_free([c[1][1] for c in completed])) + } + # If there are failed workers, try to recover the still good ones + # (via non-batched ray.get()) and store the first error (to raise + # later). + except RayError: + counts = {} + for i, c in enumerate(completed): + try: + counts[i] = ray_get_and_free(c[1][1]) + except RayError as e: + logger.exception( + "Error in completed task: {}".format(e)) + ray_error = ray_error if ray_error is not None else e + for i, (ev, (sample_batch, count)) in enumerate(completed): + # Skip failed tasks. + if i not in counts: + continue + + sample_timesteps += counts[i] # Send the data to the replay buffer random.choice( self.replay_actors).add_batch.remote(sample_batch) - # Update weights if needed + # Update weights if needed. self.steps_since_update[ev] += counts[i] if self.steps_since_update[ev] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once - # updated to avoid excessive correlation between actors + # updated to avoid excessive correlation between actors. if weights is None or self.learner.weights_updated: self.learner.weights_updated = False with self.timers["put_weights"]: @@ -228,9 +254,14 @@ class AsyncReplayOptimizer(PolicyOptimizer): self.num_weight_syncs += 1 self.steps_since_update[ev] = 0 - # Kick off another sample request + # Kick off another sample request. self.sample_tasks.add(ev, ev.sample_with_count.remote()) + # Now that all still good tasks have been kicked off again, + # we can throw the error. + if ray_error: + raise ray_error + with self.timers["replay_processing"]: for ra, replay in self.replay_tasks.completed(): self.replay_tasks.add(ra, ra.replay.remote())