mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
Turn replay into a circular queue. (#4667)
This commit is contained in:
parent
9d481cc2e6
commit
39a09fa457
1 changed files with 7 additions and 3 deletions
|
@ -83,6 +83,7 @@ class AggregationWorkerBase(object):
|
||||||
self.replay_proportion = replay_proportion
|
self.replay_proportion = replay_proportion
|
||||||
self.replay_buffer_num_slots = replay_buffer_num_slots
|
self.replay_buffer_num_slots = replay_buffer_num_slots
|
||||||
self.replay_batches = []
|
self.replay_batches = []
|
||||||
|
self.replay_index = 0
|
||||||
self.num_sent_since_broadcast = 0
|
self.num_sent_since_broadcast = 0
|
||||||
self.num_weight_syncs = 0
|
self.num_weight_syncs = 0
|
||||||
self.num_replayed = 0
|
self.num_replayed = 0
|
||||||
|
@ -115,9 +116,12 @@ class AggregationWorkerBase(object):
|
||||||
|
|
||||||
# Put in replay buffer if enabled
|
# Put in replay buffer if enabled
|
||||||
if self.replay_buffer_num_slots > 0:
|
if self.replay_buffer_num_slots > 0:
|
||||||
self.replay_batches.append(sample_batch)
|
if len(self.replay_batches) < self.replay_buffer_num_slots:
|
||||||
if len(self.replay_batches) > self.replay_buffer_num_slots:
|
self.replay_batches.append(sample_batch)
|
||||||
self.replay_batches.pop(0)
|
else:
|
||||||
|
self.replay_batches[self.replay_index] = sample_batch
|
||||||
|
self.replay_index += 1
|
||||||
|
self.replay_index %= self.replay_buffer_num_slots
|
||||||
|
|
||||||
ev.set_weights.remote(self.broadcasted_weights)
|
ev.set_weights.remote(self.broadcasted_weights)
|
||||||
self.num_weight_syncs += 1
|
self.num_weight_syncs += 1
|
||||||
|
|
Loading…
Add table
Reference in a new issue