ray/rllib/optimizers/aso_aggregator.py

209 lines
7.8 KiB
Python

"""Helper class for AsyncSamplesOptimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
import ray
from ray.rllib.utils.actors import TaskPool
from ray.rllib.utils.annotations import override
from ray.rllib.utils.memory import ray_get_and_free
class Aggregator(object):
"""An aggregator collects and processes samples from workers.
This class is used to abstract away the strategy for sample collection.
For example, you may want to use a tree of actors to collect samples. The
use of multiple actors can be necessary to offload expensive work such
as concatenating and decompressing sample batches.
Attributes:
local_worker: local RolloutWorker copy
"""
def iter_train_batches(self):
"""Returns a generator over batches ready to learn on.
Iterating through this generator will also send out weight updates to
remote workers as needed.
This call may block until results are available.
"""
raise NotImplementedError
def broadcast_new_weights(self):
"""Broadcast a new set of weights from the local workers."""
raise NotImplementedError
def should_broadcast(self):
"""Returns whether broadcast() should be called to update weights."""
raise NotImplementedError
def stats(self):
"""Returns runtime statistics for debugging."""
raise NotImplementedError
def reset(self, remote_workers):
"""Called to change the set of remote workers being used."""
raise NotImplementedError
class AggregationWorkerBase(object):
"""Aggregators should extend from this class."""
def __init__(self, initial_weights_obj_id, remote_workers,
max_sample_requests_in_flight_per_worker, replay_proportion,
replay_buffer_num_slots, train_batch_size, sample_batch_size):
"""Initialize an aggregator.
Arguments:
initial_weights_obj_id (ObjectID): initial worker weights
remote_workers (list): set of remote workers assigned to this agg
max_sample_request_in_flight_per_worker (int): max queue size per
worker
replay_proportion (float): ratio of replay to sampled outputs
replay_buffer_num_slots (int): max number of sample batches to
store in the replay buffer
train_batch_size (int): size of batches to learn on
sample_batch_size (int): size of batches to sample from workers
"""
self.broadcasted_weights = initial_weights_obj_id
self.remote_workers = remote_workers
self.sample_batch_size = sample_batch_size
self.train_batch_size = train_batch_size
if replay_proportion:
if replay_buffer_num_slots * sample_batch_size <= train_batch_size:
raise ValueError(
"Replay buffer size is too small to produce train, "
"please increase replay_buffer_num_slots.",
replay_buffer_num_slots, sample_batch_size,
train_batch_size)
# Kick off async background sampling
self.sample_tasks = TaskPool()
for ev in self.remote_workers:
ev.set_weights.remote(self.broadcasted_weights)
for _ in range(max_sample_requests_in_flight_per_worker):
self.sample_tasks.add(ev, ev.sample.remote())
self.batch_buffer = []
self.replay_proportion = replay_proportion
self.replay_buffer_num_slots = replay_buffer_num_slots
self.replay_batches = []
self.replay_index = 0
self.num_sent_since_broadcast = 0
self.num_weight_syncs = 0
self.num_replayed = 0
@override(Aggregator)
def iter_train_batches(self, max_yield=999):
"""Iterate over train batches.
Arguments:
max_yield (int): Max number of batches to iterate over in this
cycle. Setting this avoids iter_train_batches returning too
much data at once.
"""
for ev, sample_batch in self._augment_with_replay(
self.sample_tasks.completed_prefetch(
blocking_wait=True, max_yield=max_yield)):
sample_batch.decompress_if_needed()
self.batch_buffer.append(sample_batch)
if sum(b.count
for b in self.batch_buffer) >= self.train_batch_size:
if len(self.batch_buffer) == 1:
# make a defensive copy to avoid sharing plasma memory
# across multiple threads
train_batch = self.batch_buffer[0].copy()
else:
train_batch = self.batch_buffer[0].concat_samples(
self.batch_buffer)
yield train_batch
self.batch_buffer = []
# If the batch was replayed, skip the update below.
if ev is None:
continue
# Put in replay buffer if enabled
if self.replay_buffer_num_slots > 0:
if len(self.replay_batches) < self.replay_buffer_num_slots:
self.replay_batches.append(sample_batch)
else:
self.replay_batches[self.replay_index] = sample_batch
self.replay_index += 1
self.replay_index %= self.replay_buffer_num_slots
ev.set_weights.remote(self.broadcasted_weights)
self.num_weight_syncs += 1
self.num_sent_since_broadcast += 1
# Kick off another sample request
self.sample_tasks.add(ev, ev.sample.remote())
@override(Aggregator)
def stats(self):
return {
"num_weight_syncs": self.num_weight_syncs,
"num_steps_replayed": self.num_replayed,
}
@override(Aggregator)
def reset(self, remote_workers):
self.sample_tasks.reset_workers(remote_workers)
def _augment_with_replay(self, sample_futures):
def can_replay():
num_needed = int(
np.ceil(self.train_batch_size / self.sample_batch_size))
return len(self.replay_batches) > num_needed
for ev, sample_batch in sample_futures:
sample_batch = ray_get_and_free(sample_batch)
yield ev, sample_batch
if can_replay():
f = self.replay_proportion
while random.random() < f:
f -= 1
replay_batch = random.choice(self.replay_batches)
self.num_replayed += replay_batch.count
yield None, replay_batch
class SimpleAggregator(AggregationWorkerBase, Aggregator):
"""Simple single-threaded implementation of an Aggregator."""
def __init__(self,
workers,
max_sample_requests_in_flight_per_worker=2,
replay_proportion=0.0,
replay_buffer_num_slots=0,
train_batch_size=500,
sample_batch_size=50,
broadcast_interval=5):
self.workers = workers
self.local_worker = workers.local_worker()
self.broadcast_interval = broadcast_interval
self.broadcast_new_weights()
AggregationWorkerBase.__init__(
self, self.broadcasted_weights, self.workers.remote_workers(),
max_sample_requests_in_flight_per_worker, replay_proportion,
replay_buffer_num_slots, train_batch_size, sample_batch_size)
@override(Aggregator)
def broadcast_new_weights(self):
self.broadcasted_weights = ray.put(self.local_worker.get_weights())
self.num_sent_since_broadcast = 0
@override(Aggregator)
def should_broadcast(self):
return self.num_sent_since_broadcast >= self.broadcast_interval