2020-09-15 03:37:07 -07:00
|
|
|
import queue
|
2020-05-21 10:16:18 -07:00
|
|
|
import threading
|
|
|
|
|
2020-07-11 22:06:35 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2021-09-30 16:39:05 +02:00
|
|
|
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
|
2021-11-01 21:46:02 +01:00
|
|
|
from ray.rllib.utils.metrics.window_stat import WindowStat
|
2020-05-21 10:16:18 -07:00
|
|
|
from ray.rllib.utils.timer import TimerStat
|
|
|
|
|
|
|
|
LEARNER_QUEUE_MAX_SIZE = 16
|
|
|
|
|
2020-07-11 22:06:35 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
2020-05-21 10:16:18 -07:00
|
|
|
|
|
|
|
class LearnerThread(threading.Thread):
|
|
|
|
"""Background thread that updates the local model from replay data.
|
|
|
|
|
|
|
|
The learner thread communicates with the main thread through Queues. This
|
|
|
|
is needed since Ray operations can only be run on the main thread. In
|
|
|
|
addition, moving heavyweight gradient ops session runs off the main thread
|
|
|
|
improves overall throughput.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, local_worker):
|
|
|
|
threading.Thread.__init__(self)
|
|
|
|
self.learner_queue_size = WindowStat("size", 50)
|
|
|
|
self.local_worker = local_worker
|
|
|
|
self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
|
|
|
|
self.outqueue = queue.Queue()
|
|
|
|
self.queue_timer = TimerStat()
|
|
|
|
self.grad_timer = TimerStat()
|
|
|
|
self.overall_timer = TimerStat()
|
|
|
|
self.daemon = True
|
|
|
|
self.weights_updated = False
|
|
|
|
self.stopped = False
|
2021-09-30 16:39:05 +02:00
|
|
|
self.learner_info = {}
|
2020-05-21 10:16:18 -07:00
|
|
|
|
|
|
|
def run(self):
|
2020-07-11 22:06:35 +02:00
|
|
|
# Switch on eager mode if configured.
|
|
|
|
if self.local_worker.policy_config.get("framework") in ["tf2", "tfe"]:
|
|
|
|
tf1.enable_eager_execution()
|
2020-05-21 10:16:18 -07:00
|
|
|
while not self.stopped:
|
|
|
|
self.step()
|
|
|
|
|
|
|
|
def step(self):
|
|
|
|
with self.overall_timer:
|
|
|
|
with self.queue_timer:
|
|
|
|
ra, replay = self.inqueue.get()
|
|
|
|
if replay is not None:
|
|
|
|
prio_dict = {}
|
|
|
|
with self.grad_timer:
|
2021-09-30 16:39:05 +02:00
|
|
|
# Use LearnerInfoBuilder as a unified way to build the
|
|
|
|
# final results dict from `learn_on_loaded_batch` call(s).
|
|
|
|
# This makes sure results dicts always have the same
|
|
|
|
# structure no matter the setup (multi-GPU, multi-agent,
|
|
|
|
# minibatch SGD, tf vs torch).
|
|
|
|
learner_info_builder = LearnerInfoBuilder(num_devices=1)
|
2022-01-29 18:41:57 -08:00
|
|
|
multi_agent_results = self.local_worker.learn_on_batch(replay)
|
2021-09-30 16:39:05 +02:00
|
|
|
for pid, results in multi_agent_results.items():
|
2022-01-29 18:41:57 -08:00
|
|
|
learner_info_builder.add_learn_on_batch_results(results, pid)
|
2021-09-30 16:39:05 +02:00
|
|
|
td_error = results["td_error"]
|
2021-05-20 09:27:03 +02:00
|
|
|
# Switch off auto-conversion from numpy to torch/tf
|
|
|
|
# tensors for the indices. This may lead to errors
|
|
|
|
# when sent to the buffer for processing
|
|
|
|
# (may get manipulated if they are part of a tensor).
|
|
|
|
replay.policy_batches[pid].set_get_interceptor(None)
|
2021-04-15 19:19:51 +02:00
|
|
|
prio_dict[pid] = (
|
|
|
|
replay.policy_batches[pid].get("batch_indexes"),
|
2022-01-29 18:41:57 -08:00
|
|
|
td_error,
|
|
|
|
)
|
2021-09-30 16:39:05 +02:00
|
|
|
self.learner_info = learner_info_builder.finalize()
|
2020-05-21 10:16:18 -07:00
|
|
|
self.grad_timer.push_units_processed(replay.count)
|
|
|
|
self.outqueue.put((ra, prio_dict, replay.count))
|
|
|
|
self.learner_queue_size.push(self.inqueue.qsize())
|
|
|
|
self.weights_updated = True
|
2022-01-29 18:41:57 -08:00
|
|
|
self.overall_timer.push_units_processed(replay and replay.count or 0)
|