mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Pull out the GPU-parallel optimizer from policy gradients to common (#711)
* refactor * docs * cleanup * clean up more * Update parallel.py * add imports from future
This commit is contained in:
parent
5b3d0c00f2
commit
cd12ea7e09
3 changed files with 281 additions and 206 deletions
249
python/ray/rllib/parallel.py
Normal file
249
python/ray/rllib/parallel.py
Normal file
|
@ -0,0 +1,249 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import os
|
||||
|
||||
from tensorflow.python.client import timeline
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class LocalSyncParallelOptimizer(object):
|
||||
"""Optimizer that runs in parallel across multiple local devices.
|
||||
|
||||
LocalSyncParallelOptimizer automatically splits up and loads training data
|
||||
onto specified local devices (e.g. GPUs) with `load_data()`. During a call to
|
||||
`optimize()`, the devices compute gradients over slices of the data in
|
||||
parallel. The gradients are then averaged and applied to the shared weights.
|
||||
|
||||
The data loaded is pinned in device memory until the next call to
|
||||
`load_data`, so you can make multiple passes (possibly in randomized order)
|
||||
over the same data once loaded.
|
||||
|
||||
This is similar to tf.train.SyncReplicasOptimizer, but works within a single
|
||||
TensorFlow graph, i.e. implements in-graph replicated training:
|
||||
|
||||
https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
|
||||
|
||||
Args:
|
||||
optimizer: delegate TensorFlow optimizer object.
|
||||
devices: list of the names of TensorFlow devices to parallelize over.
|
||||
input_placeholders: list of inputs for the loss function. Tensors of
|
||||
these shapes will be passed to build_loss() in order
|
||||
to define the per-device loss ops.
|
||||
per_device_batch_size: number of tuples to optimize over at a time per
|
||||
device. In each call to `optimize()`,
|
||||
`len(devices) * per_device_batch_size` tuples of
|
||||
data will be processed.
|
||||
build_loss: function that takes the specified inputs and returns an
|
||||
object with a 'loss' property that is a scalar Tensor. For
|
||||
example, ray.rllib.policy_gradient.ProximalPolicyLoss.
|
||||
logdir: directory to place debugging output in.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
devices,
|
||||
input_placeholders,
|
||||
per_device_batch_size,
|
||||
build_loss,
|
||||
logdir):
|
||||
self.optimizer = optimizer
|
||||
self.devices = devices
|
||||
self.batch_size = per_device_batch_size * len(devices)
|
||||
self.per_device_batch_size = per_device_batch_size
|
||||
self.input_placeholders = input_placeholders
|
||||
self.build_loss = build_loss
|
||||
self.logdir = logdir
|
||||
|
||||
# First initialize the shared loss network
|
||||
with tf.variable_scope("tower"):
|
||||
self._shared_loss = build_loss(*input_placeholders)
|
||||
|
||||
# Then setup the per-device loss graphs that use the shared weights
|
||||
self._batch_index = tf.placeholder(tf.int32)
|
||||
data_splits = zip(
|
||||
*[tf.split(ph, len(devices)) for ph in input_placeholders])
|
||||
self._towers = []
|
||||
for device, device_placeholders in zip(self.devices, data_splits):
|
||||
self._towers.append(self._setup_device(device, device_placeholders))
|
||||
|
||||
avg = average_gradients([t.grads for t in self._towers])
|
||||
self._train_op = self.optimizer.apply_gradients(avg)
|
||||
|
||||
def load_data(self, sess, inputs, full_trace=False):
|
||||
"""Bulk loads the specified inputs into device memory.
|
||||
|
||||
The shape of the inputs must conform to the shapes of the input
|
||||
placeholders this optimizer was constructed with.
|
||||
|
||||
The data is split equally across all the devices. If the data is not
|
||||
evenly divisible by the batch size, excess data will be discarded.
|
||||
|
||||
Args:
|
||||
sess: TensorFlow session.
|
||||
inputs: list of Tensors matching the input placeholders specified at
|
||||
construction time of this optimizer.
|
||||
full_trace: whether to profile data loading.
|
||||
|
||||
Returns:
|
||||
The number of tuples loaded per device.
|
||||
"""
|
||||
|
||||
feed_dict = {}
|
||||
assert len(self.input_placeholders) == len(inputs)
|
||||
for ph, arr in zip(self.input_placeholders, inputs):
|
||||
truncated_arr = make_divisible_by(arr, self.batch_size)
|
||||
feed_dict[ph] = truncated_arr
|
||||
truncated_len = len(truncated_arr)
|
||||
|
||||
if full_trace:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
else:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
|
||||
sess.run(
|
||||
[t.init_op for t in self._towers],
|
||||
feed_dict=feed_dict,
|
||||
options=run_options,
|
||||
run_metadata=run_metadata)
|
||||
if full_trace:
|
||||
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
|
||||
trace_file = open(os.path.join(self.logdir, "timeline-load.json"), "w")
|
||||
trace_file.write(trace.generate_chrome_trace_format())
|
||||
|
||||
tuples_per_device = truncated_len / len(self.devices)
|
||||
assert tuples_per_device % self.per_device_batch_size == 0
|
||||
return tuples_per_device
|
||||
|
||||
def optimize(
|
||||
self, sess, batch_index,
|
||||
extra_ops=[], extra_feed_dict={}, file_writer=None):
|
||||
"""Run a single step of SGD.
|
||||
|
||||
Runs a SGD step over a slice of the preloaded batch with size given by
|
||||
self.per_device_batch_size and offset given by the batch_index argument.
|
||||
|
||||
Updates shared model weights based on the averaged per-device gradients.
|
||||
|
||||
Args:
|
||||
sess: TensorFlow session.
|
||||
batch_index: offset into the preloaded data. This value must be
|
||||
between `0` and `tuples_per_device`. The amount of data
|
||||
to process is always fixed to `per_device_batch_size`.
|
||||
extra_ops: extra ops to run with this step (e.g. for metrics).
|
||||
extra_feed_dict: extra args to feed into this session run.
|
||||
file_writer: if specified, tf metrics will be written out using this.
|
||||
|
||||
Returns:
|
||||
the outputs of extra_ops evaluated over the batch.
|
||||
"""
|
||||
|
||||
if file_writer:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
else:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
|
||||
feed_dict = {self._batch_index: batch_index}
|
||||
feed_dict.update(extra_feed_dict)
|
||||
outs = sess.run(
|
||||
[self._train_op] + extra_ops,
|
||||
feed_dict=feed_dict,
|
||||
options=run_options,
|
||||
run_metadata=run_metadata)
|
||||
|
||||
if file_writer:
|
||||
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
|
||||
trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"), "w")
|
||||
trace_file.write(trace.generate_chrome_trace_format())
|
||||
file_writer.add_run_metadata(
|
||||
run_metadata, "sgd_train_{}".format(batch_index))
|
||||
|
||||
return outs[1:]
|
||||
|
||||
def get_common_loss(self):
|
||||
return self._shared_loss
|
||||
|
||||
def get_device_losses(self):
|
||||
return [t.loss_object for t in self._towers]
|
||||
|
||||
def _setup_device(self, device, device_input_placeholders):
|
||||
with tf.device(device):
|
||||
with tf.variable_scope("tower", reuse=True):
|
||||
device_input_batches = []
|
||||
device_input_slices = []
|
||||
for ph in device_input_placeholders:
|
||||
current_batch = tf.Variable(
|
||||
ph, trainable=False, validate_shape=False, collections=[])
|
||||
device_input_batches.append(current_batch)
|
||||
current_slice = tf.slice(
|
||||
current_batch,
|
||||
[self._batch_index] + [0] * len(ph.shape[1:]),
|
||||
[self.per_device_batch_size] + [-1] * len(ph.shape[1:]))
|
||||
current_slice.set_shape(ph.shape)
|
||||
device_input_slices.append(current_slice)
|
||||
device_loss_obj = self.build_loss(*device_input_slices)
|
||||
device_grads = self.optimizer.compute_gradients(
|
||||
device_loss_obj.loss, colocate_gradients_with_ops=True)
|
||||
return Tower(
|
||||
tf.group(*[batch.initializer for batch in device_input_batches]),
|
||||
device_grads,
|
||||
device_loss_obj)
|
||||
|
||||
|
||||
# Each tower is a copy of the loss graph pinned to a specific device.
|
||||
Tower = namedtuple("Tower", ["init_op", "grads", "loss_object"])
|
||||
|
||||
|
||||
def make_divisible_by(array, n):
|
||||
return array[0:array.shape[0] - array.shape[0] % n]
|
||||
|
||||
|
||||
def average_gradients(tower_grads):
|
||||
"""Averages gradients across towers.
|
||||
|
||||
Calculate the average gradient for each shared variable across all towers.
|
||||
Note that this function provides a synchronization point across all towers.
|
||||
|
||||
Args:
|
||||
tower_grads: List of lists of (gradient, variable) tuples. The outer list
|
||||
is over individual gradients. The inner list is over the gradient
|
||||
calculation for each tower.
|
||||
|
||||
Returns:
|
||||
List of pairs of (gradient, variable) where the gradient has been averaged
|
||||
across all towers.
|
||||
|
||||
TODO(ekl): We could use NCCL if this becomes a bottleneck.
|
||||
"""
|
||||
|
||||
average_grads = []
|
||||
for grad_and_vars in zip(*tower_grads):
|
||||
|
||||
# Note that each grad_and_vars looks like the following:
|
||||
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
|
||||
grads = []
|
||||
for g, _ in grad_and_vars:
|
||||
if g is not None:
|
||||
# Add 0 dimension to the gradients to represent the tower.
|
||||
expanded_g = tf.expand_dims(g, 0)
|
||||
|
||||
# Append on a 'tower' dimension which we will average over below.
|
||||
grads.append(expanded_g)
|
||||
|
||||
# Average over the 'tower' dimension.
|
||||
grad = tf.concat(axis=0, values=grads)
|
||||
grad = tf.reduce_mean(grad, 0)
|
||||
|
||||
# Keep in mind that the Variables are redundant because they are shared
|
||||
# across towers. So .. we will just return the first tower's pointer to
|
||||
# the Variable.
|
||||
v = grad_and_vars[0][1]
|
||||
grad_and_var = (grad, v)
|
||||
average_grads.append(grad_and_var)
|
||||
|
||||
return average_grads
|
|
@ -2,24 +2,20 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import gym.spaces
|
||||
import tensorflow as tf
|
||||
import os
|
||||
|
||||
from tensorflow.python.client import timeline
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
import ray
|
||||
|
||||
from ray.rllib.parallel import LocalSyncParallelOptimizer
|
||||
from ray.rllib.policy_gradient.distributions import Categorical, DiagGaussian
|
||||
from ray.rllib.policy_gradient.env import BatchedEnv
|
||||
from ray.rllib.policy_gradient.loss import ProximalPolicyLoss
|
||||
from ray.rllib.policy_gradient.filter import MeanStdFilter
|
||||
from ray.rllib.policy_gradient.rollout import rollouts, add_advantage_values
|
||||
from ray.rllib.policy_gradient.utils import (
|
||||
make_divisible_by, average_gradients)
|
||||
|
||||
# TODO(pcm): Make sure that both observation_filter and reward_filter
|
||||
# are correctly handled, i.e. (a) the values are accumulated accross
|
||||
|
@ -27,9 +23,6 @@ from ray.rllib.policy_gradient.utils import (
|
|||
# correctly and no default arguments are used, and (c) they are saved
|
||||
# as part of the checkpoint so training can resume properly.
|
||||
|
||||
# Each tower is a copy of the policy graph pinned to a specific device.
|
||||
Tower = namedtuple("Tower", ["init_op", "grads", "policy"])
|
||||
|
||||
|
||||
class Agent(object):
|
||||
"""
|
||||
|
@ -38,14 +31,6 @@ class Agent(object):
|
|||
Initializes the tensorflow graphs for both training and evaluation.
|
||||
One common policy graph is initialized on '/cpu:0' and holds all the shared
|
||||
network weights. When run as a remote agent, only this graph is used.
|
||||
|
||||
When the agent is initialized locally with multiple GPU devices, copies of
|
||||
the policy graph are also placed on each GPU. These per-GPU graphs share the
|
||||
common policy network weights but take device-local input tensors.
|
||||
|
||||
The idea here is that training data can be bulk-loaded onto these
|
||||
device-local variables. Synchronous SGD can then be run in parallel over
|
||||
this GPU-local data.
|
||||
"""
|
||||
|
||||
def __init__(self, name, batchsize, preprocessor, config, logdir, is_remote):
|
||||
|
@ -96,15 +81,6 @@ class Agent(object):
|
|||
"currently not supported")
|
||||
self.prev_logits = tf.placeholder(tf.float32, shape=(None, self.logit_dim))
|
||||
|
||||
data_splits = zip(
|
||||
tf.split(self.observations, len(devices)),
|
||||
tf.split(self.advantages, len(devices)),
|
||||
tf.split(self.actions, len(devices)),
|
||||
tf.split(self.prev_logits, len(devices)))
|
||||
|
||||
# Parallel SGD ops
|
||||
self.towers = []
|
||||
self.batch_index = tf.placeholder(tf.int32)
|
||||
assert config["sgd_batchsize"] % len(devices) == 0, \
|
||||
"Batch size must be evenly divisible by devices"
|
||||
if is_remote:
|
||||
|
@ -113,26 +89,33 @@ class Agent(object):
|
|||
else:
|
||||
self.batch_size = config["sgd_batchsize"]
|
||||
self.per_device_batch_size = int(self.batch_size / len(devices))
|
||||
self.optimizer = tf.train.AdamOptimizer(self.config["sgd_stepsize"])
|
||||
self.setup_common_policy(
|
||||
self.observations, self.advantages, self.actions, self.prev_logits)
|
||||
for device, (obs, adv, acts, plog) in zip(devices, data_splits):
|
||||
self.towers.append(
|
||||
self.setup_per_device_policy(device, obs, adv, acts, plog))
|
||||
|
||||
avg = average_gradients([t.grads for t in self.towers])
|
||||
self.train_op = self.optimizer.apply_gradients(avg)
|
||||
def build_loss(obs, advs, acts, plog):
|
||||
return ProximalPolicyLoss(
|
||||
self.env.observation_space, self.env.action_space,
|
||||
obs, advs, acts, plog, self.logit_dim,
|
||||
self.kl_coeff, self.distribution_class, self.config, self.sess)
|
||||
|
||||
self.par_opt = LocalSyncParallelOptimizer(
|
||||
tf.train.AdamOptimizer(self.config["sgd_stepsize"]),
|
||||
self.devices,
|
||||
[self.observations, self.advantages, self.actions, self.prev_logits],
|
||||
self.per_device_batch_size,
|
||||
build_loss,
|
||||
self.logdir)
|
||||
|
||||
# Metric ops
|
||||
with tf.name_scope("test_outputs"):
|
||||
policies = self.par_opt.get_device_losses()
|
||||
self.mean_loss = tf.reduce_mean(
|
||||
tf.stack(values=[t.policy.loss for t in self.towers]), 0)
|
||||
tf.stack(values=[policy.loss for policy in policies]), 0)
|
||||
self.mean_kl = tf.reduce_mean(
|
||||
tf.stack(values=[t.policy.mean_kl for t in self.towers]), 0)
|
||||
tf.stack(values=[policy.mean_kl for policy in policies]), 0)
|
||||
self.mean_entropy = tf.reduce_mean(
|
||||
tf.stack(values=[t.policy.mean_entropy for t in self.towers]), 0)
|
||||
tf.stack(values=[policy.mean_entropy for policy in policies]), 0)
|
||||
|
||||
# References to the model weights
|
||||
self.common_policy = self.par_opt.get_common_loss()
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
self.common_policy.loss,
|
||||
self.sess)
|
||||
|
@ -140,127 +123,22 @@ class Agent(object):
|
|||
self.reward_filter = MeanStdFilter((), clip=5.0)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def setup_common_policy(self, observations, advantages, actions, prev_log):
|
||||
with tf.variable_scope("tower"):
|
||||
self.common_policy = ProximalPolicyLoss(
|
||||
self.env.observation_space, self.env.action_space,
|
||||
observations, advantages, actions, prev_log, self.logit_dim,
|
||||
self.kl_coeff, self.distribution_class, self.config, self.sess)
|
||||
|
||||
def setup_per_device_policy(
|
||||
self, device, observations, advantages, actions, prev_log):
|
||||
with tf.device(device):
|
||||
with tf.variable_scope("tower", reuse=True):
|
||||
all_obs = tf.Variable(
|
||||
observations, trainable=False, validate_shape=False,
|
||||
collections=[])
|
||||
all_adv = tf.Variable(
|
||||
advantages, trainable=False, validate_shape=False, collections=[])
|
||||
all_acts = tf.Variable(
|
||||
actions, trainable=False, validate_shape=False, collections=[])
|
||||
all_plog = tf.Variable(
|
||||
prev_log, trainable=False, validate_shape=False, collections=[])
|
||||
obs_slice = tf.slice(
|
||||
all_obs,
|
||||
[self.batch_index] + [0] * len(self.preprocessor.shape),
|
||||
[self.per_device_batch_size] + [-1] * len(self.preprocessor.shape))
|
||||
obs_slice.set_shape(observations.shape)
|
||||
adv_slice = tf.slice(
|
||||
all_adv, [self.batch_index], [self.per_device_batch_size])
|
||||
acts_slice = tf.slice(
|
||||
all_acts,
|
||||
[self.batch_index] + [0] * len(self.action_shape),
|
||||
[self.per_device_batch_size] + [-1] * len(self.action_shape))
|
||||
plog_slice = tf.slice(
|
||||
all_plog, [self.batch_index, 0], [self.per_device_batch_size, -1])
|
||||
policy = ProximalPolicyLoss(
|
||||
self.env.observation_space, self.env.action_space,
|
||||
obs_slice, adv_slice, acts_slice, plog_slice, self.logit_dim,
|
||||
self.kl_coeff, self.distribution_class, self.config, self.sess)
|
||||
grads = self.optimizer.compute_gradients(
|
||||
policy.loss, colocate_gradients_with_ops=True)
|
||||
|
||||
return Tower(
|
||||
tf.group(
|
||||
*[all_obs.initializer,
|
||||
all_adv.initializer,
|
||||
all_acts.initializer,
|
||||
all_plog.initializer]),
|
||||
grads,
|
||||
policy)
|
||||
|
||||
def load_data(self, trajectories, full_trace):
|
||||
"""
|
||||
Bulk loads the specified trajectories into device memory.
|
||||
|
||||
The data is split equally across all the devices.
|
||||
|
||||
Returns:
|
||||
The number of tuples loaded per device.
|
||||
"""
|
||||
|
||||
truncated_obs = make_divisible_by(
|
||||
trajectories["observations"], self.batch_size)
|
||||
if full_trace:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
else:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
self.sess.run(
|
||||
[t.init_op for t in self.towers],
|
||||
feed_dict={
|
||||
self.observations: truncated_obs,
|
||||
self.advantages: make_divisible_by(
|
||||
trajectories["advantages"], self.batch_size),
|
||||
self.actions: make_divisible_by(
|
||||
trajectories["actions"].squeeze(), self.batch_size),
|
||||
self.prev_logits: make_divisible_by(
|
||||
trajectories["logprobs"], self.batch_size),
|
||||
},
|
||||
options=run_options,
|
||||
run_metadata=run_metadata)
|
||||
if full_trace:
|
||||
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
|
||||
trace_file = open(os.path.join(self.logdir, "timeline-load.json"), "w")
|
||||
trace_file.write(trace.generate_chrome_trace_format())
|
||||
|
||||
tuples_per_device = len(truncated_obs) / len(self.devices)
|
||||
assert tuples_per_device % self.per_device_batch_size == 0
|
||||
return tuples_per_device
|
||||
return self.par_opt.load_data(
|
||||
self.sess,
|
||||
[trajectories["observations"],
|
||||
trajectories["advantages"],
|
||||
trajectories["actions"].squeeze(),
|
||||
trajectories["logprobs"]],
|
||||
full_trace=full_trace)
|
||||
|
||||
def run_sgd_minibatch(self, batch_index, kl_coeff, full_trace, file_writer):
|
||||
"""
|
||||
Run a single step of SGD.
|
||||
|
||||
Runs a SGD step over the batch with index batch_index as created by
|
||||
load_rollouts_data(), updating local weights.
|
||||
|
||||
Returns:
|
||||
(mean_loss, mean_kl, mean_entropy) evaluated over the batch.
|
||||
"""
|
||||
|
||||
if full_trace:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
else:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
|
||||
_, loss, kl, entropy = self.sess.run(
|
||||
[self.train_op, self.mean_loss, self.mean_kl, self.mean_entropy],
|
||||
feed_dict={
|
||||
self.batch_index: batch_index,
|
||||
self.kl_coeff: kl_coeff},
|
||||
options=run_options,
|
||||
run_metadata=run_metadata)
|
||||
|
||||
if full_trace:
|
||||
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
|
||||
trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"), "w")
|
||||
trace_file.write(trace.generate_chrome_trace_format())
|
||||
file_writer.add_run_metadata(
|
||||
run_metadata, "sgd_train_{}".format(batch_index))
|
||||
|
||||
return loss, kl, entropy
|
||||
return self.par_opt.optimize(
|
||||
self.sess,
|
||||
batch_index,
|
||||
extra_ops=[self.mean_loss, self.mean_kl, self.mean_entropy],
|
||||
extra_feed_dict={self.kl_coeff: kl_coeff},
|
||||
file_writer=file_writer if full_trace else None)
|
||||
|
||||
def get_weights(self):
|
||||
return self.variables.get_weights()
|
||||
|
|
|
@ -3,7 +3,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def flatten(weights, start=0, stop=2):
|
||||
|
@ -35,54 +34,3 @@ def shuffle(trajectory):
|
|||
for key, val in trajectory.items():
|
||||
trajectory[key] = val[permutation]
|
||||
return trajectory
|
||||
|
||||
|
||||
def make_divisible_by(array, n):
|
||||
return array[0:array.shape[0] - array.shape[0] % n]
|
||||
|
||||
|
||||
def average_gradients(tower_grads):
|
||||
"""
|
||||
Average gradients across towers.
|
||||
|
||||
Calculate the average gradient for each shared variable across all towers.
|
||||
Note that this function provides a synchronization point across all towers.
|
||||
|
||||
Args:
|
||||
tower_grads: List of lists of (gradient, variable) tuples. The outer list
|
||||
is over individual gradients. The inner list is over the gradient
|
||||
calculation for each tower.
|
||||
|
||||
Returns:
|
||||
List of pairs of (gradient, variable) where the gradient has been averaged
|
||||
across all towers.
|
||||
|
||||
TODO(ekl): We could use NCCL if this becomes a bottleneck.
|
||||
"""
|
||||
|
||||
average_grads = []
|
||||
for grad_and_vars in zip(*tower_grads):
|
||||
|
||||
# Note that each grad_and_vars looks like the following:
|
||||
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
|
||||
grads = []
|
||||
for g, _ in grad_and_vars:
|
||||
if g is not None:
|
||||
# Add 0 dimension to the gradients to represent the tower.
|
||||
expanded_g = tf.expand_dims(g, 0)
|
||||
|
||||
# Append on a 'tower' dimension which we will average over below.
|
||||
grads.append(expanded_g)
|
||||
|
||||
# Average over the 'tower' dimension.
|
||||
grad = tf.concat(axis=0, values=grads)
|
||||
grad = tf.reduce_mean(grad, 0)
|
||||
|
||||
# Keep in mind that the Variables are redundant because they are shared
|
||||
# across towers. So .. we will just return the first tower's pointer to
|
||||
# the Variable.
|
||||
v = grad_and_vars[0][1]
|
||||
grad_and_var = (grad, v)
|
||||
average_grads.append(grad_and_var)
|
||||
|
||||
return average_grads
|
||||
|
|
Loading…
Add table
Reference in a new issue