mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Fix AsyncReplayOptimizer bug where it swallows all good worker tasks … (#7111)
This commit is contained in:
parent
fea54ab97f
commit
2a0e4d94aa
1 changed files with 41 additions and 10 deletions
|
@ -3,15 +3,16 @@
|
|||
https://arxiv.org/abs/1803.00933"""
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
from six.moves import queue
|
||||
import threading
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayError
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
|
@ -27,6 +28,8 @@ SAMPLE_QUEUE_DEPTH = 2
|
|||
REPLAY_QUEUE_DEPTH = 4
|
||||
LEARNER_QUEUE_MAX_SIZE = 16
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
"""Main event loop of the Ape-X optimizer (async sampling with replay).
|
||||
|
@ -206,19 +209,42 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
|||
|
||||
with self.timers["sample_processing"]:
|
||||
completed = list(self.sample_tasks.completed())
|
||||
counts = ray_get_and_free([c[1][1] for c in completed])
|
||||
for i, (ev, (sample_batch, count)) in enumerate(completed):
|
||||
sample_timesteps += counts[i]
|
||||
# First try a batched ray.get().
|
||||
ray_error = None
|
||||
try:
|
||||
counts = {
|
||||
i: v
|
||||
for i, v in enumerate(
|
||||
ray_get_and_free([c[1][1] for c in completed]))
|
||||
}
|
||||
# If there are failed workers, try to recover the still good ones
|
||||
# (via non-batched ray.get()) and store the first error (to raise
|
||||
# later).
|
||||
except RayError:
|
||||
counts = {}
|
||||
for i, c in enumerate(completed):
|
||||
try:
|
||||
counts[i] = ray_get_and_free(c[1][1])
|
||||
except RayError as e:
|
||||
logger.exception(
|
||||
"Error in completed task: {}".format(e))
|
||||
ray_error = ray_error if ray_error is not None else e
|
||||
|
||||
for i, (ev, (sample_batch, count)) in enumerate(completed):
|
||||
# Skip failed tasks.
|
||||
if i not in counts:
|
||||
continue
|
||||
|
||||
sample_timesteps += counts[i]
|
||||
# Send the data to the replay buffer
|
||||
random.choice(
|
||||
self.replay_actors).add_batch.remote(sample_batch)
|
||||
|
||||
# Update weights if needed
|
||||
# Update weights if needed.
|
||||
self.steps_since_update[ev] += counts[i]
|
||||
if self.steps_since_update[ev] >= self.max_weight_sync_delay:
|
||||
# Note that it's important to pull new weights once
|
||||
# updated to avoid excessive correlation between actors
|
||||
# updated to avoid excessive correlation between actors.
|
||||
if weights is None or self.learner.weights_updated:
|
||||
self.learner.weights_updated = False
|
||||
with self.timers["put_weights"]:
|
||||
|
@ -228,9 +254,14 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
|||
self.num_weight_syncs += 1
|
||||
self.steps_since_update[ev] = 0
|
||||
|
||||
# Kick off another sample request
|
||||
# Kick off another sample request.
|
||||
self.sample_tasks.add(ev, ev.sample_with_count.remote())
|
||||
|
||||
# Now that all still good tasks have been kicked off again,
|
||||
# we can throw the error.
|
||||
if ray_error:
|
||||
raise ray_error
|
||||
|
||||
with self.timers["replay_processing"]:
|
||||
for ra, replay in self.replay_tasks.completed():
|
||||
self.replay_tasks.add(ra, ra.replay.remote())
|
||||
|
|
Loading…
Add table
Reference in a new issue