2020-05-21 10:16:18 -07:00
|
|
|
import copy
|
2021-06-08 16:27:02 +02:00
|
|
|
from six.moves import queue
|
2021-02-08 12:05:16 +01:00
|
|
|
import threading
|
2021-07-20 14:58:13 -04:00
|
|
|
from typing import Dict, Optional
|
2020-05-21 10:16:18 -07:00
|
|
|
|
2022-06-23 21:30:01 +02:00
|
|
|
from ray.util.timer import _Timer
|
2021-06-01 15:40:28 +01:00
|
|
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
2022-05-17 13:43:49 +02:00
|
|
|
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
|
2020-07-11 22:06:35 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, LEARNER_INFO
|
2021-11-23 23:01:05 +01:00
|
|
|
from ray.rllib.utils.metrics.window_stat import WindowStat
|
2021-06-01 15:40:28 +01:00
|
|
|
from ray.util.iter import _NextValueNotReady
|
2020-05-21 10:16:18 -07:00
|
|
|
|
2020-07-11 22:06:35 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
2020-05-21 10:16:18 -07:00
|
|
|
|
|
|
|
class LearnerThread(threading.Thread):
|
|
|
|
"""Background thread that updates the local model from sample trajectories.
|
|
|
|
|
|
|
|
The learner thread communicates with the main thread through Queues. This
|
|
|
|
is needed since Ray operations can only be run on the main thread. In
|
|
|
|
addition, moving heavyweight gradient ops session runs off the main thread
|
|
|
|
improves overall throughput.
|
|
|
|
"""
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
local_worker: RolloutWorker,
|
|
|
|
minibatch_buffer_size: int,
|
|
|
|
num_sgd_iter: int,
|
|
|
|
learner_queue_size: int,
|
|
|
|
learner_queue_timeout: int,
|
|
|
|
):
|
2020-05-21 10:16:18 -07:00
|
|
|
"""Initialize the learner thread.
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
local_worker: process local rollout worker holding
|
2020-05-21 10:16:18 -07:00
|
|
|
policies this thread will call learn_on_batch() on
|
2022-06-01 11:27:54 -07:00
|
|
|
minibatch_buffer_size: max number of train batches to store
|
2020-05-21 10:16:18 -07:00
|
|
|
in the minibatching buffer
|
2022-06-01 11:27:54 -07:00
|
|
|
num_sgd_iter: number of passes to learn on per train batch
|
|
|
|
learner_queue_size: max size of queue of inbound
|
2020-05-21 10:16:18 -07:00
|
|
|
train batches to this thread
|
2022-06-01 11:27:54 -07:00
|
|
|
learner_queue_timeout: raise an exception if the queue has
|
2020-05-21 10:16:18 -07:00
|
|
|
been empty for this long in seconds
|
|
|
|
"""
|
|
|
|
threading.Thread.__init__(self)
|
|
|
|
self.learner_queue_size = WindowStat("size", 50)
|
|
|
|
self.local_worker = local_worker
|
|
|
|
self.inqueue = queue.Queue(maxsize=learner_queue_size)
|
|
|
|
self.outqueue = queue.Queue()
|
|
|
|
self.minibatch_buffer = MinibatchBuffer(
|
|
|
|
inqueue=self.inqueue,
|
|
|
|
size=minibatch_buffer_size,
|
|
|
|
timeout=learner_queue_timeout,
|
|
|
|
num_passes=num_sgd_iter,
|
2022-01-29 18:41:57 -08:00
|
|
|
init_num_passes=num_sgd_iter,
|
|
|
|
)
|
2022-06-23 21:30:01 +02:00
|
|
|
self.queue_timer = _Timer()
|
|
|
|
self.grad_timer = _Timer()
|
|
|
|
self.load_timer = _Timer()
|
|
|
|
self.load_wait_timer = _Timer()
|
2020-05-21 10:16:18 -07:00
|
|
|
self.daemon = True
|
|
|
|
self.weights_updated = False
|
2021-09-30 16:39:05 +02:00
|
|
|
self.learner_info = {}
|
2020-05-21 10:16:18 -07:00
|
|
|
self.stopped = False
|
|
|
|
self.num_steps = 0
|
|
|
|
|
2020-12-24 06:30:33 -08:00
|
|
|
def run(self) -> None:
|
2020-07-11 22:06:35 +02:00
|
|
|
# Switch on eager mode if configured.
|
|
|
|
if self.local_worker.policy_config.get("framework") in ["tf2", "tfe"]:
|
|
|
|
tf1.enable_eager_execution()
|
2020-05-21 10:16:18 -07:00
|
|
|
while not self.stopped:
|
|
|
|
self.step()
|
|
|
|
|
2021-07-20 14:58:13 -04:00
|
|
|
def step(self) -> Optional[_NextValueNotReady]:
|
2020-05-21 10:16:18 -07:00
|
|
|
with self.queue_timer:
|
2021-02-08 12:05:16 +01:00
|
|
|
try:
|
|
|
|
batch, _ = self.minibatch_buffer.get()
|
|
|
|
except queue.Empty:
|
2021-06-01 15:40:28 +01:00
|
|
|
return _NextValueNotReady()
|
2020-05-21 10:16:18 -07:00
|
|
|
with self.grad_timer:
|
2021-09-30 16:39:05 +02:00
|
|
|
# Use LearnerInfoBuilder as a unified way to build the final
|
|
|
|
# results dict from `learn_on_loaded_batch` call(s).
|
|
|
|
# This makes sure results dicts always have the same structure
|
|
|
|
# no matter the setup (multi-GPU, multi-agent, minibatch SGD,
|
|
|
|
# tf vs torch).
|
|
|
|
learner_info_builder = LearnerInfoBuilder(num_devices=1)
|
|
|
|
multi_agent_results = self.local_worker.learn_on_batch(batch)
|
|
|
|
for pid, results in multi_agent_results.items():
|
|
|
|
learner_info_builder.add_learn_on_batch_results(results, pid)
|
|
|
|
self.learner_info = learner_info_builder.finalize()
|
2020-05-21 10:16:18 -07:00
|
|
|
self.weights_updated = True
|
|
|
|
|
|
|
|
self.num_steps += 1
|
2022-05-17 10:31:07 +02:00
|
|
|
# Put tuple: env-steps, agent-steps, and learner info into the queue.
|
|
|
|
self.outqueue.put((batch.count, batch.agent_steps(), self.learner_info))
|
2020-05-21 10:16:18 -07:00
|
|
|
self.learner_queue_size.push(self.inqueue.qsize())
|
|
|
|
|
2022-05-05 10:11:08 -04:00
|
|
|
def add_learner_metrics(self, result: Dict, overwrite_learner_info=True) -> Dict:
|
2022-06-11 15:10:39 +02:00
|
|
|
"""Add internal metrics to a result dict."""
|
2020-05-21 10:16:18 -07:00
|
|
|
|
|
|
|
def timer_to_ms(timer):
|
|
|
|
return round(1000 * timer.mean, 3)
|
|
|
|
|
2022-05-05 10:11:08 -04:00
|
|
|
if overwrite_learner_info:
|
|
|
|
result["info"].update(
|
|
|
|
{
|
|
|
|
"learner_queue": self.learner_queue_size.stats(),
|
|
|
|
LEARNER_INFO: copy.deepcopy(self.learner_info),
|
|
|
|
"timing_breakdown": {
|
|
|
|
"learner_grad_time_ms": timer_to_ms(self.grad_timer),
|
|
|
|
"learner_load_time_ms": timer_to_ms(self.load_timer),
|
|
|
|
"learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer),
|
|
|
|
"learner_dequeue_time_ms": timer_to_ms(self.queue_timer),
|
|
|
|
},
|
|
|
|
}
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
result["info"].update(
|
|
|
|
{
|
|
|
|
"learner_queue": self.learner_queue_size.stats(),
|
|
|
|
"timing_breakdown": {
|
|
|
|
"learner_grad_time_ms": timer_to_ms(self.grad_timer),
|
|
|
|
"learner_load_time_ms": timer_to_ms(self.load_timer),
|
|
|
|
"learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer),
|
|
|
|
"learner_dequeue_time_ms": timer_to_ms(self.queue_timer),
|
|
|
|
},
|
|
|
|
}
|
|
|
|
)
|
2020-05-21 10:16:18 -07:00
|
|
|
return result
|