ray/rllib/contrib/alpha_zero/optimizer/sync_batches_replay_optimizer.py

34 lines
1.4 KiB
Python

import random
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.optimizers.sync_batch_replay_optimizer import \
SyncBatchReplayOptimizer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
class SyncBatchesReplayOptimizer(SyncBatchReplayOptimizer):
def __init__(self,
workers,
learning_starts=1000,
buffer_size=10000,
train_batch_size=32,
num_gradient_descents=10):
super(SyncBatchesReplayOptimizer, self).__init__(
workers, learning_starts, buffer_size, train_batch_size)
self.num_sgds = num_gradient_descents
@override(SyncBatchReplayOptimizer)
def _optimize(self):
for _ in range(self.num_sgds):
samples = [random.choice(self.replay_buffer)]
while sum(s.count for s in samples) < self.train_batch_size:
samples.append(random.choice(self.replay_buffer))
samples = SampleBatch.concat_samples(samples)
with self.grad_timer:
info_dict = self.workers.local_worker().learn_on_batch(samples)
for policy_id, info in info_dict.items():
self.learner_stats[policy_id] = get_learner_stats(info)
self.grad_timer.push_units_processed(samples.count)
self.num_steps_trained += samples.count
return info_dict