mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[rllib] Configure learner queue timeout (#5270)
* configure learner queue timeout * lint * use config * fix method args order, add unit test * fix wrong param name
This commit is contained in:
parent
6f682db99d
commit
827618254a
7 changed files with 50 additions and 19 deletions
|
@ -54,6 +54,10 @@ DEFAULT_CONFIG = with_common_config({
|
||||||
"replay_buffer_num_slots": 0,
|
"replay_buffer_num_slots": 0,
|
||||||
# max queue size for train batches feeding into the learner
|
# max queue size for train batches feeding into the learner
|
||||||
"learner_queue_size": 16,
|
"learner_queue_size": 16,
|
||||||
|
# wait for train batches to be available in minibatch buffer queue
|
||||||
|
# this many seconds. This may need to be increased e.g. when training
|
||||||
|
# with a slow environment
|
||||||
|
"learner_queue_timeout": 300,
|
||||||
# level of queuing for sampling.
|
# level of queuing for sampling.
|
||||||
"max_sample_requests_in_flight_per_worker": 2,
|
"max_sample_requests_in_flight_per_worker": 2,
|
||||||
# max number of workers to broadcast one set of weights to
|
# max number of workers to broadcast one set of weights to
|
||||||
|
@ -126,6 +130,8 @@ def make_aggregators_and_optimizer(workers, config):
|
||||||
num_sgd_iter=config["num_sgd_iter"],
|
num_sgd_iter=config["num_sgd_iter"],
|
||||||
minibatch_buffer_size=config["minibatch_buffer_size"],
|
minibatch_buffer_size=config["minibatch_buffer_size"],
|
||||||
num_aggregation_workers=config["num_aggregation_workers"],
|
num_aggregation_workers=config["num_aggregation_workers"],
|
||||||
|
learner_queue_size=config["learner_queue_size"],
|
||||||
|
learner_queue_timeout=config["learner_queue_timeout"],
|
||||||
**config["optimizer"])
|
**config["optimizer"])
|
||||||
|
|
||||||
if aggregators:
|
if aggregators:
|
||||||
|
|
|
@ -35,6 +35,7 @@ DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
|
||||||
"replay_proportion": 0.0,
|
"replay_proportion": 0.0,
|
||||||
"replay_buffer_num_slots": 100,
|
"replay_buffer_num_slots": 100,
|
||||||
"learner_queue_size": 16,
|
"learner_queue_size": 16,
|
||||||
|
"learner_queue_timeout": 300,
|
||||||
"max_sample_requests_in_flight_per_worker": 2,
|
"max_sample_requests_in_flight_per_worker": 2,
|
||||||
"broadcast_interval": 1,
|
"broadcast_interval": 1,
|
||||||
"grad_clip": 40.0,
|
"grad_clip": 40.0,
|
||||||
|
|
|
@ -26,14 +26,17 @@ class LearnerThread(threading.Thread):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, local_worker, minibatch_buffer_size, num_sgd_iter,
|
def __init__(self, local_worker, minibatch_buffer_size, num_sgd_iter,
|
||||||
learner_queue_size):
|
learner_queue_size, learner_queue_timeout):
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
self.learner_queue_size = WindowStat("size", 50)
|
self.learner_queue_size = WindowStat("size", 50)
|
||||||
self.local_worker = local_worker
|
self.local_worker = local_worker
|
||||||
self.inqueue = queue.Queue(maxsize=learner_queue_size)
|
self.inqueue = queue.Queue(maxsize=learner_queue_size)
|
||||||
self.outqueue = queue.Queue()
|
self.outqueue = queue.Queue()
|
||||||
self.minibatch_buffer = MinibatchBuffer(
|
self.minibatch_buffer = MinibatchBuffer(
|
||||||
self.inqueue, minibatch_buffer_size, num_sgd_iter)
|
inqueue=self.inqueue,
|
||||||
|
size=minibatch_buffer_size,
|
||||||
|
timeout=learner_queue_timeout,
|
||||||
|
num_passes=num_sgd_iter)
|
||||||
self.queue_timer = TimerStat()
|
self.queue_timer = TimerStat()
|
||||||
self.grad_timer = TimerStat()
|
self.grad_timer = TimerStat()
|
||||||
self.load_timer = TimerStat()
|
self.load_timer = TimerStat()
|
||||||
|
|
|
@ -11,16 +11,18 @@ class MinibatchBuffer(object):
|
||||||
This is for use with AsyncSamplesOptimizer.
|
This is for use with AsyncSamplesOptimizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, inqueue, size, num_passes):
|
def __init__(self, inqueue, size, timeout, num_passes):
|
||||||
"""Initialize a minibatch buffer.
|
"""Initialize a minibatch buffer.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
inqueue: Queue to populate the internal ring buffer from.
|
inqueue: Queue to populate the internal ring buffer from.
|
||||||
size: Max number of data items to buffer.
|
size: Max number of data items to buffer.
|
||||||
|
timeout: Queue timeout
|
||||||
num_passes: Max num times each data item should be emitted.
|
num_passes: Max num times each data item should be emitted.
|
||||||
"""
|
"""
|
||||||
self.inqueue = inqueue
|
self.inqueue = inqueue
|
||||||
self.size = size
|
self.size = size
|
||||||
|
self.timeout = timeout
|
||||||
self.max_ttl = num_passes
|
self.max_ttl = num_passes
|
||||||
self.cur_max_ttl = 1 # ramp up slowly to better mix the input data
|
self.cur_max_ttl = 1 # ramp up slowly to better mix the input data
|
||||||
self.buffers = [None] * size
|
self.buffers = [None] * size
|
||||||
|
@ -35,7 +37,7 @@ class MinibatchBuffer(object):
|
||||||
released: True if the item is now removed from the ring buffer.
|
released: True if the item is now removed from the ring buffer.
|
||||||
"""
|
"""
|
||||||
if self.ttl[self.idx] <= 0:
|
if self.ttl[self.idx] <= 0:
|
||||||
self.buffers[self.idx] = self.inqueue.get(timeout=300.0)
|
self.buffers[self.idx] = self.inqueue.get(timeout=self.timeout)
|
||||||
self.ttl[self.idx] = self.cur_max_ttl
|
self.ttl[self.idx] = self.cur_max_ttl
|
||||||
if self.cur_max_ttl < self.max_ttl:
|
if self.cur_max_ttl < self.max_ttl:
|
||||||
self.cur_max_ttl += 1
|
self.cur_max_ttl += 1
|
||||||
|
|
|
@ -39,10 +39,12 @@ class TFMultiGPULearner(LearnerThread):
|
||||||
minibatch_buffer_size=1,
|
minibatch_buffer_size=1,
|
||||||
num_sgd_iter=1,
|
num_sgd_iter=1,
|
||||||
learner_queue_size=16,
|
learner_queue_size=16,
|
||||||
|
learner_queue_timeout=300,
|
||||||
num_data_load_threads=16,
|
num_data_load_threads=16,
|
||||||
_fake_gpus=False):
|
_fake_gpus=False):
|
||||||
LearnerThread.__init__(self, local_worker, minibatch_buffer_size,
|
LearnerThread.__init__(self, local_worker, minibatch_buffer_size,
|
||||||
num_sgd_iter, learner_queue_size)
|
num_sgd_iter, learner_queue_size,
|
||||||
|
learner_queue_timeout)
|
||||||
self.lr = lr
|
self.lr = lr
|
||||||
self.train_batch_size = train_batch_size
|
self.train_batch_size = train_batch_size
|
||||||
if not num_gpus:
|
if not num_gpus:
|
||||||
|
@ -99,7 +101,8 @@ class TFMultiGPULearner(LearnerThread):
|
||||||
self.loader_thread.start()
|
self.loader_thread.start()
|
||||||
|
|
||||||
self.minibatch_buffer = MinibatchBuffer(
|
self.minibatch_buffer = MinibatchBuffer(
|
||||||
self.ready_optimizers, minibatch_buffer_size, num_sgd_iter)
|
self.ready_optimizers, minibatch_buffer_size,
|
||||||
|
learner_queue_timeout, num_sgd_iter)
|
||||||
|
|
||||||
@override(LearnerThread)
|
@override(LearnerThread)
|
||||||
def step(self):
|
def step(self):
|
||||||
|
|
|
@ -42,6 +42,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||||
num_sgd_iter=1,
|
num_sgd_iter=1,
|
||||||
minibatch_buffer_size=1,
|
minibatch_buffer_size=1,
|
||||||
learner_queue_size=16,
|
learner_queue_size=16,
|
||||||
|
learner_queue_timeout=300,
|
||||||
num_aggregation_workers=0,
|
num_aggregation_workers=0,
|
||||||
_fake_gpus=False):
|
_fake_gpus=False):
|
||||||
PolicyOptimizer.__init__(self, workers)
|
PolicyOptimizer.__init__(self, workers)
|
||||||
|
@ -69,11 +70,15 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||||
minibatch_buffer_size=minibatch_buffer_size,
|
minibatch_buffer_size=minibatch_buffer_size,
|
||||||
num_sgd_iter=num_sgd_iter,
|
num_sgd_iter=num_sgd_iter,
|
||||||
learner_queue_size=learner_queue_size,
|
learner_queue_size=learner_queue_size,
|
||||||
|
learner_queue_timeout=learner_queue_timeout,
|
||||||
_fake_gpus=_fake_gpus)
|
_fake_gpus=_fake_gpus)
|
||||||
else:
|
else:
|
||||||
self.learner = LearnerThread(self.workers.local_worker(),
|
self.learner = LearnerThread(
|
||||||
minibatch_buffer_size, num_sgd_iter,
|
self.workers.local_worker(),
|
||||||
learner_queue_size)
|
minibatch_buffer_size=minibatch_buffer_size,
|
||||||
|
num_sgd_iter=num_sgd_iter,
|
||||||
|
learner_queue_size=learner_queue_size,
|
||||||
|
learner_queue_timeout=learner_queue_timeout)
|
||||||
self.learner.start()
|
self.learner.start()
|
||||||
|
|
||||||
# Stats
|
# Stats
|
||||||
|
|
|
@ -117,26 +117,26 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||||
ray.init(num_cpus=8)
|
ray.init(num_cpus=8)
|
||||||
|
|
||||||
def testSimple(self):
|
def testSimple(self):
|
||||||
local, remotes = self._make_evs()
|
local, remotes = self._make_envs()
|
||||||
workers = WorkerSet._from_existing(local, remotes)
|
workers = WorkerSet._from_existing(local, remotes)
|
||||||
optimizer = AsyncSamplesOptimizer(workers)
|
optimizer = AsyncSamplesOptimizer(workers)
|
||||||
self._wait_for(optimizer, 1000, 1000)
|
self._wait_for(optimizer, 1000, 1000)
|
||||||
|
|
||||||
def testMultiGPU(self):
|
def testMultiGPU(self):
|
||||||
local, remotes = self._make_evs()
|
local, remotes = self._make_envs()
|
||||||
workers = WorkerSet._from_existing(local, remotes)
|
workers = WorkerSet._from_existing(local, remotes)
|
||||||
optimizer = AsyncSamplesOptimizer(workers, num_gpus=1, _fake_gpus=True)
|
optimizer = AsyncSamplesOptimizer(workers, num_gpus=1, _fake_gpus=True)
|
||||||
self._wait_for(optimizer, 1000, 1000)
|
self._wait_for(optimizer, 1000, 1000)
|
||||||
|
|
||||||
def testMultiGPUParallelLoad(self):
|
def testMultiGPUParallelLoad(self):
|
||||||
local, remotes = self._make_evs()
|
local, remotes = self._make_envs()
|
||||||
workers = WorkerSet._from_existing(local, remotes)
|
workers = WorkerSet._from_existing(local, remotes)
|
||||||
optimizer = AsyncSamplesOptimizer(
|
optimizer = AsyncSamplesOptimizer(
|
||||||
workers, num_gpus=1, num_data_loader_buffers=1, _fake_gpus=True)
|
workers, num_gpus=1, num_data_loader_buffers=1, _fake_gpus=True)
|
||||||
self._wait_for(optimizer, 1000, 1000)
|
self._wait_for(optimizer, 1000, 1000)
|
||||||
|
|
||||||
def testMultiplePasses(self):
|
def testMultiplePasses(self):
|
||||||
local, remotes = self._make_evs()
|
local, remotes = self._make_envs()
|
||||||
workers = WorkerSet._from_existing(local, remotes)
|
workers = WorkerSet._from_existing(local, remotes)
|
||||||
optimizer = AsyncSamplesOptimizer(
|
optimizer = AsyncSamplesOptimizer(
|
||||||
workers,
|
workers,
|
||||||
|
@ -149,7 +149,7 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||||
self.assertGreater(optimizer.stats()["num_steps_trained"], 8000)
|
self.assertGreater(optimizer.stats()["num_steps_trained"], 8000)
|
||||||
|
|
||||||
def testReplay(self):
|
def testReplay(self):
|
||||||
local, remotes = self._make_evs()
|
local, remotes = self._make_envs()
|
||||||
workers = WorkerSet._from_existing(local, remotes)
|
workers = WorkerSet._from_existing(local, remotes)
|
||||||
optimizer = AsyncSamplesOptimizer(
|
optimizer = AsyncSamplesOptimizer(
|
||||||
workers,
|
workers,
|
||||||
|
@ -166,7 +166,7 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||||
self.assertLess(stats["num_steps_trained"], stats["num_steps_sampled"])
|
self.assertLess(stats["num_steps_trained"], stats["num_steps_sampled"])
|
||||||
|
|
||||||
def testReplayAndMultiplePasses(self):
|
def testReplayAndMultiplePasses(self):
|
||||||
local, remotes = self._make_evs()
|
local, remotes = self._make_envs()
|
||||||
workers = WorkerSet._from_existing(local, remotes)
|
workers = WorkerSet._from_existing(local, remotes)
|
||||||
optimizer = AsyncSamplesOptimizer(
|
optimizer = AsyncSamplesOptimizer(
|
||||||
workers,
|
workers,
|
||||||
|
@ -187,7 +187,7 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||||
self.assertLess(train_ratio, 0.4)
|
self.assertLess(train_ratio, 0.4)
|
||||||
|
|
||||||
def testMultiTierAggregationBadConf(self):
|
def testMultiTierAggregationBadConf(self):
|
||||||
local, remotes = self._make_evs()
|
local, remotes = self._make_envs()
|
||||||
workers = WorkerSet._from_existing(local, remotes)
|
workers = WorkerSet._from_existing(local, remotes)
|
||||||
aggregators = TreeAggregator.precreate_aggregators(4)
|
aggregators = TreeAggregator.precreate_aggregators(4)
|
||||||
optimizer = AsyncSamplesOptimizer(workers, num_aggregation_workers=4)
|
optimizer = AsyncSamplesOptimizer(workers, num_aggregation_workers=4)
|
||||||
|
@ -195,7 +195,7 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||||
lambda: optimizer.aggregator.init(aggregators))
|
lambda: optimizer.aggregator.init(aggregators))
|
||||||
|
|
||||||
def testMultiTierAggregation(self):
|
def testMultiTierAggregation(self):
|
||||||
local, remotes = self._make_evs()
|
local, remotes = self._make_envs()
|
||||||
workers = WorkerSet._from_existing(local, remotes)
|
workers = WorkerSet._from_existing(local, remotes)
|
||||||
aggregators = TreeAggregator.precreate_aggregators(1)
|
aggregators = TreeAggregator.precreate_aggregators(1)
|
||||||
optimizer = AsyncSamplesOptimizer(workers, num_aggregation_workers=1)
|
optimizer = AsyncSamplesOptimizer(workers, num_aggregation_workers=1)
|
||||||
|
@ -203,7 +203,7 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||||
self._wait_for(optimizer, 1000, 1000)
|
self._wait_for(optimizer, 1000, 1000)
|
||||||
|
|
||||||
def testRejectBadConfigs(self):
|
def testRejectBadConfigs(self):
|
||||||
local, remotes = self._make_evs()
|
local, remotes = self._make_envs()
|
||||||
workers = WorkerSet._from_existing(local, remotes)
|
workers = WorkerSet._from_existing(local, remotes)
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValueError, lambda: AsyncSamplesOptimizer(
|
ValueError, lambda: AsyncSamplesOptimizer(
|
||||||
|
@ -231,7 +231,18 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||||
_fake_gpus=True)
|
_fake_gpus=True)
|
||||||
self._wait_for(optimizer, 1000, 1000)
|
self._wait_for(optimizer, 1000, 1000)
|
||||||
|
|
||||||
def _make_evs(self):
|
def testLearnerQueueTimeout(self):
|
||||||
|
local, remotes = self._make_envs()
|
||||||
|
workers = WorkerSet._from_existing(local, remotes)
|
||||||
|
optimizer = AsyncSamplesOptimizer(
|
||||||
|
workers,
|
||||||
|
sample_batch_size=1000,
|
||||||
|
train_batch_size=1000,
|
||||||
|
learner_queue_timeout=1)
|
||||||
|
self.assertRaises(AssertionError,
|
||||||
|
lambda: self._wait_for(optimizer, 1000, 1000))
|
||||||
|
|
||||||
|
def _make_envs(self):
|
||||||
def make_sess():
|
def make_sess():
|
||||||
return tf.Session(config=tf.ConfigProto(device_count={"CPU": 2}))
|
return tf.Session(config=tf.ConfigProto(device_count={"CPU": 2}))
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue