From cd12ea7e0930a08d1f1877a2c4eb6c43b19e459f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 7 Jul 2017 15:20:02 -0700 Subject: [PATCH] [rllib] Pull out the GPU-parallel optimizer from policy gradients to common (#711) * refactor * docs * cleanup * clean up more * Update parallel.py * add imports from future --- python/ray/rllib/parallel.py | 249 ++++++++++++++++++++++ python/ray/rllib/policy_gradient/agent.py | 186 +++------------- python/ray/rllib/policy_gradient/utils.py | 52 ----- 3 files changed, 281 insertions(+), 206 deletions(-) create mode 100644 python/ray/rllib/parallel.py diff --git a/python/ray/rllib/parallel.py b/python/ray/rllib/parallel.py new file mode 100644 index 000000000..9cc7d9f32 --- /dev/null +++ b/python/ray/rllib/parallel.py @@ -0,0 +1,249 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +import os + +from tensorflow.python.client import timeline +import tensorflow as tf + + +class LocalSyncParallelOptimizer(object): + """Optimizer that runs in parallel across multiple local devices. + + LocalSyncParallelOptimizer automatically splits up and loads training data + onto specified local devices (e.g. GPUs) with `load_data()`. During a call to + `optimize()`, the devices compute gradients over slices of the data in + parallel. The gradients are then averaged and applied to the shared weights. + + The data loaded is pinned in device memory until the next call to + `load_data`, so you can make multiple passes (possibly in randomized order) + over the same data once loaded. + + This is similar to tf.train.SyncReplicasOptimizer, but works within a single + TensorFlow graph, i.e. implements in-graph replicated training: + + https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer + + Args: + optimizer: delegate TensorFlow optimizer object. + devices: list of the names of TensorFlow devices to parallelize over. + input_placeholders: list of inputs for the loss function. Tensors of + these shapes will be passed to build_loss() in order + to define the per-device loss ops. + per_device_batch_size: number of tuples to optimize over at a time per + device. In each call to `optimize()`, + `len(devices) * per_device_batch_size` tuples of + data will be processed. + build_loss: function that takes the specified inputs and returns an + object with a 'loss' property that is a scalar Tensor. For + example, ray.rllib.policy_gradient.ProximalPolicyLoss. + logdir: directory to place debugging output in. + """ + + def __init__( + self, + optimizer, + devices, + input_placeholders, + per_device_batch_size, + build_loss, + logdir): + self.optimizer = optimizer + self.devices = devices + self.batch_size = per_device_batch_size * len(devices) + self.per_device_batch_size = per_device_batch_size + self.input_placeholders = input_placeholders + self.build_loss = build_loss + self.logdir = logdir + + # First initialize the shared loss network + with tf.variable_scope("tower"): + self._shared_loss = build_loss(*input_placeholders) + + # Then setup the per-device loss graphs that use the shared weights + self._batch_index = tf.placeholder(tf.int32) + data_splits = zip( + *[tf.split(ph, len(devices)) for ph in input_placeholders]) + self._towers = [] + for device, device_placeholders in zip(self.devices, data_splits): + self._towers.append(self._setup_device(device, device_placeholders)) + + avg = average_gradients([t.grads for t in self._towers]) + self._train_op = self.optimizer.apply_gradients(avg) + + def load_data(self, sess, inputs, full_trace=False): + """Bulk loads the specified inputs into device memory. + + The shape of the inputs must conform to the shapes of the input + placeholders this optimizer was constructed with. + + The data is split equally across all the devices. If the data is not + evenly divisible by the batch size, excess data will be discarded. + + Args: + sess: TensorFlow session. + inputs: list of Tensors matching the input placeholders specified at + construction time of this optimizer. + full_trace: whether to profile data loading. + + Returns: + The number of tuples loaded per device. + """ + + feed_dict = {} + assert len(self.input_placeholders) == len(inputs) + for ph, arr in zip(self.input_placeholders, inputs): + truncated_arr = make_divisible_by(arr, self.batch_size) + feed_dict[ph] = truncated_arr + truncated_len = len(truncated_arr) + + if full_trace: + run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + else: + run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE) + run_metadata = tf.RunMetadata() + + sess.run( + [t.init_op for t in self._towers], + feed_dict=feed_dict, + options=run_options, + run_metadata=run_metadata) + if full_trace: + trace = timeline.Timeline(step_stats=run_metadata.step_stats) + trace_file = open(os.path.join(self.logdir, "timeline-load.json"), "w") + trace_file.write(trace.generate_chrome_trace_format()) + + tuples_per_device = truncated_len / len(self.devices) + assert tuples_per_device % self.per_device_batch_size == 0 + return tuples_per_device + + def optimize( + self, sess, batch_index, + extra_ops=[], extra_feed_dict={}, file_writer=None): + """Run a single step of SGD. + + Runs a SGD step over a slice of the preloaded batch with size given by + self.per_device_batch_size and offset given by the batch_index argument. + + Updates shared model weights based on the averaged per-device gradients. + + Args: + sess: TensorFlow session. + batch_index: offset into the preloaded data. This value must be + between `0` and `tuples_per_device`. The amount of data + to process is always fixed to `per_device_batch_size`. + extra_ops: extra ops to run with this step (e.g. for metrics). + extra_feed_dict: extra args to feed into this session run. + file_writer: if specified, tf metrics will be written out using this. + + Returns: + the outputs of extra_ops evaluated over the batch. + """ + + if file_writer: + run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + else: + run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE) + run_metadata = tf.RunMetadata() + + feed_dict = {self._batch_index: batch_index} + feed_dict.update(extra_feed_dict) + outs = sess.run( + [self._train_op] + extra_ops, + feed_dict=feed_dict, + options=run_options, + run_metadata=run_metadata) + + if file_writer: + trace = timeline.Timeline(step_stats=run_metadata.step_stats) + trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"), "w") + trace_file.write(trace.generate_chrome_trace_format()) + file_writer.add_run_metadata( + run_metadata, "sgd_train_{}".format(batch_index)) + + return outs[1:] + + def get_common_loss(self): + return self._shared_loss + + def get_device_losses(self): + return [t.loss_object for t in self._towers] + + def _setup_device(self, device, device_input_placeholders): + with tf.device(device): + with tf.variable_scope("tower", reuse=True): + device_input_batches = [] + device_input_slices = [] + for ph in device_input_placeholders: + current_batch = tf.Variable( + ph, trainable=False, validate_shape=False, collections=[]) + device_input_batches.append(current_batch) + current_slice = tf.slice( + current_batch, + [self._batch_index] + [0] * len(ph.shape[1:]), + [self.per_device_batch_size] + [-1] * len(ph.shape[1:])) + current_slice.set_shape(ph.shape) + device_input_slices.append(current_slice) + device_loss_obj = self.build_loss(*device_input_slices) + device_grads = self.optimizer.compute_gradients( + device_loss_obj.loss, colocate_gradients_with_ops=True) + return Tower( + tf.group(*[batch.initializer for batch in device_input_batches]), + device_grads, + device_loss_obj) + + +# Each tower is a copy of the loss graph pinned to a specific device. +Tower = namedtuple("Tower", ["init_op", "grads", "loss_object"]) + + +def make_divisible_by(array, n): + return array[0:array.shape[0] - array.shape[0] % n] + + +def average_gradients(tower_grads): + """Averages gradients across towers. + + Calculate the average gradient for each shared variable across all towers. + Note that this function provides a synchronization point across all towers. + + Args: + tower_grads: List of lists of (gradient, variable) tuples. The outer list + is over individual gradients. The inner list is over the gradient + calculation for each tower. + + Returns: + List of pairs of (gradient, variable) where the gradient has been averaged + across all towers. + + TODO(ekl): We could use NCCL if this becomes a bottleneck. + """ + + average_grads = [] + for grad_and_vars in zip(*tower_grads): + + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + grads = [] + for g, _ in grad_and_vars: + if g is not None: + # Add 0 dimension to the gradients to represent the tower. + expanded_g = tf.expand_dims(g, 0) + + # Append on a 'tower' dimension which we will average over below. + grads.append(expanded_g) + + # Average over the 'tower' dimension. + grad = tf.concat(axis=0, values=grads) + grad = tf.reduce_mean(grad, 0) + + # Keep in mind that the Variables are redundant because they are shared + # across towers. So .. we will just return the first tower's pointer to + # the Variable. + v = grad_and_vars[0][1] + grad_and_var = (grad, v) + average_grads.append(grad_and_var) + + return average_grads diff --git a/python/ray/rllib/policy_gradient/agent.py b/python/ray/rllib/policy_gradient/agent.py index 9c9cd2110..9aa44e250 100644 --- a/python/ray/rllib/policy_gradient/agent.py +++ b/python/ray/rllib/policy_gradient/agent.py @@ -2,24 +2,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from collections import namedtuple - import gym.spaces import tensorflow as tf import os -from tensorflow.python.client import timeline from tensorflow.python import debug as tf_debug import ray +from ray.rllib.parallel import LocalSyncParallelOptimizer from ray.rllib.policy_gradient.distributions import Categorical, DiagGaussian from ray.rllib.policy_gradient.env import BatchedEnv from ray.rllib.policy_gradient.loss import ProximalPolicyLoss from ray.rllib.policy_gradient.filter import MeanStdFilter from ray.rllib.policy_gradient.rollout import rollouts, add_advantage_values -from ray.rllib.policy_gradient.utils import ( - make_divisible_by, average_gradients) # TODO(pcm): Make sure that both observation_filter and reward_filter # are correctly handled, i.e. (a) the values are accumulated accross @@ -27,9 +23,6 @@ from ray.rllib.policy_gradient.utils import ( # correctly and no default arguments are used, and (c) they are saved # as part of the checkpoint so training can resume properly. -# Each tower is a copy of the policy graph pinned to a specific device. -Tower = namedtuple("Tower", ["init_op", "grads", "policy"]) - class Agent(object): """ @@ -38,14 +31,6 @@ class Agent(object): Initializes the tensorflow graphs for both training and evaluation. One common policy graph is initialized on '/cpu:0' and holds all the shared network weights. When run as a remote agent, only this graph is used. - - When the agent is initialized locally with multiple GPU devices, copies of - the policy graph are also placed on each GPU. These per-GPU graphs share the - common policy network weights but take device-local input tensors. - - The idea here is that training data can be bulk-loaded onto these - device-local variables. Synchronous SGD can then be run in parallel over - this GPU-local data. """ def __init__(self, name, batchsize, preprocessor, config, logdir, is_remote): @@ -96,15 +81,6 @@ class Agent(object): "currently not supported") self.prev_logits = tf.placeholder(tf.float32, shape=(None, self.logit_dim)) - data_splits = zip( - tf.split(self.observations, len(devices)), - tf.split(self.advantages, len(devices)), - tf.split(self.actions, len(devices)), - tf.split(self.prev_logits, len(devices))) - - # Parallel SGD ops - self.towers = [] - self.batch_index = tf.placeholder(tf.int32) assert config["sgd_batchsize"] % len(devices) == 0, \ "Batch size must be evenly divisible by devices" if is_remote: @@ -113,26 +89,33 @@ class Agent(object): else: self.batch_size = config["sgd_batchsize"] self.per_device_batch_size = int(self.batch_size / len(devices)) - self.optimizer = tf.train.AdamOptimizer(self.config["sgd_stepsize"]) - self.setup_common_policy( - self.observations, self.advantages, self.actions, self.prev_logits) - for device, (obs, adv, acts, plog) in zip(devices, data_splits): - self.towers.append( - self.setup_per_device_policy(device, obs, adv, acts, plog)) - avg = average_gradients([t.grads for t in self.towers]) - self.train_op = self.optimizer.apply_gradients(avg) + def build_loss(obs, advs, acts, plog): + return ProximalPolicyLoss( + self.env.observation_space, self.env.action_space, + obs, advs, acts, plog, self.logit_dim, + self.kl_coeff, self.distribution_class, self.config, self.sess) + + self.par_opt = LocalSyncParallelOptimizer( + tf.train.AdamOptimizer(self.config["sgd_stepsize"]), + self.devices, + [self.observations, self.advantages, self.actions, self.prev_logits], + self.per_device_batch_size, + build_loss, + self.logdir) # Metric ops with tf.name_scope("test_outputs"): + policies = self.par_opt.get_device_losses() self.mean_loss = tf.reduce_mean( - tf.stack(values=[t.policy.loss for t in self.towers]), 0) + tf.stack(values=[policy.loss for policy in policies]), 0) self.mean_kl = tf.reduce_mean( - tf.stack(values=[t.policy.mean_kl for t in self.towers]), 0) + tf.stack(values=[policy.mean_kl for policy in policies]), 0) self.mean_entropy = tf.reduce_mean( - tf.stack(values=[t.policy.mean_entropy for t in self.towers]), 0) + tf.stack(values=[policy.mean_entropy for policy in policies]), 0) # References to the model weights + self.common_policy = self.par_opt.get_common_loss() self.variables = ray.experimental.TensorFlowVariables( self.common_policy.loss, self.sess) @@ -140,127 +123,22 @@ class Agent(object): self.reward_filter = MeanStdFilter((), clip=5.0) self.sess.run(tf.global_variables_initializer()) - def setup_common_policy(self, observations, advantages, actions, prev_log): - with tf.variable_scope("tower"): - self.common_policy = ProximalPolicyLoss( - self.env.observation_space, self.env.action_space, - observations, advantages, actions, prev_log, self.logit_dim, - self.kl_coeff, self.distribution_class, self.config, self.sess) - - def setup_per_device_policy( - self, device, observations, advantages, actions, prev_log): - with tf.device(device): - with tf.variable_scope("tower", reuse=True): - all_obs = tf.Variable( - observations, trainable=False, validate_shape=False, - collections=[]) - all_adv = tf.Variable( - advantages, trainable=False, validate_shape=False, collections=[]) - all_acts = tf.Variable( - actions, trainable=False, validate_shape=False, collections=[]) - all_plog = tf.Variable( - prev_log, trainable=False, validate_shape=False, collections=[]) - obs_slice = tf.slice( - all_obs, - [self.batch_index] + [0] * len(self.preprocessor.shape), - [self.per_device_batch_size] + [-1] * len(self.preprocessor.shape)) - obs_slice.set_shape(observations.shape) - adv_slice = tf.slice( - all_adv, [self.batch_index], [self.per_device_batch_size]) - acts_slice = tf.slice( - all_acts, - [self.batch_index] + [0] * len(self.action_shape), - [self.per_device_batch_size] + [-1] * len(self.action_shape)) - plog_slice = tf.slice( - all_plog, [self.batch_index, 0], [self.per_device_batch_size, -1]) - policy = ProximalPolicyLoss( - self.env.observation_space, self.env.action_space, - obs_slice, adv_slice, acts_slice, plog_slice, self.logit_dim, - self.kl_coeff, self.distribution_class, self.config, self.sess) - grads = self.optimizer.compute_gradients( - policy.loss, colocate_gradients_with_ops=True) - - return Tower( - tf.group( - *[all_obs.initializer, - all_adv.initializer, - all_acts.initializer, - all_plog.initializer]), - grads, - policy) - def load_data(self, trajectories, full_trace): - """ - Bulk loads the specified trajectories into device memory. - - The data is split equally across all the devices. - - Returns: - The number of tuples loaded per device. - """ - - truncated_obs = make_divisible_by( - trajectories["observations"], self.batch_size) - if full_trace: - run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) - else: - run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE) - run_metadata = tf.RunMetadata() - self.sess.run( - [t.init_op for t in self.towers], - feed_dict={ - self.observations: truncated_obs, - self.advantages: make_divisible_by( - trajectories["advantages"], self.batch_size), - self.actions: make_divisible_by( - trajectories["actions"].squeeze(), self.batch_size), - self.prev_logits: make_divisible_by( - trajectories["logprobs"], self.batch_size), - }, - options=run_options, - run_metadata=run_metadata) - if full_trace: - trace = timeline.Timeline(step_stats=run_metadata.step_stats) - trace_file = open(os.path.join(self.logdir, "timeline-load.json"), "w") - trace_file.write(trace.generate_chrome_trace_format()) - - tuples_per_device = len(truncated_obs) / len(self.devices) - assert tuples_per_device % self.per_device_batch_size == 0 - return tuples_per_device + return self.par_opt.load_data( + self.sess, + [trajectories["observations"], + trajectories["advantages"], + trajectories["actions"].squeeze(), + trajectories["logprobs"]], + full_trace=full_trace) def run_sgd_minibatch(self, batch_index, kl_coeff, full_trace, file_writer): - """ - Run a single step of SGD. - - Runs a SGD step over the batch with index batch_index as created by - load_rollouts_data(), updating local weights. - - Returns: - (mean_loss, mean_kl, mean_entropy) evaluated over the batch. - """ - - if full_trace: - run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) - else: - run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE) - run_metadata = tf.RunMetadata() - - _, loss, kl, entropy = self.sess.run( - [self.train_op, self.mean_loss, self.mean_kl, self.mean_entropy], - feed_dict={ - self.batch_index: batch_index, - self.kl_coeff: kl_coeff}, - options=run_options, - run_metadata=run_metadata) - - if full_trace: - trace = timeline.Timeline(step_stats=run_metadata.step_stats) - trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"), "w") - trace_file.write(trace.generate_chrome_trace_format()) - file_writer.add_run_metadata( - run_metadata, "sgd_train_{}".format(batch_index)) - - return loss, kl, entropy + return self.par_opt.optimize( + self.sess, + batch_index, + extra_ops=[self.mean_loss, self.mean_kl, self.mean_entropy], + extra_feed_dict={self.kl_coeff: kl_coeff}, + file_writer=file_writer if full_trace else None) def get_weights(self): return self.variables.get_weights() diff --git a/python/ray/rllib/policy_gradient/utils.py b/python/ray/rllib/policy_gradient/utils.py index 8890fa6dd..0762b930c 100644 --- a/python/ray/rllib/policy_gradient/utils.py +++ b/python/ray/rllib/policy_gradient/utils.py @@ -3,7 +3,6 @@ from __future__ import division from __future__ import print_function import numpy as np -import tensorflow as tf def flatten(weights, start=0, stop=2): @@ -35,54 +34,3 @@ def shuffle(trajectory): for key, val in trajectory.items(): trajectory[key] = val[permutation] return trajectory - - -def make_divisible_by(array, n): - return array[0:array.shape[0] - array.shape[0] % n] - - -def average_gradients(tower_grads): - """ - Average gradients across towers. - - Calculate the average gradient for each shared variable across all towers. - Note that this function provides a synchronization point across all towers. - - Args: - tower_grads: List of lists of (gradient, variable) tuples. The outer list - is over individual gradients. The inner list is over the gradient - calculation for each tower. - - Returns: - List of pairs of (gradient, variable) where the gradient has been averaged - across all towers. - - TODO(ekl): We could use NCCL if this becomes a bottleneck. - """ - - average_grads = [] - for grad_and_vars in zip(*tower_grads): - - # Note that each grad_and_vars looks like the following: - # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) - grads = [] - for g, _ in grad_and_vars: - if g is not None: - # Add 0 dimension to the gradients to represent the tower. - expanded_g = tf.expand_dims(g, 0) - - # Append on a 'tower' dimension which we will average over below. - grads.append(expanded_g) - - # Average over the 'tower' dimension. - grad = tf.concat(axis=0, values=grads) - grad = tf.reduce_mean(grad, 0) - - # Keep in mind that the Variables are redundant because they are shared - # across towers. So .. we will just return the first tower's pointer to - # the Variable. - v = grad_and_vars[0][1] - grad_and_var = (grad, v) - average_grads.append(grad_and_var) - - return average_grads