mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
116 lines
4.6 KiB
Python
116 lines
4.6 KiB
Python
import random
|
|
|
|
import ray
|
|
from ray.rllib.evaluation.metrics import get_learner_stats
|
|
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
|
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
|
MultiAgentBatch
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.timer import TimerStat
|
|
|
|
|
|
class SyncBatchReplayOptimizer(PolicyOptimizer):
|
|
"""Variant of the sync replay optimizer that replays entire batches.
|
|
|
|
This enables RNN support. Does not currently support prioritization."""
|
|
|
|
def __init__(self,
|
|
workers,
|
|
learning_starts=1000,
|
|
buffer_size=10000,
|
|
train_batch_size=32):
|
|
"""Initialize a batch replay optimizer.
|
|
|
|
Arguments:
|
|
workers (WorkerSet): set of all workers
|
|
learning_starts (int): start learning after this number of
|
|
timesteps have been collected
|
|
buffer_size (int): max timesteps to keep in the replay buffer
|
|
train_batch_size (int): number of timesteps to train on at once
|
|
"""
|
|
PolicyOptimizer.__init__(self, workers)
|
|
|
|
self.replay_starts = learning_starts
|
|
self.max_buffer_size = buffer_size
|
|
self.train_batch_size = train_batch_size
|
|
assert self.max_buffer_size >= self.replay_starts
|
|
|
|
# List of buffered sample batches
|
|
self.replay_buffer = []
|
|
self.buffer_size = 0
|
|
|
|
# Stats
|
|
self.update_weights_timer = TimerStat()
|
|
self.sample_timer = TimerStat()
|
|
self.grad_timer = TimerStat()
|
|
self.learner_stats = {}
|
|
|
|
@override(PolicyOptimizer)
|
|
def step(self):
|
|
with self.update_weights_timer:
|
|
if self.workers.remote_workers():
|
|
weights = ray.put(self.workers.local_worker().get_weights())
|
|
for e in self.workers.remote_workers():
|
|
e.set_weights.remote(weights)
|
|
|
|
with self.sample_timer:
|
|
if self.workers.remote_workers():
|
|
batches = ray.get(
|
|
[e.sample.remote() for e in self.workers.remote_workers()])
|
|
else:
|
|
batches = [self.workers.local_worker().sample()]
|
|
|
|
# Handle everything as if multiagent
|
|
tmp = []
|
|
for batch in batches:
|
|
if isinstance(batch, SampleBatch):
|
|
batch = MultiAgentBatch({
|
|
DEFAULT_POLICY_ID: batch
|
|
}, batch.count)
|
|
tmp.append(batch)
|
|
batches = tmp
|
|
|
|
for batch in batches:
|
|
if batch.count > self.max_buffer_size:
|
|
raise ValueError(
|
|
"The size of a single sample batch exceeds the replay "
|
|
"buffer size ({} > {})".format(batch.count,
|
|
self.max_buffer_size))
|
|
self.replay_buffer.append(batch)
|
|
self.num_steps_sampled += batch.count
|
|
self.buffer_size += batch.count
|
|
while self.buffer_size > self.max_buffer_size:
|
|
evicted = self.replay_buffer.pop(0)
|
|
self.buffer_size -= evicted.count
|
|
|
|
if self.num_steps_sampled >= self.replay_starts:
|
|
return self._optimize()
|
|
else:
|
|
return {}
|
|
|
|
@override(PolicyOptimizer)
|
|
def stats(self):
|
|
return dict(
|
|
PolicyOptimizer.stats(self), **{
|
|
"sample_time_ms": round(1000 * self.sample_timer.mean, 3),
|
|
"grad_time_ms": round(1000 * self.grad_timer.mean, 3),
|
|
"update_time_ms": round(1000 * self.update_weights_timer.mean,
|
|
3),
|
|
"opt_peak_throughput": round(self.grad_timer.mean_throughput,
|
|
3),
|
|
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
|
|
"learner": self.learner_stats,
|
|
})
|
|
|
|
def _optimize(self):
|
|
samples = [random.choice(self.replay_buffer)]
|
|
while sum(s.count for s in samples) < self.train_batch_size:
|
|
samples.append(random.choice(self.replay_buffer))
|
|
samples = SampleBatch.concat_samples(samples)
|
|
with self.grad_timer:
|
|
info_dict = self.workers.local_worker().learn_on_batch(samples)
|
|
for policy_id, info in info_dict.items():
|
|
self.learner_stats[policy_id] = get_learner_stats(info)
|
|
self.grad_timer.push_units_processed(samples.count)
|
|
self.num_steps_trained += samples.count
|
|
return info_dict
|