2019-03-31 12:25:52 -07:00
|
|
|
"""Helper class for AsyncSamplesOptimizer."""
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
|
|
class MinibatchBuffer(object):
|
|
|
|
"""Ring buffer of recent data batches for minibatch SGD.
|
|
|
|
|
|
|
|
This is for use with AsyncSamplesOptimizer.
|
|
|
|
"""
|
|
|
|
|
2019-07-29 15:02:32 -07:00
|
|
|
def __init__(self, inqueue, size, timeout, num_passes, init_num_passes=1):
|
2019-03-31 12:25:52 -07:00
|
|
|
"""Initialize a minibatch buffer.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
inqueue: Queue to populate the internal ring buffer from.
|
|
|
|
size: Max number of data items to buffer.
|
2019-07-26 06:18:05 +02:00
|
|
|
timeout: Queue timeout
|
2019-03-31 12:25:52 -07:00
|
|
|
num_passes: Max num times each data item should be emitted.
|
2019-07-29 15:02:32 -07:00
|
|
|
init_num_passes: Initial max passes for each data item
|
|
|
|
"""
|
2019-03-31 12:25:52 -07:00
|
|
|
self.inqueue = inqueue
|
|
|
|
self.size = size
|
2019-07-26 06:18:05 +02:00
|
|
|
self.timeout = timeout
|
2019-03-31 12:25:52 -07:00
|
|
|
self.max_ttl = num_passes
|
2019-07-29 15:02:32 -07:00
|
|
|
self.cur_max_ttl = init_num_passes
|
2019-03-31 12:25:52 -07:00
|
|
|
self.buffers = [None] * size
|
|
|
|
self.ttl = [0] * size
|
|
|
|
self.idx = 0
|
|
|
|
|
|
|
|
def get(self):
|
|
|
|
"""Get a new batch from the internal ring buffer.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
buf: Data item saved from inqueue.
|
|
|
|
released: True if the item is now removed from the ring buffer.
|
|
|
|
"""
|
|
|
|
if self.ttl[self.idx] <= 0:
|
2019-07-26 06:18:05 +02:00
|
|
|
self.buffers[self.idx] = self.inqueue.get(timeout=self.timeout)
|
2019-03-31 12:25:52 -07:00
|
|
|
self.ttl[self.idx] = self.cur_max_ttl
|
|
|
|
if self.cur_max_ttl < self.max_ttl:
|
|
|
|
self.cur_max_ttl += 1
|
|
|
|
buf = self.buffers[self.idx]
|
|
|
|
self.ttl[self.idx] -= 1
|
|
|
|
released = self.ttl[self.idx] <= 0
|
|
|
|
if released:
|
|
|
|
self.buffers[self.idx] = None
|
|
|
|
self.idx = (self.idx + 1) % len(self.buffers)
|
|
|
|
return buf, released
|