ray/rllib/optimizers/aso_learner.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

75 lines
2.8 KiB
Python

"""Helper class for AsyncSamplesOptimizer."""
import threading
from six.moves import queue
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.optimizers.aso_minibatch_buffer import MinibatchBuffer
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.window_stat import WindowStat
class LearnerThread(threading.Thread):
"""Background thread that updates the local model from sample trajectories.
This is for use with AsyncSamplesOptimizer.
The learner thread communicates with the main thread through Queues. This
is needed since Ray operations can only be run on the main thread. In
addition, moving heavyweight gradient ops session runs off the main thread
improves overall throughput.
"""
def __init__(self, local_worker, minibatch_buffer_size, num_sgd_iter,
learner_queue_size, learner_queue_timeout):
"""Initialize the learner thread.
Arguments:
local_worker (RolloutWorker): process local rollout worker holding
policies this thread will call learn_on_batch() on
minibatch_buffer_size (int): max number of train batches to store
in the minibatching buffer
num_sgd_iter (int): number of passes to learn on per train batch
learner_queue_size (int): max size of queue of inbound
train batches to this thread
learner_queue_timeout (int): raise an exception if the queue has
been empty for this long in seconds
"""
threading.Thread.__init__(self)
self.learner_queue_size = WindowStat("size", 50)
self.local_worker = local_worker
self.inqueue = queue.Queue(maxsize=learner_queue_size)
self.outqueue = queue.Queue()
self.minibatch_buffer = MinibatchBuffer(
inqueue=self.inqueue,
size=minibatch_buffer_size,
timeout=learner_queue_timeout,
num_passes=num_sgd_iter,
init_num_passes=num_sgd_iter)
self.queue_timer = TimerStat()
self.grad_timer = TimerStat()
self.load_timer = TimerStat()
self.load_wait_timer = TimerStat()
self.daemon = True
self.weights_updated = False
self.stats = {}
self.stopped = False
self.num_steps = 0
def run(self):
while not self.stopped:
self.step()
def step(self):
with self.queue_timer:
batch, _ = self.minibatch_buffer.get()
with self.grad_timer:
fetches = self.local_worker.learn_on_batch(batch)
self.weights_updated = True
self.stats = get_learner_stats(fetches)
self.num_steps += 1
self.outqueue.put(batch.count)
self.learner_queue_size.push(self.inqueue.qsize())