mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
178 lines
6.8 KiB
Python
178 lines
6.8 KiB
Python
"""Helper class for AsyncSamplesOptimizer."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import logging
|
|
import os
|
|
import time
|
|
|
|
import ray
|
|
from ray.rllib.utils.actors import TaskPool, create_colocated
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.optimizers.aso_aggregator import Aggregator, \
|
|
AggregationWorkerBase
|
|
from ray.rllib.utils.memory import ray_get_and_free
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TreeAggregator(Aggregator):
|
|
"""A hierarchical experiences aggregator.
|
|
|
|
The given set of remote workers is divided into subsets and assigned to
|
|
one of several aggregation workers. These aggregation workers collate
|
|
experiences into batches of size `train_batch_size` and we collect them
|
|
in this class when `iter_train_batches` is called.
|
|
"""
|
|
|
|
def __init__(self,
|
|
workers,
|
|
num_aggregation_workers,
|
|
max_sample_requests_in_flight_per_worker=2,
|
|
replay_proportion=0.0,
|
|
replay_buffer_num_slots=0,
|
|
train_batch_size=500,
|
|
sample_batch_size=50,
|
|
broadcast_interval=5):
|
|
"""Initialize a tree aggregator.
|
|
|
|
Arguments:
|
|
workers (WorkerSet): set of all workers
|
|
num_aggregation_workers (int): number of intermediate actors to
|
|
use for data aggregation
|
|
max_sample_request_in_flight_per_worker (int): max queue size per
|
|
worker
|
|
replay_proportion (float): ratio of replay to sampled outputs
|
|
replay_buffer_num_slots (int): max number of sample batches to
|
|
store in the replay buffer
|
|
train_batch_size (int): size of batches to learn on
|
|
sample_batch_size (int): size of batches to sample from workers
|
|
broadcast_interval (int): max number of workers to send the
|
|
same set of weights to
|
|
"""
|
|
self.workers = workers
|
|
self.num_aggregation_workers = num_aggregation_workers
|
|
self.max_sample_requests_in_flight_per_worker = \
|
|
max_sample_requests_in_flight_per_worker
|
|
self.replay_proportion = replay_proportion
|
|
self.replay_buffer_num_slots = replay_buffer_num_slots
|
|
self.sample_batch_size = sample_batch_size
|
|
self.train_batch_size = train_batch_size
|
|
self.broadcast_interval = broadcast_interval
|
|
self.broadcasted_weights = ray.put(
|
|
workers.local_worker().get_weights())
|
|
self.num_batches_processed = 0
|
|
self.num_broadcasts = 0
|
|
self.num_sent_since_broadcast = 0
|
|
self.initialized = False
|
|
|
|
def init(self, aggregators):
|
|
"""Deferred init so that we can pass in previously created workers."""
|
|
|
|
assert len(aggregators) == self.num_aggregation_workers, aggregators
|
|
if len(self.workers.remote_workers()) < self.num_aggregation_workers:
|
|
raise ValueError(
|
|
"The number of aggregation workers should not exceed the "
|
|
"number of total evaluation workers ({} vs {})".format(
|
|
self.num_aggregation_workers,
|
|
len(self.workers.remote_workers())))
|
|
|
|
assigned_workers = collections.defaultdict(list)
|
|
for i, ev in enumerate(self.workers.remote_workers()):
|
|
assigned_workers[i % self.num_aggregation_workers].append(ev)
|
|
|
|
self.aggregators = aggregators
|
|
for i, agg in enumerate(self.aggregators):
|
|
agg.init.remote(self.broadcasted_weights, assigned_workers[i],
|
|
self.max_sample_requests_in_flight_per_worker,
|
|
self.replay_proportion,
|
|
self.replay_buffer_num_slots,
|
|
self.train_batch_size, self.sample_batch_size)
|
|
|
|
self.agg_tasks = TaskPool()
|
|
for agg in self.aggregators:
|
|
agg.set_weights.remote(self.broadcasted_weights)
|
|
self.agg_tasks.add(agg, agg.get_train_batches.remote())
|
|
|
|
self.initialized = True
|
|
|
|
@override(Aggregator)
|
|
def iter_train_batches(self):
|
|
assert self.initialized, "Must call init() before using this class."
|
|
for agg, batches in self.agg_tasks.completed_prefetch():
|
|
for b in ray_get_and_free(batches):
|
|
self.num_sent_since_broadcast += 1
|
|
yield b
|
|
agg.set_weights.remote(self.broadcasted_weights)
|
|
self.agg_tasks.add(agg, agg.get_train_batches.remote())
|
|
self.num_batches_processed += 1
|
|
|
|
@override(Aggregator)
|
|
def broadcast_new_weights(self):
|
|
self.broadcasted_weights = ray.put(
|
|
self.workers.local_worker().get_weights())
|
|
self.num_sent_since_broadcast = 0
|
|
self.num_broadcasts += 1
|
|
|
|
@override(Aggregator)
|
|
def should_broadcast(self):
|
|
return self.num_sent_since_broadcast >= self.broadcast_interval
|
|
|
|
@override(Aggregator)
|
|
def stats(self):
|
|
return {
|
|
"num_broadcasts": self.num_broadcasts,
|
|
"num_batches_processed": self.num_batches_processed,
|
|
}
|
|
|
|
@override(Aggregator)
|
|
def reset(self, remote_workers):
|
|
raise NotImplementedError("changing number of remote workers")
|
|
|
|
@staticmethod
|
|
def precreate_aggregators(n):
|
|
return create_colocated(AggregationWorker, [], n)
|
|
|
|
|
|
@ray.remote(num_cpus=1)
|
|
class AggregationWorker(AggregationWorkerBase):
|
|
def __init__(self):
|
|
self.initialized = False
|
|
|
|
def init(self, initial_weights_obj_id, remote_workers,
|
|
max_sample_requests_in_flight_per_worker, replay_proportion,
|
|
replay_buffer_num_slots, train_batch_size, sample_batch_size):
|
|
"""Deferred init that assigns sub-workers to this aggregator."""
|
|
|
|
logger.info("Assigned workers {} to aggregation worker {}".format(
|
|
remote_workers, self))
|
|
assert remote_workers
|
|
AggregationWorkerBase.__init__(
|
|
self, initial_weights_obj_id, remote_workers,
|
|
max_sample_requests_in_flight_per_worker, replay_proportion,
|
|
replay_buffer_num_slots, train_batch_size, sample_batch_size)
|
|
self.initialized = True
|
|
|
|
def set_weights(self, weights):
|
|
self.broadcasted_weights = weights
|
|
|
|
def get_train_batches(self):
|
|
assert self.initialized, "Must call init() before using this class."
|
|
start = time.time()
|
|
result = []
|
|
for batch in self.iter_train_batches(max_yield=5):
|
|
result.append(batch)
|
|
while not result:
|
|
time.sleep(0.01)
|
|
for batch in self.iter_train_batches(max_yield=5):
|
|
result.append(batch)
|
|
logger.debug("Returning {} train batches, {}s".format(
|
|
len(result),
|
|
time.time() - start))
|
|
return result
|
|
|
|
def get_host(self):
|
|
return os.uname()[1]
|