[Experimental] Add experimental distributed SGD API (#2858)

* check in sgd api

* idx

* foreach_worker foreach_model

* add feed_dict

* update

* yapf

* typo

* lint

* plasma op change

* fix plasma op

* still not working

* fix

* fix

* comments

* yapf

* silly flake8

* small test
This commit is contained in:
Eric Liang 2018-09-19 21:12:37 -07:00 committed by Philipp Moritz
parent b23fd5de13
commit 3267676994
13 changed files with 2383 additions and 0 deletions

View file

View file

@ -0,0 +1,629 @@
# This file is adapted from https://github.com/tensorflow/benchmarks
# /blob/master/scripts/tf_cnn_benchmarks/allreduce.py
#
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for allreduce."""
from __future__ import print_function
import collections as pycoll
import logging
import numpy as np
import re
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib import nccl
from tensorflow.contrib.all_reduce.python import all_reduce
logger = logging.getLogger(__name__)
AllReduceSpecTuple = pycoll.namedtuple('AllReduceSpecTuple',
'alg shards limit')
def parse_general_int(s):
"""Parse integer with power-of-2 suffix eg. 32k."""
mo = re.match(r'(\d+)([KkMGT]?)$', s)
if mo:
i, suffix = mo.group(1, 2)
v = int(i)
if suffix:
if suffix == 'K' or suffix == 'k':
v *= 1024
elif suffix == 'M':
v *= (1024 * 1024)
elif suffix == 'G':
v *= (1024 * 1024 * 1024)
elif suffix == 'T':
v *= (1024 * 1024 * 1024 * 1024)
else:
raise ValueError('invalid integer string %s' % s)
return v
else:
v = int(s)
return v
def parse_all_reduce_spec(all_reduce_spec):
"""Parse all_reduce_spec.
Args:
all_reduce_spec: a string specifying a combination of all-reduce
algorithms to apply for gradient reduction.
Returns:
a list of AllReduceSpecTuple.
Raises:
ValueError: all_reduce_spec is not well-formed.
An all_reduce_spec has BNF form:
int ::= positive whole number
g_int ::= int[KkMGT]?
alg_spec ::= alg | alg#int
range_spec ::= alg_spec | alg_spec/alg_spec
spec ::= range_spec | range_spec:g_int:range_spec
Not all syntactically correct specifications are supported.
Examples of supported all_reduce_spec strings, with semantics explained:
'xring' == apply ring all-reduce to all tensors
'xring#2' == apply ring all-reduce to all tensors, using two simultaneous
transfer rings, each operating on 1/2 of each tensor.
'nccl' == apply NCCL all-reduce to all tensors (only works within
a single worker process where all devices are GPUs)
'nccl/xring' == apply NCCL all-reduce to all tensors within each worker
to produce at least one full-reduced (locally) value,
then apply ring all-reduce to one such value from each
worker, then apply NCCL broadcast to propagate those globally
reduced values back to every device within each worker.
'pscpu' == Shuffle reduce using worker CPUs as the gather devices: each
distributed tensor is reduced by copying all instances to
one of the worker CPUs, computing the reduction there, then
copying back to each participating device. Tensor reductions
are assigned to specific CPUs round-robin.
'psgpu#4' == Arrange all GPUs across all workers into groups of 4.
Each distributed tensor is shuffle reduced against one
such group of 4 GPUs, selected round-robin. That is, each
tensor is split across 4 shards for the reduction.
'pscpu:2k:pscpu#2:64k:xring' == Apply single-shard pscpu to
tensors of size <= 2048 elements, apply 2-shard pscpu to
tensors up to size 64k elements, apply xring to larger tensors.
'pscpu/pscpu#2' == Use shuffle gather to locally reduce each tensor on
the worker's CPU, then use 2-shard shuffle to reduce those
locally reduced tensors across workers (on the worker CPUs), then
scatter the globally reduced values locally from each worker CPU.
"""
range_parts = all_reduce_spec.split(':') + ['-1']
if len(range_parts) % 2:
raise ValueError(
'all_reduce_spec not well formed: %s' % all_reduce_spec)
limit = 0
spec = []
alg = None
shards = 1
for i, range_part in enumerate(range_parts):
if i % 2 == 1:
try:
limit = parse_general_int(range_part)
spec.append(
AllReduceSpecTuple(alg=alg, shards=shards, limit=limit))
except ValueError:
raise ValueError(
'all_reduce_spec (%s) contains non-integer range %s' %
(all_reduce_spec, range_part))
else:
alg = range_part
alg_parts = range_part.split('#')
alg = alg_parts[0]
if len(alg_parts) > 1:
try:
shards = int(alg_parts[1])
except ValueError:
raise ValueError(
'all_reduce_spec (%s) contains non-integer '
'shards %s' % all_reduce_spec, alg_parts[1])
else:
shards = 1
if alg not in [
'nccl', 'nccl/xring', 'nccl/rechd', 'nccl/pscpu', 'xring',
'pscpu', 'psgpu', 'pscpu/pscpu'
]:
raise ValueError('all_reduce_spec (%s) contains invalid alg %s'
% (all_reduce_spec, alg))
return spec
def build_all_reduce_device_prefixes(job_name, num_tasks):
"""Build list of device prefix names for all_reduce.
Args:
job_name: 'worker', 'ps' or 'localhost'.
num_tasks: number of jobs across which device names should be generated.
Returns:
A list of device name prefix strings. Each element spells out the full
host name without adding the device.
e.g. '/job:worker/task:0'
"""
if job_name != 'localhost':
return ['/job:%s/task:%d' % (job_name, d) for d in range(0, num_tasks)]
else:
assert num_tasks == 1
return ['/job:%s' % job_name]
def group_device_names(devices, group_size):
"""Group device names into groups of group_size.
Args:
devices: list of strings naming devices.
group_size: int >= 1
Returns:
list of lists of devices, where each inner list is group_size long,
and each device appears at least once in an inner list. If
len(devices) % group_size = 0 then each device will appear
exactly once.
Raises:
ValueError: group_size > len(devices)
"""
num_devices = len(devices)
if group_size > num_devices:
raise ValueError(
'only %d devices, but group_size=%d' % (num_devices, group_size))
num_groups = (
num_devices // group_size + (1 if
(num_devices % group_size != 0) else 0))
groups = [[] for i in range(num_groups)]
for i in range(0, num_groups * group_size):
groups[i % num_groups].append(devices[i % num_devices])
return groups
def split_grads_by_size(threshold_size, device_grads):
"""Break gradients into two sets according to tensor size.
Args:
threshold_size: int size cutoff for small vs large tensor.
device_grads: List of lists of (gradient, variable) tuples. The outer
list is over devices. The inner list is over individual gradients.
Returns:
small_grads: Subset of device_grads where shape is <= theshold_size
elements.
large_grads: Subset of device_grads where shape is > threshold_size
elements.
"""
small_grads = []
large_grads = []
for dl in device_grads:
small_dl = []
large_dl = []
for (g, v) in dl:
tensor_size = g.get_shape().num_elements()
if tensor_size <= threshold_size:
small_dl.append([g, v])
else:
large_dl.append([g, v])
if small_dl:
small_grads.append(small_dl)
if large_dl:
large_grads.append(large_dl)
return small_grads, large_grads
def build_reduce_sum(scaled_grads):
stacked = tf.parallel_stack(values=scaled_grads)
reduced = tf.reduce_sum(stacked, 0)
return [reduced] * len(scaled_grads)
def build_trivial_sum(scaled_grads):
return scaled_grads
def aggregate_single_gradient(grad_and_vars, use_mean, check_inf_nan):
"""Calculate the average gradient for a shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
grad_and_vars: A list or tuple of (gradient, variable) tuples. Each
(gradient, variable) pair within the outer list represents the gradient
of the variable calculated for a single tower, and the number of pairs
equals the number of towers.
use_mean: if True, mean is taken, else sum of gradients is taken.
check_inf_nan: check grads for nans and infs.
Returns:
The tuple ([(average_gradient, variable),], has_nan_or_inf) where the
gradient has been averaged across all towers. The variable is chosen from
the first tower. The has_nan_or_inf indicates the grads has nan or inf.
"""
grads = [g for g, _ in grad_and_vars]
grad = tf.add_n(grads)
if use_mean and len(grads) > 1:
grad = tf.multiply(grad, 1.0 / len(grads))
v = grad_and_vars[0][1]
if check_inf_nan:
has_nan_or_inf = tf.logical_not(tf.reduce_all(tf.is_finite(grads)))
return (grad, v), has_nan_or_inf
else:
return (grad, v), None
def aggregate_gradients_using_copy_with_device_selection(
tower_grads, avail_devices, use_mean=True, check_inf_nan=False):
"""Aggregate gradients, controlling device for the aggregation.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over towers. The inner list is over individual gradients.
use_mean: if True, mean is taken, else sum of gradients is taken.
check_inf_nan: If true, check grads for nans and infs.
Returns:
The tuple ([(average_gradient, variable),], has_nan_or_inf) where the
gradient has been averaged across all towers. The variable is chosen from
the first tower. The has_nan_or_inf indicates the grads has nan or inf.
"""
agg_grads = []
has_nan_or_inf_list = []
for i, single_grads in enumerate(zip(*tower_grads)):
with tf.device(avail_devices[i % len(avail_devices)]):
grad_and_var, has_nan_or_inf = aggregate_single_gradient(
single_grads, use_mean, check_inf_nan)
agg_grads.append(grad_and_var)
has_nan_or_inf_list.append(has_nan_or_inf)
return agg_grads
def sum_grad_and_var_all_reduce(grad_and_vars,
num_workers,
alg,
gpu_indices,
aux_devices=None,
num_shards=1):
"""Apply all-reduce algorithm over specified gradient tensors."""
with tf.name_scope('allreduce'):
# Note that each grad_and_vars looks like the following:
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
scaled_grads = [g for g, _ in grad_and_vars]
if alg == 'nccl':
summed_grads = nccl.all_sum(scaled_grads)
elif alg == 'simple':
summed_grads = build_reduce_sum(scaled_grads)
elif alg == 'trivial':
summed_grads = build_trivial_sum(scaled_grads)
elif alg == 'xring':
summed_grads = all_reduce.build_ring_all_reduce(
scaled_grads, num_workers, num_shards, gpu_indices, tf.add)
elif alg == 'nccl/xring':
summed_grads = all_reduce.build_nccl_then_ring(
scaled_grads, num_shards, tf.add)
elif alg == 'nccl/rechd':
summed_grads = all_reduce.build_nccl_then_recursive_hd(
scaled_grads, tf.add)
elif alg == 'nccl/pscpu':
summed_grads = all_reduce.build_nccl_then_shuffle(
scaled_grads, aux_devices, tf.add, tf.add_n)
elif alg == 'pscpu/pscpu':
summed_grads = all_reduce.build_shuffle_then_shuffle(
scaled_grads,
aux_devices,
# TODO(tucker): devise a way of better specifying the device
# for the second level.
[aux_devices[0]],
tf.add_n)
elif alg in ['pscpu', 'psgpu']:
summed_grads = all_reduce.build_shuffle_all_reduce(
scaled_grads, aux_devices, tf.add_n)
else:
raise ValueError('unsupported all_reduce alg: ', alg)
result = []
for (_, v), g in zip(grad_and_vars, summed_grads):
result.append([g, v])
return result
def contains_any(haystack, needles):
"""Tests if any needle is a substring of haystack.
Args:
haystack: a string
needles: list of strings
Returns:
True if any element of needles is a substring of haystack,
False otherwise.
"""
for n in needles:
if n in haystack:
return True
return False
def sum_gradients_all_reduce(dev_prefixes,
tower_grads,
num_workers,
alg,
num_shards,
gpu_indices,
agg_small_grads_max_bytes=0):
"""Apply all-reduce algorithm over specified gradient tensors.
Args:
dev_prefixes: list of prefix strings to use to generate PS device names.
tower_grads: the gradients to reduce.
num_workers: number of worker processes across entire job.
alg: the all-reduce algorithm to apply.
num_shards: alg-specific sharding factor.
gpu_indices: indices of local GPUs in order usable for ring-reduce.
agg_small_grads_max_bytes: largest tensor eligible for aggregation,
in number of bytes.
Returns:
list of reduced tensors, packing values
"""
alg_contains_shuffle = contains_any(alg, ['pscpu', 'psgpu'])
is_hierarchical = '/' in alg
if 'pscpu' in alg:
aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes]
elif 'psgpu' in alg:
aux_devices = [
prefix + '/gpu:%d' % i for i in range(len(gpu_indices))
for prefix in dev_prefixes
]
else:
aux_devices = ['/job:localhost/cpu:0']
aux_device_groups = group_device_names(
aux_devices, num_shards if alg_contains_shuffle else 1)
group_index = 0
if agg_small_grads_max_bytes > 0:
tower_grads, packing = pack_small_tensors(
tower_grads, max_bytes=agg_small_grads_max_bytes)
else:
packing = None
new_tower_grads = []
if alg == 'better':
raw_devices = ['/gpu:%i' % (i) for i in gpu_indices]
agg_grads = aggregate_gradients_using_copy_with_device_selection(
tower_grads, raw_devices)
for arr in tower_grads:
new_tower_grads.append(
[(g, v) for (_, v), (g, _) in zip(arr, agg_grads)])
else:
reduced_gv_list = []
for grad_and_vars in zip(*tower_grads):
reduced_gv_list.append(
sum_grad_and_var_all_reduce(
grad_and_vars, num_workers, alg, gpu_indices, aux_devices
if is_hierarchical else aux_device_groups[group_index],
num_shards))
group_index = (group_index + 1) % len(aux_device_groups)
new_tower_grads = [list(x) for x in zip(*reduced_gv_list)]
return new_tower_grads, packing
def print_stats(sizes):
def sizeof_fmt(num, suffix='B'):
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
if abs(num) < 1024.0:
return "%3.1f%s%s" % (num, unit, suffix)
num /= 1024.0
return "%.1f%s%s" % (num, 'Yi', suffix)
stats = {
"avg": np.mean(sizes),
"median": np.median(sizes),
"total size": np.sum(sizes)
}
logger.info("Stats " + ", ".join(
["%s: %s" % (k, sizeof_fmt(v)) for k, v in stats.items()]))
other_stats = {"len": len(sizes)}
logger.info(", ".join(["%s: %f" % (k, v) for k, v in other_stats.items()]))
def extract_ranges(index_list, range_size_limit=32):
"""Extract consecutive ranges and singles from index_list.
Args:
index_list: List of monotone increasing non-negative integers.
range_size_limit: Largest size range to return. If a larger
consecutive range exists it will be returned as multiple
ranges.
Returns:
ranges, singles where ranges is a list of [first, last] pairs of
consecutive elements in index_list, and singles is all of the
other elements, in original order.
"""
if not index_list:
return [], []
first = index_list[0]
last = first
ranges = []
singles = []
for i in index_list[1:]:
if i == last + 1 and (last - first) <= range_size_limit:
last = i
else:
if last > first:
ranges.append([first, last])
else:
singles.append(first)
first = i
last = i
if last > first:
ranges.append([first, last])
else:
singles.append(first)
return ranges, singles
GradPackTuple = pycoll.namedtuple('GradPackTuple', 'indices vars shapes')
def pack_range(key, packing, grad_vars, rng):
"""Form the concatenation of a specified range of gradient tensors.
Args:
key: Value under which to store meta-data in packing that will be used
later to restore the grad_var list structure.
packing: Dict holding data describing packed ranges of small tensors.
grad_vars: List of (grad, var) pairs for one tower.
rng: A pair of integers giving the first, last indices of a consecutive
range of tensors to be packed.
Returns:
A tensor that is the concatenation of all the specified small tensors.
"""
to_pack = grad_vars[rng[0]:rng[1] + 1]
members = []
variables = []
restore_shapes = []
with tf.name_scope('pack'):
for g, v in to_pack:
variables.append(v)
restore_shapes.append(g.shape)
with tf.device(g.device):
members.append(tf.reshape(g, [-1]))
packing[key] = GradPackTuple(
indices=range(rng[0], rng[1] + 1),
vars=variables,
shapes=restore_shapes)
with tf.device(members[0].device):
return tf.concat(members, 0)
def unpack_grad_tuple(gv, gpt):
"""Unpack a previously packed collection of gradient tensors.
Args:
gv: A (grad, var) pair to be unpacked.
gpt: A GradPackTuple describing the packing operation that produced gv.
Returns:
A list of (grad, var) pairs corresponding to the values that were
originally packed into gv, maybe following subsequent operations like
reduction.
"""
elt_widths = [x.num_elements() for x in gpt.shapes]
with tf.device(gv[0][0].device):
with tf.name_scope('unpack'):
splits = tf.split(gv[0], elt_widths)
unpacked_gv = []
for idx, s in enumerate(splits):
unpacked_gv.append((tf.reshape(s, gpt.shapes[idx]),
gpt.vars[idx]))
return unpacked_gv
def pack_small_tensors(tower_grads, max_bytes=0):
"""Concatenate gradients together more intelligently.
Does binpacking
Args:
tower_grads: List of lists of (gradient, variable) tuples.
max_bytes: Int giving max number of bytes in a tensor that
may be considered small.
"""
assert max_bytes >= 0
orig_grads = [g for g, _ in tower_grads[0]]
# Check to make sure sizes are accurate; not entirely important
assert all(g.dtype == tf.float32 for g in orig_grads)
sizes = [4 * g.shape.num_elements() for g in orig_grads]
print_stats(sizes)
small_ranges = []
large_indices = []
new_sizes = []
def end_interval(indices, small_ranges, large_indices):
if len(indices) > 1:
small_ranges.insert(0, [indices[0], indices[-1]])
else:
large_indices.insert(0, indices[0])
cur_range = []
cur_size = 0
for i, s in reversed(list(enumerate(sizes))):
if cur_size > max_bytes:
end_interval(cur_range, small_ranges, large_indices)
new_sizes.insert(0, cur_size)
cur_range = []
cur_size = 0
cur_range.insert(0, i)
cur_size += s
end_interval(cur_range, small_ranges, large_indices)
new_sizes.insert(0, cur_size)
print_stats(new_sizes)
num_gv = len(orig_grads)
packing = {}
if len(small_ranges):
new_tower_grads = []
for dev_idx, gv_list in enumerate(tower_grads):
assert len(gv_list) == num_gv
new_gv_list = []
for r in small_ranges:
key = '%d:%d' % (dev_idx, len(new_gv_list))
new_gv_list.append((pack_range(key, packing, gv_list, r),
'packing_var_placeholder'))
for i in large_indices:
new_gv_list.append(gv_list[i])
new_tower_grads.append(new_gv_list)
return new_tower_grads, packing
else:
return tower_grads, None
def unpack_small_tensors(tower_grads, packing):
"""Undo the structure alterations to tower_grads done by pack_small_tensors.
Args:
tower_grads: List of List of (grad, var) tuples.
packing: A dict generated by pack_small_tensors describing the changes
it made to tower_grads.
Returns:
new_tower_grads: identical to tower_grads except that concatentations
of small tensors have been split apart and returned to their original
positions, paired with their original variables.
"""
if not packing:
return tower_grads
new_tower_grads = []
num_devices = len(tower_grads)
num_packed = len(packing.keys()) // num_devices
for dev_idx, gv_list in enumerate(tower_grads):
new_gv_list = gv_list[num_packed:]
for i in xrange(0, num_packed):
k = '%d:%d' % (dev_idx, i)
gpt = packing[k]
gv = unpack_grad_tuple(gv_list[i], gpt)
for gi, idx in enumerate(gpt.indices):
assert idx == gpt.indices[gi]
new_gv_list.insert(idx, gv[gi])
new_tower_grads.append(new_gv_list)
return new_tower_grads

View file

@ -0,0 +1,446 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import random
import time
import numpy as np
import pyarrow.plasma as plasma
import tensorflow as tf
import ray
from ray.experimental.sgd.util import Timeline, fetch, run_timeline
from ray.experimental.sgd.modified_allreduce import sum_gradients_all_reduce, \
unpack_small_tensors
logger = logging.getLogger(__name__)
class SGDWorker(object):
def __init__(self,
worker_index,
model_creator,
all_reduce_alg="simple",
num_devices=1,
use_cpus=False,
max_bytes=60000000,
plasma_op=False):
self.worker_index = worker_index
assert num_devices > 0
# TODO(ekl) support custom session
tf_session_args = {
"device_count": {
"CPU": num_devices
},
"log_device_placement": False,
"gpu_options": tf.GPUOptions(force_gpu_compatible=True),
"inter_op_parallelism_threads": 128,
}
config_proto = tf.ConfigProto(**tf_session_args)
self.sess = tf.Session(config=config_proto)
self.models = []
grad_ops = []
if use_cpus:
device_tmpl = "/cpu:%d"
else:
device_tmpl = "/gpu:%d"
for device_idx in range(num_devices):
device = device_tmpl % device_idx
with tf.device(device):
with tf.variable_scope("device_%d" % device_idx):
model = model_creator(worker_index, device_idx)
self.models.append(model)
model.grads = [
t
for t in model.optimizer.compute_gradients(model.loss)
if t[0] is not None
]
grad_ops.append(model.grads)
if num_devices == 1:
assert not max_bytes, "Not supported with 1 GPU"
self.packed_grads_and_vars = grad_ops
else:
if max_bytes:
self.packed_grads_and_vars, packing_vals = (
sum_gradients_all_reduce(
"",
grad_ops,
1,
all_reduce_alg,
1,
list(range(num_devices)),
agg_small_grads_max_bytes=max_bytes))
else:
self.packed_grads_and_vars, _ = (sum_gradients_all_reduce(
"",
grad_ops,
1,
all_reduce_alg,
1,
list(range(num_devices)),
agg_small_grads_max_bytes=0))
self.per_device_grads = [
list(zip(*dev_gv))[0] for dev_gv in self.packed_grads_and_vars
]
assert (len(self.per_device_grads) == num_devices)
self.num_grads = num_grads = len(self.packed_grads_and_vars[0])
if max_bytes:
logger.info("Packed grads => {} tensors".format(num_grads))
# Ops for reading grads with the right control deps
nccl_noops = []
for j in range(num_grads)[::-1]:
deps = nccl_noops + [
dev_grad[j] for dev_grad in self.per_device_grads
]
with tf.control_dependencies(deps):
nccl_noops = [tf.no_op()]
# You must fetch this otherwise the NCCL allreduce will hang
self.nccl_control_out = tf.group(*nccl_noops)
round_robin_devices = False
if plasma_op:
store_socket = (
ray.worker.global_worker.plasma_client.store_socket_name)
manager_socket = (
ray.worker.global_worker.plasma_client.manager_socket_name)
if not plasma.tf_plasma_op:
plasma.build_plasma_tensorflow_op()
# For fetching grads -> plasma
self.plasma_in_grads = []
self.plasma_in_grads_oids = [
tf.placeholder(shape=[], dtype=tf.string, name="in_grad_oids")
for _ in range(num_grads)
]
ix = 0
for j in range(num_grads):
grad = self.per_device_grads[ix][j]
if round_robin_devices:
ix += 1 # round robin assignment
ix %= num_devices
with tf.device(self.models[ix].loss.device):
plasma_grad = plasma.tf_plasma_op.tensor_to_plasma(
[grad],
self.plasma_in_grads_oids[j],
plasma_store_socket_name=store_socket,
plasma_manager_socket_name=manager_socket)
self.plasma_in_grads.append(plasma_grad)
# For applying grads <- plasma
unpacked_gv = []
self.plasma_out_grads_oids = [
tf.placeholder(
shape=[], dtype=tf.string, name="grad_out_oids")
for _ in range(num_grads)
]
packed_plasma_grads = []
ix = 0
for j in range(num_grads):
with tf.device(self.plasma_in_grads[j].device):
with tf.control_dependencies([self.plasma_in_grads[j]]):
grad_ph = plasma.tf_plasma_op.plasma_to_tensor(
self.plasma_out_grads_oids[j],
dtype=tf.float32,
plasma_store_socket_name=store_socket,
plasma_manager_socket_name=manager_socket)
grad_ph = tf.reshape(grad_ph,
self.packed_grads_and_vars[0][j][0].shape)
logger.debug("Packed tensor {}".format(grad_ph))
packed_plasma_grads.append(grad_ph)
for i in range(num_devices):
per_device = []
for j, (g, v) in enumerate(self.packed_grads_and_vars[i]):
grad_ph = packed_plasma_grads[j]
per_device.append((grad_ph, v))
unpacked_gv.append(per_device)
if max_bytes:
unpacked_gv = unpack_small_tensors(unpacked_gv, packing_vals)
elif max_bytes:
unpacked_gv = unpack_small_tensors(self.packed_grads_and_vars,
packing_vals)
else:
unpacked_gv = self.packed_grads_and_vars
# Same shape as packed_grads_and_vars
assert len(unpacked_gv) == num_devices
assert len(unpacked_gv[0][0]) == 2
apply_ops = []
to_apply = unpacked_gv[0]
for ix, m in enumerate(self.models):
apply_ops.append(
m.optimizer.apply_gradients(
[(g, v)
for ((g, _), (_, v)) in zip(to_apply, unpacked_gv[ix])]))
self.apply_op = tf.group(*apply_ops)
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
self.sess.run(init_op)
def foreach_model(self, fn):
return [fn(m) for m in self.models]
def foreach_worker(self, fn):
return fn(self)
def compute_gradients(self):
start = time.time()
feed_dict = {}
# Aggregate feed dicts for each model on this worker.
for model in self.models:
feed_dict.update(model.get_feed_dict())
# We only need to fetch the first per_device_grad, since they are
# averaged across all devices by allreduce.
fetches = self.sess.run(
[
self.models[0].loss, self.per_device_grads[0],
self.nccl_control_out
],
feed_dict=feed_dict)
logger.debug(
"compute grad interior time {}".format(time.time() - start))
return fetches
def apply_gradients(self, avg_grads):
start = time.time()
result = {
g: avg_grads[i]
for (i, g) in enumerate(self.per_device_grads[0])
}
self.sess.run(self.apply_op, feed_dict=result)
logger.debug("apply grad interior time {}".format(time.time() - start))
def ps_compute_apply(self,
out_grad_shard_oids,
agg_grad_shard_oids,
tl_name="ps_compute_apply",
write_timeline=False):
feed_dict = {
ph: oid
for (ph,
oid) in zip(self.plasma_in_grads_oids, out_grad_shard_oids)
}
feed_dict.update({
ph: oid
for (ph,
oid) in zip(self.plasma_out_grads_oids, agg_grad_shard_oids)
})
fetch(agg_grad_shard_oids)
run_timeline(
self.sess,
[self.plasma_in_grads, self.apply_op, self.nccl_control_out],
feed_dict=feed_dict,
write_timeline=write_timeline)
def num_grad_shards(self):
return self.num_grads
def shard_shapes(self):
main_gv = self.packed_grads_and_vars[0]
return [g.shape for g, _ in main_gv]
def ip(self):
return ray.services.get_node_ip_address()
class ParameterServer(object):
def __init__(self, num_workers, tid):
self.num_sgd_workers = num_workers
self.acc_counter = 0
self.timeline = Timeline(tid)
self.timeline.patch_ray()
def set_tid(self, tid):
self.timeline.tid = tid
def get_time(self):
return time.time() + self.timeline.offset
def set_time(self, ref_time):
self.timeline.offset = ref_time - time.time()
def initialize(self, shard_shape):
self.accumulated = np.zeros(shard_shape, dtype=np.float32)
def mark(self):
self.timeline.event("mark")
def prefetch(self, oids):
self.timeline.reset()
self.timeline.start("prefetch")
fetch(oids)
self.timeline.end("prefetch")
def add_spinwait(self, grad_shard_ids):
self.timeline.start("add_spinwait")
plasma_ids = [ray.pyarrow.plasma.ObjectID(x) for x in grad_shard_ids]
while plasma_ids:
for p in plasma_ids:
if ray.worker.global_worker.plasma_client.contains(p):
self.timeline.start("get_buffers")
grads = ray.worker.global_worker.plasma_client.get(p)
self.accumulated += grads
self.acc_counter += 1
self.timeline.end("get_buffers")
plasma_ids.remove(p)
break
self.timeline.end("add_spinwait")
def add(self, grad_shard_id):
self.timeline.start("add")
self.timeline.start("get_buffers")
oid = ray.pyarrow.plasma.ObjectID(grad_shard_id)
grads = ray.worker.global_worker.plasma_client.get(oid)
self.timeline.end("get_buffers")
self.accumulated += grads
self.acc_counter += 1
self.timeline.end("add")
def get(self, object_id):
self.timeline.start("get")
client = ray.worker.global_worker.plasma_client
assert self.acc_counter == self.num_sgd_workers, self.acc_counter
oid = ray.pyarrow.plasma.ObjectID(object_id)
client.put(self.accumulate.flatten(), object_id=oid)
self.accumulated = np.zeros_like(self.accumulated)
self.acc_counter = 0
self.timeline.end("get")
def get_timeline(self):
return self.timeline
def ip(self):
return ray.services.get_node_ip_address()
def pin(self, cpu_id):
try:
import psutil
p = psutil.Process()
p.cpu_affinity([cpu_id])
logger.info("Setting CPU Affinity to: {}".format(cpu_id))
except Exception as e:
logger.error(e)
def average_gradients(grads):
out = []
for grad_list in zip(*grads):
out.append(np.mean(grad_list, axis=0))
return out
def do_sgd_step(actors):
start = time.time()
fetches = ray.get([a.compute_gradients.remote() for a in actors])
losses = [f[0] for f in fetches]
grads = [f[1] for f in fetches]
logger.debug("compute all grads time {}".format(time.time() - start))
start = time.time()
if len(actors) == 1:
assert len(grads) == 1
avg_grad = grads[0]
else:
avg_grad = average_gradients(grads)
logger.debug("grad reduce time {}".format(time.time() - start))
start = time.time()
ray.get([a.apply_gradients.remote(avg_grad) for a in actors])
logger.debug("apply all grads time {}".format(time.time() - start))
return np.mean(losses)
def distributed_sgd_step(actors, ps_list, write_timeline):
# Preallocate object ids that actors will write gradient shards to
grad_shard_oids_list = [[np.random.bytes(20) for _ in ps_list]
for _ in actors]
logger.info("generated grad oids")
# Preallocate object ids that param servers will write new weights to
accum_shard_ids = [np.random.bytes(20) for _ in ps_list]
logger.info("generated accum oids")
# Kick off the fused compute grad / update weights tf run for each actor
for actor, grad_shard_oids in zip(actors, grad_shard_oids_list):
actor.ps_compute_apply.remote(
grad_shard_oids, accum_shard_ids, write_timeline=write_timeline)
logger.info("Launched all ps_compute_applys on all actors")
# Issue prefetch ops
for j, (ps, weight_shard_oid) in list(
enumerate(zip(ps_list, accum_shard_ids)))[::-1]:
to_fetch = []
for grad_shard_oids in grad_shard_oids_list:
to_fetch.append(grad_shard_oids[j])
random.shuffle(to_fetch)
ps.prefetch.remote(to_fetch)
logger.info("Launched all prefetch ops")
# Aggregate the gradients produced by the actors. These operations
# run concurrently with the actor methods above.
ps_gets = []
for j, (ps, weight_shard_oid) in list(
enumerate(zip(ps_list, accum_shard_ids)))[::-1]:
ps.add_spinwait.remote([gs[j] for gs in grad_shard_oids_list])
ps_gets.append(ps.get.remote(weight_shard_oid))
logger.info("Launched all aggregate ops")
if write_timeline:
timelines = [ps.get_timeline.remote() for ps in ps_list]
logger.info("launched timeline gets")
timelines = ray.get(timelines)
t0 = timelines[0]
for t in timelines[1:]:
t0.merge(t)
t0.chrome_trace_format("ps_timeline.json")
else:
# Wait for at least the ps gets to finish
ray.get(ps_gets)
class DistributedSGD(object):
def __init__(self,
model_creator,
num_workers,
devices_per_worker,
use_cpus=False,
use_plasma_op=False):
self.model_creator = model_creator
if use_cpus:
requests = {"num_cpus": devices_per_worker}
else:
requests = {"num_gpus": devices_per_worker}
RemoteSGDWorker = ray.remote(**requests)(SGDWorker)
self.workers = []
for worker_index in range(num_workers):
logger.info("Creating worker {}".format(worker_index))
self.workers.append(
RemoteSGDWorker.remote(
worker_index,
model_creator,
num_devices=devices_per_worker,
plasma_op=use_plasma_op,
use_cpus=use_cpus))
assert not use_plasma_op, \
"TODO: when use_plasma_op is true, we must run in PS mode"
def foreach_worker(self, fn):
results = ray.get([w.foreach_worker.remote(fn) for w in self.workers])
return results
def foreach_model(self, fn):
results = ray.get([w.foreach_model.remote(fn) for w in self.workers])
out = []
for r in results:
out.extend(r)
return r
def step(self):
return do_sgd_step(self.workers)

View file

@ -0,0 +1,32 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ray
from ray.experimental.sgd.tfbench.test_model import TFBenchModel
from ray.experimental.sgd.sgd import DistributedSGD
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-iters", default=100, type=int, help="Number of iterations to run")
if __name__ == "__main__":
ray.init()
args, _ = parser.parse_known_args()
model_creator = (
lambda worker_idx, device_idx: TFBenchModel(batch=1, use_cpus=True))
sgd = DistributedSGD(
model_creator,
num_workers=2,
devices_per_worker=2,
use_cpus=True,
use_plasma_op=False)
for _ in range(args.num_iters):
loss = sgd.step()
print("Current loss", loss)

View file

@ -0,0 +1 @@
Files in this directory are adapted from https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks.

View file

@ -0,0 +1,507 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""CNN builder."""
from __future__ import print_function
from collections import defaultdict
import contextlib
import numpy as np
import tensorflow as tf
from tensorflow.python.layers import convolutional as conv_layers
from tensorflow.python.layers import core as core_layers
from tensorflow.python.layers import pooling as pooling_layers
from tensorflow.python.training import moving_averages
class ConvNetBuilder(object):
"""Builder of cnn net."""
def __init__(self,
input_op,
input_nchan,
phase_train,
use_tf_layers,
data_format='NCHW',
dtype=tf.float32,
variable_dtype=tf.float32):
self.top_layer = input_op
self.top_size = input_nchan
self.phase_train = phase_train
self.use_tf_layers = use_tf_layers
self.data_format = data_format
self.dtype = dtype
self.variable_dtype = variable_dtype
self.counts = defaultdict(lambda: 0)
self.use_batch_norm = False
self.batch_norm_config = {} # 'decay': 0.997, 'scale': True}
self.channel_pos = ('channels_last'
if data_format == 'NHWC' else 'channels_first')
self.aux_top_layer = None
self.aux_top_size = 0
def get_custom_getter(self):
"""Returns a custom getter that this class's methods must be called
All methods of this class must be called under a variable scope that was
passed this custom getter. Example:
```python
network = ConvNetBuilder(...)
with tf.variable_scope('cg', custom_getter=network.get_custom_getter()):
network.conv(...)
# Call more methods of network here
```
Currently, this custom getter only does anything if self.use_tf_layers is
True. In that case, it causes variables to be stored as dtype
self.variable_type, then casted to the requested dtype, instead of directly
storing the variable as the requested dtype.
"""
def inner_custom_getter(getter, *args, **kwargs):
if not self.use_tf_layers:
return getter(*args, **kwargs)
requested_dtype = kwargs['dtype']
if not (requested_dtype == tf.float32
and self.variable_dtype == tf.float16):
kwargs['dtype'] = self.variable_dtype
var = getter(*args, **kwargs)
if var.dtype.base_dtype != requested_dtype:
var = tf.cast(var, requested_dtype)
return var
return inner_custom_getter
@contextlib.contextmanager
def switch_to_aux_top_layer(self):
"""Context that construct cnn in the auxiliary arm."""
if self.aux_top_layer is None:
raise RuntimeError('Empty auxiliary top layer in the network.')
saved_top_layer = self.top_layer
saved_top_size = self.top_size
self.top_layer = self.aux_top_layer
self.top_size = self.aux_top_size
yield
self.aux_top_layer = self.top_layer
self.aux_top_size = self.top_size
self.top_layer = saved_top_layer
self.top_size = saved_top_size
def get_variable(self, name, shape, dtype, cast_dtype, *args, **kwargs):
var = tf.get_variable(name, shape, dtype, *args, **kwargs)
return tf.cast(var, cast_dtype)
def _conv2d_impl(self, input_layer, num_channels_in, filters, kernel_size,
strides, padding, kernel_initializer):
if self.use_tf_layers:
return conv_layers.conv2d(
input_layer,
filters,
kernel_size,
strides,
padding,
self.channel_pos,
kernel_initializer=kernel_initializer,
use_bias=False)
else:
weights_shape = [
kernel_size[0], kernel_size[1], num_channels_in, filters
]
weights = self.get_variable(
'conv2d/kernel',
weights_shape,
self.variable_dtype,
self.dtype,
initializer=kernel_initializer)
if self.data_format == 'NHWC':
strides = [1] + strides + [1]
else:
strides = [1, 1] + strides
return tf.nn.conv2d(
input_layer,
weights,
strides,
padding,
data_format=self.data_format)
def conv(self,
num_out_channels,
k_height,
k_width,
d_height=1,
d_width=1,
mode='SAME',
input_layer=None,
num_channels_in=None,
use_batch_norm=None,
stddev=None,
activation='relu',
bias=0.0):
"""Construct a conv2d layer on top of cnn."""
if input_layer is None:
input_layer = self.top_layer
if num_channels_in is None:
num_channels_in = self.top_size
kernel_initializer = None
if stddev is not None:
kernel_initializer = tf.truncated_normal_initializer(stddev=stddev)
name = 'conv' + str(self.counts['conv'])
self.counts['conv'] += 1
with tf.variable_scope(name):
strides = [1, d_height, d_width, 1]
if self.data_format == 'NCHW':
strides = [strides[0], strides[3], strides[1], strides[2]]
if mode != 'SAME_RESNET':
conv = self._conv2d_impl(
input_layer,
num_channels_in,
num_out_channels,
kernel_size=[k_height, k_width],
strides=[d_height, d_width],
padding=mode,
kernel_initializer=kernel_initializer)
else: # Special padding mode for ResNet models
if d_height == 1 and d_width == 1:
conv = self._conv2d_impl(
input_layer,
num_channels_in,
num_out_channels,
kernel_size=[k_height, k_width],
strides=[d_height, d_width],
padding='SAME',
kernel_initializer=kernel_initializer)
else:
rate = 1 # Unused (for 'a trous' convolutions)
kernel_height_effective = k_height + (k_height - 1) * (
rate - 1)
pad_h_beg = (kernel_height_effective - 1) // 2
pad_h_end = kernel_height_effective - 1 - pad_h_beg
kernel_width_effective = k_width + (k_width - 1) * (
rate - 1)
pad_w_beg = (kernel_width_effective - 1) // 2
pad_w_end = kernel_width_effective - 1 - pad_w_beg
padding = [[0, 0], [pad_h_beg, pad_h_end],
[pad_w_beg, pad_w_end], [0, 0]]
if self.data_format == 'NCHW':
padding = [
padding[0], padding[3], padding[1], padding[2]
]
input_layer = tf.pad(input_layer, padding)
conv = self._conv2d_impl(
input_layer,
num_channels_in,
num_out_channels,
kernel_size=[k_height, k_width],
strides=[d_height, d_width],
padding='VALID',
kernel_initializer=kernel_initializer)
if use_batch_norm is None:
use_batch_norm = self.use_batch_norm
if not use_batch_norm:
if bias is not None:
biases = self.get_variable(
'biases', [num_out_channels],
self.variable_dtype,
self.dtype,
initializer=tf.constant_initializer(bias))
biased = tf.reshape(
tf.nn.bias_add(
conv, biases, data_format=self.data_format),
conv.get_shape())
else:
biased = conv
else:
self.top_layer = conv
self.top_size = num_out_channels
biased = self.batch_norm(**self.batch_norm_config)
if activation == 'relu':
conv1 = tf.nn.relu(biased)
elif activation == 'linear' or activation is None:
conv1 = biased
elif activation == 'tanh':
conv1 = tf.nn.tanh(biased)
else:
raise KeyError('Invalid activation type \'%s\'' % activation)
self.top_layer = conv1
self.top_size = num_out_channels
return conv1
def _pool(self, pool_name, pool_function, k_height, k_width, d_height,
d_width, mode, input_layer, num_channels_in):
"""Construct a pooling layer."""
if input_layer is None:
input_layer = self.top_layer
else:
self.top_size = num_channels_in
name = pool_name + str(self.counts[pool_name])
self.counts[pool_name] += 1
if self.use_tf_layers:
pool = pool_function(
input_layer, [k_height, k_width], [d_height, d_width],
padding=mode,
data_format=self.channel_pos,
name=name)
else:
if self.data_format == 'NHWC':
ksize = [1, k_height, k_width, 1]
strides = [1, d_height, d_width, 1]
else:
ksize = [1, 1, k_height, k_width]
strides = [1, 1, d_height, d_width]
pool = tf.nn.max_pool(
input_layer,
ksize,
strides,
padding=mode,
data_format=self.data_format,
name=name)
self.top_layer = pool
return pool
def mpool(self,
k_height,
k_width,
d_height=2,
d_width=2,
mode='VALID',
input_layer=None,
num_channels_in=None):
"""Construct a max pooling layer."""
return self._pool('mpool', pooling_layers.max_pooling2d, k_height,
k_width, d_height, d_width, mode, input_layer,
num_channels_in)
def apool(self,
k_height,
k_width,
d_height=2,
d_width=2,
mode='VALID',
input_layer=None,
num_channels_in=None):
"""Construct an average pooling layer."""
return self._pool('apool', pooling_layers.average_pooling2d, k_height,
k_width, d_height, d_width, mode, input_layer,
num_channels_in)
def reshape(self, shape, input_layer=None):
if input_layer is None:
input_layer = self.top_layer
self.top_layer = tf.reshape(input_layer, shape)
self.top_size = shape[-1] # HACK This may not always work
return self.top_layer
def affine(self,
num_out_channels,
input_layer=None,
num_channels_in=None,
bias=0.0,
stddev=None,
activation='relu'):
if input_layer is None:
input_layer = self.top_layer
if num_channels_in is None:
num_channels_in = self.top_size
name = 'affine' + str(self.counts['affine'])
self.counts['affine'] += 1
with tf.variable_scope(name):
init_factor = 2. if activation == 'relu' else 1.
stddev = stddev or np.sqrt(init_factor / num_channels_in)
kernel = self.get_variable(
'weights', [num_channels_in, num_out_channels],
self.variable_dtype,
self.dtype,
initializer=tf.truncated_normal_initializer(stddev=stddev))
biases = self.get_variable(
'biases', [num_out_channels],
self.variable_dtype,
self.dtype,
initializer=tf.constant_initializer(bias))
logits = tf.nn.xw_plus_b(input_layer, kernel, biases)
if activation == 'relu':
affine1 = tf.nn.relu(logits, name=name)
elif activation == 'linear' or activation is None:
affine1 = logits
else:
raise KeyError('Invalid activation type \'%s\'' % activation)
self.top_layer = affine1
self.top_size = num_out_channels
return affine1
def inception_module(self, name, cols, input_layer=None, in_size=None):
if input_layer is None:
input_layer = self.top_layer
if in_size is None:
in_size = self.top_size
name += str(self.counts[name])
self.counts[name] += 1
with tf.variable_scope(name):
col_layers = []
col_layer_sizes = []
for c, col in enumerate(cols):
col_layers.append([])
col_layer_sizes.append([])
for lx, layer in enumerate(col):
ltype, args = layer[0], layer[1:]
kwargs = {
'input_layer': input_layer,
'num_channels_in': in_size
} if lx == 0 else {}
if ltype == 'conv':
self.conv(*args, **kwargs)
elif ltype == 'mpool':
self.mpool(*args, **kwargs)
elif ltype == 'apool':
self.apool(*args, **kwargs)
elif ltype == 'share':
self.top_layer = col_layers[c - 1][lx]
self.top_size = col_layer_sizes[c - 1][lx]
else:
raise KeyError(
'Invalid layer type for inception module: \'%s\'' %
ltype)
col_layers[c].append(self.top_layer)
col_layer_sizes[c].append(self.top_size)
catdim = 3 if self.data_format == 'NHWC' else 1
self.top_layer = tf.concat([layers[-1] for layers in col_layers],
catdim)
self.top_size = sum(sizes[-1] for sizes in col_layer_sizes)
return self.top_layer
def spatial_mean(self, keep_dims=False):
name = 'spatial_mean' + str(self.counts['spatial_mean'])
self.counts['spatial_mean'] += 1
axes = [1, 2] if self.data_format == 'NHWC' else [2, 3]
self.top_layer = tf.reduce_mean(
self.top_layer, axes, keep_dims=keep_dims, name=name)
return self.top_layer
def dropout(self, keep_prob=0.5, input_layer=None):
if input_layer is None:
input_layer = self.top_layer
else:
self.top_size = None
name = 'dropout' + str(self.counts['dropout'])
with tf.variable_scope(name):
if not self.phase_train:
keep_prob = 1.0
if self.use_tf_layers:
dropout = core_layers.dropout(input_layer, 1. - keep_prob)
else:
dropout = tf.nn.dropout(input_layer, keep_prob)
self.top_layer = dropout
return dropout
def _batch_norm_without_layers(self, input_layer, decay, use_scale,
epsilon):
"""Batch normalization on `input_layer` without tf.layers."""
shape = input_layer.shape
num_channels = shape[3] if self.data_format == 'NHWC' else shape[1]
beta = self.get_variable(
'beta', [num_channels],
tf.float32,
tf.float32,
initializer=tf.zeros_initializer())
if use_scale:
gamma = self.get_variable(
'gamma', [num_channels],
tf.float32,
tf.float32,
initializer=tf.ones_initializer())
else:
gamma = tf.constant(1.0, tf.float32, [num_channels])
moving_mean = tf.get_variable(
'moving_mean', [num_channels],
tf.float32,
initializer=tf.zeros_initializer(),
trainable=False)
moving_variance = tf.get_variable(
'moving_variance', [num_channels],
tf.float32,
initializer=tf.ones_initializer(),
trainable=False)
if self.phase_train:
bn, batch_mean, batch_variance = tf.nn.fused_batch_norm(
input_layer,
gamma,
beta,
epsilon=epsilon,
data_format=self.data_format,
is_training=True)
mean_update = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay=decay, zero_debias=False)
variance_update = moving_averages.assign_moving_average(
moving_variance,
batch_variance,
decay=decay,
zero_debias=False)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, mean_update)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, variance_update)
else:
bn, _, _ = tf.nn.fused_batch_norm(
input_layer,
gamma,
beta,
mean=moving_mean,
variance=moving_variance,
epsilon=epsilon,
data_format=self.data_format,
is_training=False)
return bn
def batch_norm(self,
input_layer=None,
decay=0.999,
scale=False,
epsilon=0.001):
"""Adds a Batch Normalization layer."""
if input_layer is None:
input_layer = self.top_layer
else:
self.top_size = None
name = 'batchnorm' + str(self.counts['batchnorm'])
self.counts['batchnorm'] += 1
with tf.variable_scope(name) as scope:
if self.use_tf_layers:
bn = tf.contrib.layers.batch_norm(
input_layer,
decay=decay,
scale=scale,
epsilon=epsilon,
is_training=self.phase_train,
fused=True,
data_format=self.data_format,
scope=scope)
else:
bn = self._batch_norm_without_layers(input_layer, decay, scale,
epsilon)
self.top_layer = bn
self.top_size = bn.shape[
3] if self.data_format == 'NHWC' else bn.shape[1]
self.top_size = int(self.top_size)
return bn
def lrn(self, depth_radius, bias, alpha, beta):
"""Adds a local response normalization layer."""
name = 'lrn' + str(self.counts['lrn'])
self.counts['lrn'] += 1
self.top_layer = tf.nn.lrn(
self.top_layer, depth_radius, bias, alpha, beta, name=name)
return self.top_layer

View file

@ -0,0 +1,116 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Base model configuration for CNN benchmarks."""
import tensorflow as tf
from . import convnet_builder
class Model(object):
"""Base model configuration for CNN benchmarks."""
def __init__(self,
model,
image_size,
batch_size,
learning_rate,
layer_counts=None,
fp16_loss_scale=128):
self.model = model
self.image_size = image_size
self.batch_size = batch_size
self.default_batch_size = batch_size
self.learning_rate = learning_rate
self.layer_counts = layer_counts
self.fp16_loss_scale = fp16_loss_scale
def get_model(self):
return self.model
def get_image_size(self):
return self.image_size
def get_batch_size(self):
return self.batch_size
def set_batch_size(self, batch_size):
self.batch_size = batch_size
def get_default_batch_size(self):
return self.default_batch_size
def get_layer_counts(self):
return self.layer_counts
def get_fp16_loss_scale(self):
return self.fp16_loss_scale
def get_learning_rate(self, global_step, batch_size):
del global_step
del batch_size
return self.learning_rate
def add_inference(self, unused_cnn):
raise ValueError('Must be implemented in derived classes')
def skip_final_affine_layer(self):
"""Returns if the caller of this class should skip the final affine
Normally, this class adds a final affine layer to the model after calling
self.add_inference(), to generate the logits. If a subclass override this
method to return True, the caller should not add the final affine layer.
This is useful for tests.
"""
return False
def build_network(self,
images,
phase_train=True,
nclass=1001,
image_depth=3,
data_type=tf.float32,
data_format='NCHW',
use_tf_layers=True,
fp16_vars=False):
"""Returns logits and aux_logits from images."""
if data_format == 'NCHW':
images = tf.transpose(images, [0, 3, 1, 2])
var_type = tf.float32
if data_type == tf.float16 and fp16_vars:
var_type = tf.float16
network = convnet_builder.ConvNetBuilder(
images, image_depth, phase_train, use_tf_layers, data_format,
data_type, var_type)
with tf.variable_scope(
'cg', custom_getter=network.get_custom_getter()):
self.add_inference(network)
# Add the final fully-connected class layer
logits = (network.affine(nclass, activation='linear')
if not self.skip_final_affine_layer() else
network.top_layer)
aux_logits = None
if network.aux_top_layer is not None:
with network.switch_to_aux_top_layer():
aux_logits = network.affine(
nclass, activation='linear', stddev=0.001)
if data_type == tf.float16:
# TODO(reedwm): Determine if we should do this cast here.
logits = tf.cast(logits, tf.float32)
if aux_logits is not None:
aux_logits = tf.cast(aux_logits, tf.float32)
return logits, aux_logits
loss_function = None

View file

@ -0,0 +1,57 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model configurations for CNN benchmarks.
"""
from . import resnet_model
_model_name_to_imagenet_model = {
'resnet50': resnet_model.create_resnet50_model,
'resnet50_v2': resnet_model.create_resnet50_v2_model,
'resnet101': resnet_model.create_resnet101_model,
'resnet101_v2': resnet_model.create_resnet101_v2_model,
'resnet152': resnet_model.create_resnet152_model,
'resnet152_v2': resnet_model.create_resnet152_v2_model,
}
_model_name_to_cifar_model = {}
def _get_model_map(dataset_name):
if 'cifar10' == dataset_name:
return _model_name_to_cifar_model
elif dataset_name in ('imagenet', 'synthetic'):
return _model_name_to_imagenet_model
else:
raise ValueError('Invalid dataset name: %s' % dataset_name)
def get_model_config(model_name, dataset):
"""Map model name to model network configuration."""
model_map = _get_model_map(dataset.name)
if model_name not in model_map:
raise ValueError('Invalid model name \'%s\' for dataset \'%s\'' %
(model_name, dataset.name))
else:
return model_map[model_name]()
def register_model(model_name, dataset_name, model_func):
"""Register a new model that can be obtained with `get_model_config`."""
model_map = _get_model_map(dataset_name)
if model_name in model_map:
raise ValueError('Model "%s" is already registered for dataset "%s"' %
(model_name, dataset_name))
model_map[model_name] = model_func

View file

@ -0,0 +1,422 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Resnet model configuration.
References:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition
arXiv:1512.03385 (2015)
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Identity Mappings in Deep Residual Networks
arXiv:1603.05027 (2016)
Liang-Chieh Chen, George Papandreou, Iasonas Kokkinos, Kevin Murphy,
Alan L. Yuille
DeepLab: Semantic Image Segmentation with Deep Convolutional Nets,
Atrous Convolution, and Fully Connected CRFs
arXiv:1606.00915 (2016)
"""
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from . import model as model_lib
def bottleneck_block_v1(cnn, depth, depth_bottleneck, stride):
"""Bottleneck block with identity short-cut for ResNet v1.
Args:
cnn: the network to append bottleneck blocks.
depth: the number of output filters for this bottleneck block.
depth_bottleneck: the number of bottleneck filters for this block.
stride: Stride used in the first layer of the bottleneck block.
"""
input_layer = cnn.top_layer
in_size = cnn.top_size
name_key = 'resnet_v1'
name = name_key + str(cnn.counts[name_key])
cnn.counts[name_key] += 1
with tf.variable_scope(name):
if depth == in_size:
if stride == 1:
shortcut = input_layer
else:
shortcut = cnn.apool(
1,
1,
stride,
stride,
input_layer=input_layer,
num_channels_in=in_size)
else:
shortcut = cnn.conv(
depth,
1,
1,
stride,
stride,
activation=None,
use_batch_norm=True,
input_layer=input_layer,
num_channels_in=in_size,
bias=None)
cnn.conv(
depth_bottleneck,
1,
1,
stride,
stride,
input_layer=input_layer,
num_channels_in=in_size,
use_batch_norm=True,
bias=None)
cnn.conv(
depth_bottleneck,
3,
3,
1,
1,
mode='SAME_RESNET',
use_batch_norm=True,
bias=None)
res = cnn.conv(
depth, 1, 1, 1, 1, activation=None, use_batch_norm=True, bias=None)
output = tf.nn.relu(shortcut + res)
cnn.top_layer = output
cnn.top_size = depth
def bottleneck_block_v2(cnn, depth, depth_bottleneck, stride):
"""Bottleneck block with identity short-cut for ResNet v2.
The main difference from v1 is that a batch norm and relu are done at the
start of the block, instead of the end. This initial batch norm and relu is
collectively called a pre-activation.
Args:
cnn: the network to append bottleneck blocks.
depth: the number of output filters for this bottleneck block.
depth_bottleneck: the number of bottleneck filters for this block.
stride: Stride used in the first layer of the bottleneck block.
"""
input_layer = cnn.top_layer
in_size = cnn.top_size
name_key = 'resnet_v2'
name = name_key + str(cnn.counts[name_key])
cnn.counts[name_key] += 1
preact = cnn.batch_norm()
preact = tf.nn.relu(preact)
with tf.variable_scope(name):
if depth == in_size:
if stride == 1:
shortcut = input_layer
else:
shortcut = cnn.apool(
1,
1,
stride,
stride,
input_layer=input_layer,
num_channels_in=in_size)
else:
shortcut = cnn.conv(
depth,
1,
1,
stride,
stride,
activation=None,
use_batch_norm=False,
input_layer=preact,
num_channels_in=in_size,
bias=None)
cnn.conv(
depth_bottleneck,
1,
1,
stride,
stride,
input_layer=preact,
num_channels_in=in_size,
use_batch_norm=True,
bias=None)
cnn.conv(
depth_bottleneck,
3,
3,
1,
1,
mode='SAME_RESNET',
use_batch_norm=True,
bias=None)
res = cnn.conv(
depth,
1,
1,
1,
1,
activation=None,
use_batch_norm=False,
bias=None)
output = shortcut + res
cnn.top_layer = output
cnn.top_size = depth
def bottleneck_block(cnn, depth, depth_bottleneck, stride, pre_activation):
"""Bottleneck block with identity short-cut.
Args:
cnn: the network to append bottleneck blocks.
depth: the number of output filters for this bottleneck block.
depth_bottleneck: the number of bottleneck filters for this block.
stride: Stride used in the first layer of the bottleneck block.
pre_activation: use pre_activation structure used in v2 or not.
"""
if pre_activation:
bottleneck_block_v2(cnn, depth, depth_bottleneck, stride)
else:
bottleneck_block_v1(cnn, depth, depth_bottleneck, stride)
def residual_block(cnn, depth, stride, pre_activation):
"""Residual block with identity short-cut.
Args:
cnn: the network to append residual blocks.
depth: the number of output filters for this residual block.
stride: Stride used in the first layer of the residual block.
pre_activation: use pre_activation structure or not.
"""
input_layer = cnn.top_layer
in_size = cnn.top_size
if in_size != depth:
# Plan A of shortcut.
shortcut = cnn.apool(
1,
1,
stride,
stride,
input_layer=input_layer,
num_channels_in=in_size)
padding = (depth - in_size) // 2
if cnn.channel_pos == 'channels_last':
shortcut = tf.pad(shortcut,
[[0, 0], [0, 0], [0, 0], [padding, padding]])
else:
shortcut = tf.pad(shortcut,
[[0, 0], [padding, padding], [0, 0], [0, 0]])
else:
shortcut = input_layer
if pre_activation:
res = cnn.batch_norm(input_layer)
res = tf.nn.relu(res)
else:
res = input_layer
cnn.conv(
depth,
3,
3,
stride,
stride,
input_layer=res,
num_channels_in=in_size,
use_batch_norm=True,
bias=None)
if pre_activation:
res = cnn.conv(
depth,
3,
3,
1,
1,
activation=None,
use_batch_norm=False,
bias=None)
output = shortcut + res
else:
res = cnn.conv(
depth, 3, 3, 1, 1, activation=None, use_batch_norm=True, bias=None)
output = tf.nn.relu(shortcut + res)
cnn.top_layer = output
cnn.top_size = depth
class ResnetModel(model_lib.Model):
"""Resnet cnn network configuration."""
def __init__(self, model, layer_counts):
default_batch_sizes = {
'resnet50': 64,
'resnet101': 32,
'resnet152': 32,
'resnet50_v2': 64,
'resnet101_v2': 32,
'resnet152_v2': 32,
}
batch_size = default_batch_sizes.get(model, 32)
super(ResnetModel, self).__init__(model, 224, batch_size, 0.005,
layer_counts)
self.pre_activation = 'v2' in model
def add_inference(self, cnn):
if self.layer_counts is None:
raise ValueError(
'Layer counts not specified for %s' % self.get_model())
cnn.use_batch_norm = True
cnn.batch_norm_config = {
'decay': 0.997,
'epsilon': 1e-5,
'scale': True
}
cnn.conv(64, 7, 7, 2, 2, mode='SAME_RESNET', use_batch_norm=True)
cnn.mpool(3, 3, 2, 2, mode='SAME')
for _ in xrange(self.layer_counts[0]):
bottleneck_block(cnn, 256, 64, 1, self.pre_activation)
for i in xrange(self.layer_counts[1]):
stride = 2 if i == 0 else 1
bottleneck_block(cnn, 512, 128, stride, self.pre_activation)
for i in xrange(self.layer_counts[2]):
stride = 2 if i == 0 else 1
bottleneck_block(cnn, 1024, 256, stride, self.pre_activation)
for i in xrange(self.layer_counts[3]):
stride = 2 if i == 0 else 1
bottleneck_block(cnn, 2048, 512, stride, self.pre_activation)
if self.pre_activation:
cnn.batch_norm()
cnn.top_layer = tf.nn.relu(cnn.top_layer)
cnn.spatial_mean()
def get_learning_rate(self, global_step, batch_size):
raise NotImplementedError
def create_resnet50_model():
return ResnetModel('resnet50', (3, 4, 6, 3))
def create_resnet50_v2_model():
return ResnetModel('resnet50_v2', (3, 4, 6, 3))
def create_resnet101_model():
return ResnetModel('resnet101', (3, 4, 23, 3))
def create_resnet101_v2_model():
return ResnetModel('resnet101_v2', (3, 4, 23, 3))
def create_resnet152_model():
return ResnetModel('resnet152', (3, 8, 36, 3))
def create_resnet152_v2_model():
return ResnetModel('resnet152_v2', (3, 8, 36, 3))
class ResnetCifar10Model(model_lib.Model):
"""Resnet cnn network configuration for Cifar 10 dataset.
V1 model architecture follows the one defined in the paper:
https://arxiv.org/pdf/1512.03385.pdf.
V2 model architecture follows the one defined in the paper:
https://arxiv.org/pdf/1603.05027.pdf.
"""
def __init__(self, model, layer_counts):
self.pre_activation = 'v2' in model
super(ResnetCifar10Model, self).__init__(model, 32, 128, 0.1,
layer_counts)
def add_inference(self, cnn):
if self.layer_counts is None:
raise ValueError(
'Layer counts not specified for %s' % self.get_model())
cnn.use_batch_norm = True
cnn.batch_norm_config = {'decay': 0.9, 'epsilon': 1e-5, 'scale': True}
if self.pre_activation:
cnn.conv(16, 3, 3, 1, 1, use_batch_norm=True)
else:
cnn.conv(16, 3, 3, 1, 1, activation=None, use_batch_norm=True)
for i in xrange(self.layer_counts[0]):
# reshape to batch_size x 16 x 32 x 32
residual_block(cnn, 16, 1, self.pre_activation)
for i in xrange(self.layer_counts[1]):
stride = 2 if i == 0 else 1
# reshape to batch_size x 32 x 16 x 16
residual_block(cnn, 32, stride, self.pre_activation)
for i in xrange(self.layer_counts[2]):
stride = 2 if i == 0 else 1
# reshape to batch_size x 64 x 8 x 8
residual_block(cnn, 64, stride, self.pre_activation)
if self.pre_activation:
cnn.batch_norm()
cnn.top_layer = tf.nn.relu(cnn.top_layer)
cnn.spatial_mean()
def get_learning_rate(self, global_step, batch_size):
num_batches_per_epoch = int(50000 / batch_size)
boundaries = num_batches_per_epoch * np.array(
[82, 123, 300], dtype=np.int64)
boundaries = [x for x in boundaries]
values = [0.1, 0.01, 0.001, 0.0002]
return tf.train.piecewise_constant(global_step, boundaries, values)
def create_resnet20_cifar_model():
return ResnetCifar10Model('resnet20', (3, 3, 3))
def create_resnet20_v2_cifar_model():
return ResnetCifar10Model('resnet20_v2', (3, 3, 3))
def create_resnet32_cifar_model():
return ResnetCifar10Model('resnet32_v2', (5, 5, 5))
def create_resnet32_v2_cifar_model():
return ResnetCifar10Model('resnet32_v2', (5, 5, 5))
def create_resnet44_cifar_model():
return ResnetCifar10Model('resnet44', (7, 7, 7))
def create_resnet44_v2_cifar_model():
return ResnetCifar10Model('resnet44_v2', (7, 7, 7))
def create_resnet56_cifar_model():
return ResnetCifar10Model('resnet56', (9, 9, 9))
def create_resnet56_v2_cifar_model():
return ResnetCifar10Model('resnet56_v2', (9, 9, 9))
def create_resnet110_cifar_model():
return ResnetCifar10Model('resnet110', (18, 18, 18))
def create_resnet110_v2_cifar_model():
return ResnetCifar10Model('resnet110_v2', (18, 18, 18))

View file

@ -0,0 +1,46 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tfbench import model_config
class MockDataset():
name = "synthetic"
class TFBenchModel(object):
def __init__(self, batch=64, use_cpus=False):
image_shape = [batch, 224, 224, 3]
labels_shape = [batch]
# Synthetic image should be within [0, 255].
images = tf.truncated_normal(
image_shape,
dtype=tf.float32,
mean=127,
stddev=60,
name='synthetic_images')
# Minor hack to avoid H2D copy when using synthetic data
self.inputs = tf.contrib.framework.local_variable(
images, name='gpu_cached_images')
self.labels = tf.random_uniform(
labels_shape,
minval=0,
maxval=999,
dtype=tf.int32,
name='synthetic_labels')
self.model = model_config.get_model_config("resnet101", MockDataset())
logits, aux = self.model.build_network(
self.inputs, data_format=use_cpus and "NHWC" or "NCHW")
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=self.labels)
self.loss = tf.reduce_mean(loss, name='xentropy-loss')
self.optimizer = tf.train.GradientDescentOptimizer(1e-6)
def get_feed_dict(self):
return {}

View file

@ -0,0 +1,124 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import logging
import os
import time
import tensorflow as tf
import ray
logger = logging.getLogger(__name__)
def fetch(oids):
if ray.global_state.use_raylet:
local_sched_client = ray.worker.global_worker.local_scheduler_client
for o in oids:
ray_obj_id = ray.ObjectID(o)
local_sched_client.reconstruct_objects([ray_obj_id], True)
else:
for o in oids:
plasma_id = ray.pyarrow.plasma.ObjectID(o)
ray.worker.global_worker.plasma_client.fetch([plasma_id])
def run_timeline(sess, ops, feed_dict=None, write_timeline=False, name=""):
feed_dict = feed_dict or {}
if write_timeline:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
fetches = sess.run(
ops,
options=run_options,
run_metadata=run_metadata,
feed_dict=feed_dict)
trace = Timeline(step_stats=run_metadata.step_stats)
outf = "timeline-{}-{}.json".format(name, os.getpid())
trace_file = open(outf, "w")
logger.info("wrote tf timeline to", os.path.abspath(outf))
trace_file.write(trace.generate_chrome_trace_format())
else:
fetches = sess.run(ops, feed_dict=feed_dict)
return fetches
class Timeline(object):
def __init__(self, tid):
self.events = []
self.offset = 0
self.start_time = self.time()
self.tid = tid
def patch_ray(self):
orig_log = ray.worker.log
def custom_log(event_type, kind, *args, **kwargs):
orig_log(event_type, kind, *args, **kwargs)
if kind == ray.worker.LOG_SPAN_START:
self.start(event_type)
elif kind == ray.worker.LOG_SPAN_END:
self.end(event_type)
elif kind == ray.worker.LOG_SPAN_POINT:
self.event(event_type)
ray.worker.log = custom_log
def time(self):
return time.time() + self.offset
def reset(self):
self.events = []
self.start_time = self.time()
def start(self, name):
self.events.append((self.tid, "B", name, self.time()))
def end(self, name):
self.events.append((self.tid, "E", name, self.time()))
def event(self, name):
now = self.time()
self.events.append((self.tid, "B", name, now))
self.events.append((self.tid, "E", name, now + .0001))
def merge(self, other):
if other.start_time < self.start_time:
self.start_time = other.start_time
self.events.extend(other.events)
self.events.sort(key=lambda e: e[3])
def chrome_trace_format(self, filename):
out = []
for tid, ph, name, t in self.events:
ts = int((t - self.start_time) * 1000000)
out.append({
"name": name,
"tid": tid,
"pid": tid,
"ph": ph,
"ts": ts,
})
with open(filename, "w") as f:
f.write(json.dumps(out))
logger.info("Wrote chrome timeline to", filename)
if __name__ == "__main__":
a = Timeline(1)
b = Timeline(2)
a.start("hi")
time.sleep(.1)
b.start("bye")
a.start("hi3")
time.sleep(.1)
a.end("hi3")
b.end("bye")
time.sleep(.1)
a.end("hi")
b.start("b1")
b.end("b1")
a.merge(b)
a.chrome_trace_format("test.json")

View file

@ -303,6 +303,9 @@ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/examples/cartpole_lstm.py --stop=75
docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2
# No Xray for PyTorch
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \