ray/rllib/execution/multi_gpu_learner_thread.py

141 lines
5.4 KiB
Python

import logging
import threading
from six.moves import queue
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.execution.learner_thread import LearnerThread
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.timer import TimerStat
from ray.rllib.evaluation.rollout_worker import RolloutWorker
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
class MultiGPULearnerThread(LearnerThread):
"""Learner that can use multiple GPUs and parallel loading.
This class is used for async sampling algorithms.
"""
def __init__(
self,
local_worker: RolloutWorker,
num_gpus: int = 1,
lr=None, # deprecated.
train_batch_size: int = 500,
num_multi_gpu_tower_stacks: int = 1,
minibatch_buffer_size: int = 1,
num_sgd_iter: int = 1,
learner_queue_size: int = 16,
learner_queue_timeout: int = 300,
num_data_load_threads: int = 16,
_fake_gpus: bool = False):
"""Initializes a MultiGPULearnerThread instance.
Args:
local_worker (RolloutWorker): Local RolloutWorker holding
policies this thread will call load_data() and optimizer() on.
num_gpus (int): Number of GPUs to use for data-parallel SGD.
train_batch_size (int): Size of batches (minibatches if
`num_sgd_iter` > 1) to learn on.
num_multi_gpu_tower_stacks (int): Number of buffers to parallelly
load data into on one device. Each buffer is of size of
`train_batch_size` and hence increases GPU memory usage
accordingly.
minibatch_buffer_size (int): Max number of train batches to store
in the minibatch buffer.
num_sgd_iter (int): Number of passes to learn on per train batch
(minibatch if `num_sgd_iter` > 1).
learner_queue_size (int): Max size of queue of inbound
train batches to this thread.
num_data_load_threads (int): Number of threads to use to load
data into GPU memory in parallel.
"""
LearnerThread.__init__(self, local_worker, minibatch_buffer_size,
num_sgd_iter, learner_queue_size,
learner_queue_timeout)
self.train_batch_size = train_batch_size
# TODO: (sven) Allow multi-GPU to work for multi-agent as well.
self.policy = self.local_worker.policy_map[DEFAULT_POLICY_ID]
logger.info("MultiGPULearnerThread devices {}".format(
self.policy.devices))
assert self.train_batch_size % len(self.policy.devices) == 0
assert self.train_batch_size >= len(self.policy.devices),\
"batch too small"
if set(self.local_worker.policy_map.keys()) != {DEFAULT_POLICY_ID}:
raise NotImplementedError("Multi-gpu mode for multi-agent")
self.tower_stack_indices = list(range(num_multi_gpu_tower_stacks))
self.idle_tower_stacks = queue.Queue()
self.ready_tower_stacks = queue.Queue()
for idx in self.tower_stack_indices:
self.idle_tower_stacks.put(idx)
for i in range(num_data_load_threads):
self.loader_thread = _MultiGPULoaderThread(
self, share_stats=(i == 0))
self.loader_thread.start()
self.minibatch_buffer = MinibatchBuffer(
self.ready_tower_stacks, minibatch_buffer_size,
learner_queue_timeout, num_sgd_iter)
@override(LearnerThread)
def step(self) -> None:
assert self.loader_thread.is_alive()
with self.load_wait_timer:
buffer_idx, released = self.minibatch_buffer.get()
with self.grad_timer:
fetches = self.policy.learn_on_loaded_batch(
offset=0, buffer_index=buffer_idx)
self.weights_updated = True
self.stats = {DEFAULT_POLICY_ID: get_learner_stats(fetches)}
if released:
self.idle_tower_stacks.put(buffer_idx)
self.outqueue.put(
(self.policy.get_num_samples_loaded_into_buffer(buffer_idx),
self.stats))
self.learner_queue_size.push(self.inqueue.qsize())
class _MultiGPULoaderThread(threading.Thread):
def __init__(self, multi_gpu_learner_thread: MultiGPULearnerThread,
share_stats: bool):
threading.Thread.__init__(self)
self.multi_gpu_learner_thread = multi_gpu_learner_thread
self.daemon = True
if share_stats:
self.queue_timer = multi_gpu_learner_thread.queue_timer
self.load_timer = multi_gpu_learner_thread.load_timer
else:
self.queue_timer = TimerStat()
self.load_timer = TimerStat()
def run(self) -> None:
while True:
self._step()
def _step(self) -> None:
s = self.multi_gpu_learner_thread
policy = s.policy
with self.queue_timer:
batch = s.inqueue.get()
buffer_idx = s.idle_tower_stacks.get()
with self.load_timer:
policy.load_batch_into_buffer(batch=batch, buffer_index=buffer_idx)
s.ready_tower_stacks.put(buffer_idx)