mirror of
https://github.com/vale981/ray
synced 2025-03-10 05:16:49 -04:00
34 lines
1.4 KiB
Python
34 lines
1.4 KiB
Python
import random
|
|
|
|
from ray.rllib.evaluation.metrics import get_learner_stats
|
|
from ray.rllib.optimizers.sync_batch_replay_optimizer import \
|
|
SyncBatchReplayOptimizer
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.annotations import override
|
|
|
|
|
|
class SyncBatchesReplayOptimizer(SyncBatchReplayOptimizer):
|
|
def __init__(self,
|
|
workers,
|
|
learning_starts=1000,
|
|
buffer_size=10000,
|
|
train_batch_size=32,
|
|
num_gradient_descents=10):
|
|
super(SyncBatchesReplayOptimizer, self).__init__(
|
|
workers, learning_starts, buffer_size, train_batch_size)
|
|
self.num_sgds = num_gradient_descents
|
|
|
|
@override(SyncBatchReplayOptimizer)
|
|
def _optimize(self):
|
|
for _ in range(self.num_sgds):
|
|
samples = [random.choice(self.replay_buffer)]
|
|
while sum(s.count for s in samples) < self.train_batch_size:
|
|
samples.append(random.choice(self.replay_buffer))
|
|
samples = SampleBatch.concat_samples(samples)
|
|
with self.grad_timer:
|
|
info_dict = self.workers.local_worker().learn_on_batch(samples)
|
|
for policy_id, info in info_dict.items():
|
|
self.learner_stats[policy_id] = get_learner_stats(info)
|
|
self.grad_timer.push_units_processed(samples.count)
|
|
self.num_steps_trained += samples.count
|
|
return info_dict
|