Initial version of evolution strategies example. (#544)

* Initial commit of evolution strategies example.

* Some small simplifications.

* Update example to use new API.

* Add example to documentation.
This commit is contained in:
Robert Nishihara 2017-05-14 17:53:51 -07:00 committed by Philipp Moritz
parent 9f91eb8c91
commit 3c5375345f
8 changed files with 1222 additions and 0 deletions

View file

@ -0,0 +1,87 @@
Evolution Strategies
====================
This document provides a walkthrough of the evolution strategies example.
To run the application, first install some dependencies.
.. code-block:: bash
pip install tensorflow
pip install gym
You can view the `code for this example`_.
.. _`code for this example`: https://github.com/ray-project/ray/tree/master/examples/evolution_strategies
The script can be run as follows. Note that the configuration is tuned to work
on the ``Humanoid-v1`` gym environment.
.. code-block:: bash
python examples/evolution_strategies/evolution_strategies.py
At the heart of this example, we define a ``Worker`` class. These workers have
a method ``do_rollouts``, which will be used to perform simulate randomly
perturbed policies in a given environment.
.. code-block:: python
@ray.remote
class Worker(object):
def __init__(self, config, policy_params, env_name, noise):
self.env = # Initialize environment.
self.policy = # Construct policy.
# Details omitted.
def do_rollouts(self, params):
# Set the network weights.
self.policy.set_trainable_flat(params)
perturbation = # Generate a random perturbation to the policy.
self.policy.set_trainable_flat(params + perturbation)
# Do rollout with the perturbed policy.
self.policy.set_trainable_flat(params - perturbation)
# Do rollout with the perturbed policy.
# Return the rewards.
In the main loop, we create a number of actors with this class.
.. code-block:: python
workers = [Worker.remote(config, policy_params, env_name, noise_id)
for _ in range(num_workers)]
We then enter an infinite loop in which we use the actors to perform rollouts
and use the rewards from the rollouts to update the policy.
.. code-block:: python
while True:
# Get the current policy weights.
theta = policy.get_trainable_flat()
# Put the current policy weights in the object store.
theta_id = ray.put(theta)
# Use the actors to do rollouts, note that we pass in the ID of the policy
# weights.
rollout_ids = [worker.do_rollouts.remote(theta_id), for worker in workers]
# Get the results of the rollouts.
results = ray.get(rollout_ids)
# Update the policy.
optimizer.update(...)
In addition, note that we create a large object representing a shared block of
random noise. We then put the block in the object store so that each ``Worker``
actor can use it without creating its own copy.
.. code-block:: python
@ray.remote
def create_shared_noise():
noise = np.random.randn(250000000)
return noise
noise_id = create_shared_noise.remote()
Recall that the ``noise_id`` argument is passed into the actor constructor.

View file

@ -30,6 +30,7 @@ Ray
example-resnet.rst
example-a3c.rst
example-lbfgs.rst
example-evolution-strategies.rst
using-ray-with-tensorflow.rst
.. toctree::

View file

@ -0,0 +1,286 @@
# Code in this file is copied and adapted from
# https://github.com/openai/evolution-strategies-starter.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from collections import namedtuple
import gym
import numpy as np
import ray
import time
import optimizers
import policies
import tabular_logger as tlogger
import tf_util
import utils
Config = namedtuple("Config", [
"l2coeff", "noise_stdev", "episodes_per_batch", "timesteps_per_batch",
"calc_obstat_prob", "eval_prob", "snapshot_freq", "return_proc_mode",
"episode_cutoff_mode"
])
Result = namedtuple("Result", [
"noise_inds_n", "returns_n2", "sign_returns_n2", "lengths_n2",
"eval_return", "eval_length", "ob_sum", "ob_sumsq", "ob_count"
])
@ray.remote
def create_shared_noise():
"""Create a large array of noise to be shared by all workers."""
seed = 123
count = 250000000
noise = np.random.RandomState(seed).randn(count).astype(np.float32)
return noise
class SharedNoiseTable(object):
def __init__(self, noise):
self.noise = noise
assert self.noise.dtype == np.float32
def get(self, i, dim):
return self.noise[i:i + dim]
def sample_index(self, stream, dim):
return stream.randint(0, len(self.noise) - dim + 1)
@ray.remote
class Worker(object):
def __init__(self, config, policy_params, env_name, noise,
min_task_runtime=0.2):
self.min_task_runtime = min_task_runtime
self.config = config
self.policy_params = policy_params
self.noise = SharedNoiseTable(noise)
self.env = gym.make(env_name)
self.sess = utils.make_session(single_threaded=True)
self.policy = policies.MujocoPolicy(self.env.observation_space,
self.env.action_space,
**policy_params)
tf_util.initialize()
self.rs = np.random.RandomState()
assert self.policy.needs_ob_stat == (self.config.calc_obstat_prob != 0)
def rollout_and_update_ob_stat(self, timestep_limit, task_ob_stat):
if (self.policy.needs_ob_stat and self.config.calc_obstat_prob != 0 and
self.rs.rand() < self.config.calc_obstat_prob):
rollout_rews, rollout_len, obs = self.policy.rollout(
self.env, timestep_limit=timestep_limit, save_obs=True,
random_stream=self.rs)
task_ob_stat.increment(obs.sum(axis=0), np.square(obs).sum(axis=0),
len(obs))
else:
rollout_rews, rollout_len = self.policy.rollout(
self.env, timestep_limit=timestep_limit, random_stream=self.rs)
return rollout_rews, rollout_len
def do_rollouts(self, params, ob_mean, ob_std, timestep_limit=None):
# Set the network weights.
self.policy.set_trainable_flat(params)
if self.policy.needs_ob_stat:
self.policy.set_ob_stat(ob_mean, ob_std)
if self.config.eval_prob != 0:
raise NotImplementedError("Eval rollouts are not implemented.")
noise_inds, returns, sign_returns, lengths = [], [], [], []
# We set eps=0 because we're incrementing only.
task_ob_stat = utils.RunningStat(self.env.observation_space.shape, eps=0)
# Perform some rollouts with noise.
while (len(noise_inds) == 0 or
time.time() - task_tstart < self.min_task_runtime):
noise_idx = self.noise.sample_index(self.rs, self.policy.num_params)
perturbation = self.config.noise_stdev * self.noise.get(
noise_idx, self.policy.num_params)
# These two sampling steps could be done in parallel on different actors
# letting us update twice as frequently.
self.policy.set_trainable_flat(params + perturbation)
rews_pos, len_pos = self.rollout_and_update_ob_stat(timestep_limit,
task_ob_stat)
self.policy.set_trainable_flat(params - perturbation)
rews_neg, len_neg = self.rollout_and_update_ob_stat(timestep_limit,
task_ob_stat)
noise_inds.append(noise_idx)
returns.append([rews_pos.sum(), rews_neg.sum()])
sign_returns.append([np.sign(rews_pos).sum(), np.sign(rews_neg).sum()])
lengths.append([len_pos, len_neg])
return Result(
noise_inds_n=np.array(noise_inds),
returns_n2=np.array(returns, dtype=np.float32),
sign_returns_n2=np.array(sign_returns, dtype=np.float32),
lengths_n2=np.array(lengths, dtype=np.int32),
eval_return=None,
eval_length=None,
ob_sum=(None if task_ob_stat.count == 0 else task_ob_stat.sum),
ob_sumsq=(None if task_ob_stat.count == 0 else task_ob_stat.sumsq),
ob_count=task_ob_stat.count)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train an RL agent on Pong.")
parser.add_argument("--num-workers", default=10, type=int,
help=("The number of actors to create in aggregate "
"across the cluster."))
parser.add_argument("--env-name", default="Pendulum-v0", type=str,
help="The name of the gym environment to use.")
parser.add_argument("--stepsize", default=0.01, type=float,
help="The stepsize to use.")
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
args = parser.parse_args()
num_workers = args.num_workers
env_name = args.env_name
stepsize = args.stepsize
ray.init(redis_address=args.redis_address,
num_workers=(0 if args.redis_address is None else None))
# Tell Ray to serialize Config and Result objects.
ray.register_class(Config)
ray.register_class(Result)
config = Config(l2coeff=0.005,
noise_stdev=0.02,
episodes_per_batch=10000,
timesteps_per_batch=100000,
calc_obstat_prob=0.01,
eval_prob=0,
snapshot_freq=20,
return_proc_mode="centered_rank",
episode_cutoff_mode="env_default")
policy_params = {
"ac_bins": "continuous:",
"ac_noise_std": 0.01,
"nonlin_type": "tanh",
"hidden_dims": [256, 256],
"connection_type": "ff"
}
# Create the shared noise table.
print("Creating shared noise table.")
noise_id = create_shared_noise.remote()
noise = SharedNoiseTable(ray.get(noise_id))
# Create the actors.
print("Creating actors.")
workers = [Worker.remote(config, policy_params, env_name, noise_id)
for _ in range(num_workers)]
env = gym.make(env_name)
sess = utils.make_session(single_threaded=False)
policy = policies.MujocoPolicy(env.observation_space, env.action_space,
**policy_params)
tf_util.initialize()
optimizer = optimizers.Adam(policy, stepsize)
ob_stat = utils.RunningStat(env.observation_space.shape, eps=1e-2)
episodes_so_far = 0
timesteps_so_far = 0
tstart = time.time()
while True:
step_tstart = time.time()
theta = policy.get_trainable_flat()
assert theta.dtype == np.float32
# Put the current policy weights in the object store.
theta_id = ray.put(theta)
# Use the actors to do rollouts, note that we pass in the ID of the policy
# weights.
rollout_ids = [worker.do_rollouts.remote(
theta_id,
ob_stat.mean if policy.needs_ob_stat else None,
ob_stat.std if policy.needs_ob_stat else None)
for worker in workers]
# Get the results of the rollouts.
results = ray.get(rollout_ids)
curr_task_results = []
ob_count_this_batch = 0
# Loop over the results
for result in results:
assert result.eval_length is None, "We aren't doing eval rollouts."
assert result.noise_inds_n.ndim == 1
assert result.returns_n2.shape == (len(result.noise_inds_n), 2)
assert result.lengths_n2.shape == (len(result.noise_inds_n), 2)
assert result.returns_n2.dtype == np.float32
result_num_eps = result.lengths_n2.size
result_num_timesteps = result.lengths_n2.sum()
episodes_so_far += result_num_eps
timesteps_so_far += result_num_timesteps
curr_task_results.append(result)
# Update ob stats.
if policy.needs_ob_stat and result.ob_count > 0:
ob_stat.increment(result.ob_sum, result.ob_sumsq, result.ob_count)
ob_count_this_batch += result.ob_count
# Assemble the results.
noise_inds_n = np.concatenate([r.noise_inds_n for
r in curr_task_results])
returns_n2 = np.concatenate([r.returns_n2 for r in curr_task_results])
lengths_n2 = np.concatenate([r.lengths_n2 for r in curr_task_results])
assert noise_inds_n.shape[0] == returns_n2.shape[0] == lengths_n2.shape[0]
# Process the returns.
if config.return_proc_mode == "centered_rank":
proc_returns_n2 = utils.compute_centered_ranks(returns_n2)
else:
raise NotImplementedError(config.return_proc_mode)
# Compute and take a step.
g, count = utils.batched_weighted_sum(
proc_returns_n2[:, 0] - proc_returns_n2[:, 1],
(noise.get(idx, policy.num_params) for idx in noise_inds_n),
batch_size=500)
g /= returns_n2.size
assert (g.shape == (policy.num_params,) and g.dtype == np.float32 and
count == len(noise_inds_n))
update_ratio = optimizer.update(-g + config.l2coeff * theta)
# Update ob stat (we're never running the policy in the master, but we
# might be snapshotting the policy).
if policy.needs_ob_stat:
policy.set_ob_stat(ob_stat.mean, ob_stat.std)
step_tend = time.time()
tlogger.record_tabular("EpRewMean", returns_n2.mean())
tlogger.record_tabular("EpRewStd", returns_n2.std())
tlogger.record_tabular("EpLenMean", lengths_n2.mean())
tlogger.record_tabular("Norm",
float(np.square(policy.get_trainable_flat()).sum()))
tlogger.record_tabular("GradNorm", float(np.square(g).sum()))
tlogger.record_tabular("UpdateRatio", float(update_ratio))
tlogger.record_tabular("EpisodesThisIter", lengths_n2.size)
tlogger.record_tabular("EpisodesSoFar", episodes_so_far)
tlogger.record_tabular("TimestepsThisIter", lengths_n2.sum())
tlogger.record_tabular("TimestepsSoFar", timesteps_so_far)
tlogger.record_tabular("ObCount", ob_count_this_batch)
tlogger.record_tabular("TimeElapsedThisIter", step_tend - step_tstart)
tlogger.record_tabular("TimeElapsed", step_tend - tstart)
tlogger.dump_tabular()

View file

@ -0,0 +1,57 @@
# Code in this file is copied and adapted from
# https://github.com/openai/evolution-strategies-starter.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
class Optimizer(object):
def __init__(self, pi):
self.pi = pi
self.dim = pi.num_params
self.t = 0
def update(self, globalg):
self.t += 1
step = self._compute_step(globalg)
theta = self.pi.get_trainable_flat()
ratio = np.linalg.norm(step) / np.linalg.norm(theta)
self.pi.set_trainable_flat(theta + step)
return ratio
def _compute_step(self, globalg):
raise NotImplementedError
class SGD(Optimizer):
def __init__(self, pi, stepsize, momentum=0.9):
Optimizer.__init__(self, pi)
self.v = np.zeros(self.dim, dtype=np.float32)
self.stepsize, self.momentum = stepsize, momentum
def _compute_step(self, globalg):
self.v = self.momentum * self.v + (1. - self.momentum) * globalg
step = -self.stepsize * self.v
return step
class Adam(Optimizer):
def __init__(self, pi, stepsize, beta1=0.9, beta2=0.999, epsilon=1e-08):
Optimizer.__init__(self, pi)
self.stepsize = stepsize
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.m = np.zeros(self.dim, dtype=np.float32)
self.v = np.zeros(self.dim, dtype=np.float32)
def _compute_step(self, globalg):
a = self.stepsize * (np.sqrt(1 - self.beta2 ** self.t) /
(1 - self.beta1 ** self.t))
self.m = self.beta1 * self.m + (1 - self.beta1) * globalg
self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg)
step = -a * self.m / (np.sqrt(self.v) + self.epsilon)
return step

View file

@ -0,0 +1,251 @@
# Code in this file is copied and adapted from
# https://github.com/openai/evolution-strategies-starter.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import pickle
import h5py
import numpy as np
import tensorflow as tf
import tf_util as U
logger = logging.getLogger(__name__)
class Policy:
def __init__(self, *args, **kwargs):
self.args, self.kwargs = args, kwargs
self.scope = self._initialize(*args, **kwargs)
self.all_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope.name)
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope.name)
self.num_params = sum(int(np.prod(v.get_shape().as_list())) for v in self.trainable_variables)
self._setfromflat = U.SetFromFlat(self.trainable_variables)
self._getflat = U.GetFlat(self.trainable_variables)
logger.info('Trainable variables ({} parameters)'.format(self.num_params))
for v in self.trainable_variables:
shp = v.get_shape().as_list()
logger.info('- {} shape:{} size:{}'.format(v.name, shp, np.prod(shp)))
logger.info('All variables')
for v in self.all_variables:
shp = v.get_shape().as_list()
logger.info('- {} shape:{} size:{}'.format(v.name, shp, np.prod(shp)))
placeholders = [tf.placeholder(v.value().dtype, v.get_shape().as_list()) for v in self.all_variables]
self.set_all_vars = U.function(
inputs=placeholders,
outputs=[],
updates=[tf.group(*[v.assign(p) for v, p in zip(self.all_variables, placeholders)])]
)
def _initialize(self, *args, **kwargs):
raise NotImplementedError
def save(self, filename):
assert filename.endswith('.h5')
with h5py.File(filename, 'w') as f:
for v in self.all_variables:
f[v.name] = v.eval()
# TODO: it would be nice to avoid pickle, but it's convenient to pass Python objects to _initialize
# (like Gym spaces or numpy arrays)
f.attrs['name'] = type(self).__name__
f.attrs['args_and_kwargs'] = np.void(pickle.dumps((self.args, self.kwargs), protocol=-1))
@classmethod
def Load(cls, filename, extra_kwargs=None):
with h5py.File(filename, 'r') as f:
args, kwargs = pickle.loads(f.attrs['args_and_kwargs'].tostring())
if extra_kwargs:
kwargs.update(extra_kwargs)
policy = cls(*args, **kwargs)
policy.set_all_vars(*[f[v.name][...] for v in policy.all_variables])
return policy
# === Rollouts/training ===
def rollout(self, env, *, render=False, timestep_limit=None, save_obs=False, random_stream=None):
"""
If random_stream is provided, the rollout will take noisy actions with noise drawn from that stream.
Otherwise, no action noise will be added.
"""
env_timestep_limit = env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps')
timestep_limit = env_timestep_limit if timestep_limit is None else min(timestep_limit, env_timestep_limit)
rews = []
t = 0
if save_obs:
obs = []
ob = env.reset()
for _ in range(timestep_limit):
ac = self.act(ob[None], random_stream=random_stream)[0]
if save_obs:
obs.append(ob)
ob, rew, done, _ = env.step(ac)
rews.append(rew)
t += 1
if render:
env.render()
if done:
break
rews = np.array(rews, dtype=np.float32)
if save_obs:
return rews, t, np.array(obs)
return rews, t
def act(self, ob, random_stream=None):
raise NotImplementedError
def set_trainable_flat(self, x):
self._setfromflat(x)
def get_trainable_flat(self):
return self._getflat()
@property
def needs_ob_stat(self):
raise NotImplementedError
def set_ob_stat(self, ob_mean, ob_std):
raise NotImplementedError
def bins(x, dim, num_bins, name):
scores = U.dense(x, dim * num_bins, name, U.normc_initializer(0.01))
scores_nab = tf.reshape(scores, [-1, dim, num_bins])
return tf.argmax(scores_nab, 2) # 0 ... num_bins-1
class MujocoPolicy(Policy):
def _initialize(self, ob_space, ac_space, ac_bins, ac_noise_std, nonlin_type, hidden_dims, connection_type):
self.ac_space = ac_space
self.ac_bins = ac_bins
self.ac_noise_std = ac_noise_std
self.hidden_dims = hidden_dims
self.connection_type = connection_type
assert len(ob_space.shape) == len(self.ac_space.shape) == 1
assert np.all(np.isfinite(self.ac_space.low)) and np.all(np.isfinite(self.ac_space.high)), \
'Action bounds required'
self.nonlin = {'tanh': tf.tanh, 'relu': tf.nn.relu, 'lrelu': U.lrelu, 'elu': tf.nn.elu}[nonlin_type]
with tf.variable_scope(type(self).__name__) as scope:
# Observation normalization
ob_mean = tf.get_variable(
'ob_mean', ob_space.shape, tf.float32, tf.constant_initializer(np.nan), trainable=False)
ob_std = tf.get_variable(
'ob_std', ob_space.shape, tf.float32, tf.constant_initializer(np.nan), trainable=False)
in_mean = tf.placeholder(tf.float32, ob_space.shape)
in_std = tf.placeholder(tf.float32, ob_space.shape)
self._set_ob_mean_std = U.function([in_mean, in_std], [], updates=[
tf.assign(ob_mean, in_mean),
tf.assign(ob_std, in_std),
])
# Policy network
o = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
a = self._make_net(tf.clip_by_value((o - ob_mean) / ob_std, -5.0, 5.0))
self._act = U.function([o], a)
return scope
def _make_net(self, o):
# Process observation
if self.connection_type == 'ff':
x = o
for ilayer, hd in enumerate(self.hidden_dims):
x = self.nonlin(U.dense(x, hd, 'l{}'.format(ilayer), U.normc_initializer(1.0)))
else:
raise NotImplementedError(self.connection_type)
# Map to action
adim, ahigh, alow = self.ac_space.shape[0], self.ac_space.high, self.ac_space.low
assert isinstance(self.ac_bins, str)
ac_bin_mode, ac_bin_arg = self.ac_bins.split(':')
if ac_bin_mode == 'uniform':
# Uniformly spaced bins, from ac_space.low to ac_space.high
num_ac_bins = int(ac_bin_arg)
aidx_na = bins(x, adim, num_ac_bins, 'out') # 0 ... num_ac_bins-1
ac_range_1a = (ahigh - alow)[None, :]
a = 1. / (num_ac_bins - 1.) * tf.to_float(aidx_na) * ac_range_1a + alow[None, :]
elif ac_bin_mode == 'custom':
# Custom bins specified as a list of values from -1 to 1
# The bins are rescaled to ac_space.low to ac_space.high
acvals_k = np.array(list(map(float, ac_bin_arg.split(','))), dtype=np.float32)
logger.info('Custom action values: ' + ' '.join('{:.3f}'.format(x) for x in acvals_k))
assert acvals_k.ndim == 1 and acvals_k[0] == -1 and acvals_k[-1] == 1
acvals_ak = (
(ahigh - alow)[:, None] / (acvals_k[-1] - acvals_k[0]) * (acvals_k - acvals_k[0])[None, :]
+ alow[:, None]
)
aidx_na = bins(x, adim, len(acvals_k), 'out') # values in [0, k-1]
a = tf.gather_nd(
acvals_ak,
tf.concat([
tf.tile(np.arange(adim)[None, :, None], [tf.shape(aidx_na)[0], 1, 1]),
2,
tf.expand_dims(aidx_na, -1)
]) # (n,a,2)
) # (n,a)
elif ac_bin_mode == 'continuous':
a = U.dense(x, adim, 'out', U.normc_initializer(0.01))
else:
raise NotImplementedError(ac_bin_mode)
return a
def act(self, ob, random_stream=None):
a = self._act(ob)
if random_stream is not None and self.ac_noise_std != 0:
a += random_stream.randn(*a.shape) * self.ac_noise_std
return a
@property
def needs_ob_stat(self):
return True
@property
def needs_ref_batch(self):
return False
def set_ob_stat(self, ob_mean, ob_std):
self._set_ob_mean_std(ob_mean, ob_std)
def initialize_from(self, filename, ob_stat=None):
"""
Initializes weights from another policy, which must have the same architecture (variable names),
but the weight arrays can be smaller than the current policy.
"""
with h5py.File(filename, 'r') as f:
f_var_names = []
f.visititems(lambda name, obj: f_var_names.append(name) if isinstance(obj, h5py.Dataset) else None)
assert set(v.name for v in self.all_variables) == set(f_var_names), 'Variable names do not match'
init_vals = []
for v in self.all_variables:
shp = v.get_shape().as_list()
f_shp = f[v.name].shape
assert len(shp) == len(f_shp) and all(a >= b for a, b in zip(shp, f_shp)), \
'This policy must have more weights than the policy to load'
init_val = v.eval()
# ob_mean and ob_std are initialized with nan, so set them manually
if 'ob_mean' in v.name:
init_val[:] = 0
init_mean = init_val
elif 'ob_std' in v.name:
init_val[:] = 0.001
init_std = init_val
# Fill in subarray from the loaded policy
init_val[tuple([np.s_[:s] for s in f_shp])] = f[v.name]
init_vals.append(init_val)
self.set_all_vars(*init_vals)
if ob_stat is not None:
ob_stat.set_from_init(init_mean, init_std, init_count=1e5)

View file

@ -0,0 +1,194 @@
# Code in this file is copied and adapted from
# https://github.com/openai/evolution-strategies-starter.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import os
import shutil
import sys
import time
import tensorflow as tf
from tensorflow.core.util import event_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.util import compat
DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40
DISABLED = 50
class TbWriter(object):
"""
Based on SummaryWriter, but changed to allow for a different prefix
and to get rid of multithreading
oops, ended up using the same prefix anyway.
"""
def __init__(self, dir, prefix):
self.dir = dir
self.step = 1 # Start at 1, because EvWriter automatically generates an object with step=0
self.evwriter = pywrap_tensorflow.EventsWriter(compat.as_bytes(os.path.join(dir, prefix)))
def write_values(self, key2val):
summary = tf.Summary(value=[tf.Summary.Value(tag=k, simple_value=float(v))
for (k, v) in key2val.items()])
event = event_pb2.Event(wall_time=time.time(), summary=summary)
event.step = self.step # is there any reason why you'd want to specify the step?
self.evwriter.WriteEvent(event)
self.evwriter.Flush()
self.step += 1
def close(self):
self.evwriter.Close()
# ================================================================
# API
# ================================================================
def start(dir):
"""
dir: directory to put all output files
force: if dir already exists, should we delete it, or throw a RuntimeError?
"""
if _Logger.CURRENT is not _Logger.DEFAULT:
sys.stderr.write("WARNING: You asked to start logging (dir=%s), but you never stopped the previous logger (dir=%s).\n"%(dir, _Logger.CURRENT.dir))
_Logger.CURRENT = _Logger(dir=dir)
def stop():
if _Logger.CURRENT is _Logger.DEFAULT:
sys.stderr.write("WARNING: You asked to stop logging, but you never started any previous logger.\n"%(dir, _Logger.CURRENT.dir))
return
_Logger.CURRENT.close()
_Logger.CURRENT = _Logger.DEFAULT
def record_tabular(key, val):
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
"""
_Logger.CURRENT.record_tabular(key, val)
def dump_tabular():
"""
Write all of the diagnostics from the current iteration
level: int. (see logger.py docs) If the global logger level is higher than
the level argument here, don't print to stdout.
"""
_Logger.CURRENT.dump_tabular()
def log(*args, level=INFO):
"""
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
"""
_Logger.CURRENT.log(*args, level=level)
def debug(*args):
log(*args, level=DEBUG)
def info(*args):
log(*args, level=INFO)
def warn(*args):
log(*args, level=WARN)
def error(*args):
log(*args, level=ERROR)
def set_level(level):
"""
Set logging threshold on current logger.
"""
_Logger.CURRENT.set_level(level)
def get_dir():
"""
Get directory that log files are being written to.
will be None if there is no output directory (i.e., if you didn't call start)
"""
return _Logger.CURRENT.get_dir()
def get_expt_dir():
sys.stderr.write("get_expt_dir() is Deprecated. Switch to get_dir()\n")
return get_dir()
# ================================================================
# Backend
# ================================================================
class _Logger(object):
DEFAULT = None # A logger with no output files. (See right below class definition)
# So that you can still log to the terminal without setting up any output files
CURRENT = None # Current logger being used by the free functions above
def __init__(self, dir=None):
self.name2val = OrderedDict() # values this iteration
self.level = INFO
self.dir = dir
self.text_outputs = [sys.stdout]
if dir is not None:
os.makedirs(dir, exist_ok=True)
self.text_outputs.append(open(os.path.join(dir, "log.txt"), "w"))
self.tbwriter = TbWriter(dir=dir, prefix="events")
else:
self.tbwriter = None
# Logging API, forwarded
# ----------------------------------------
def record_tabular(self, key, val):
self.name2val[key] = val
def dump_tabular(self):
# Create strings for printing
key2str = OrderedDict()
for (key,val) in self.name2val.items():
if hasattr(val, "__float__"): valstr = "%-8.3g"%val
else: valstr = val
key2str[self._truncate(key)]=self._truncate(valstr)
keywidth = max(map(len, key2str.keys()))
valwidth = max(map(len, key2str.values()))
# Write to all text outputs
self._write_text("-"*(keywidth+valwidth+7), "\n")
for (key,val) in key2str.items():
self._write_text("| ", key, " "*(keywidth-len(key)), " | ", val, " "*(valwidth-len(val)), " |\n")
self._write_text("-"*(keywidth+valwidth+7), "\n")
for f in self.text_outputs:
try: f.flush()
except OSError: sys.stderr.write('Warning! OSError when flushing.\n')
# Write to tensorboard
if self.tbwriter is not None:
self.tbwriter.write_values(self.name2val)
self.name2val.clear()
def log(self, *args, level=INFO):
if self.level <= level:
self._do_log(*args)
# Configuration
# ----------------------------------------
def set_level(self, level):
self.level = level
def get_dir(self):
return self.dir
def close(self):
for f in self.text_outputs[1:]: f.close()
if self.tbwriter: self.tbwriter.close()
# Misc
# ----------------------------------------
def _do_log(self, *args):
self._write_text(*args, '\n')
for f in self.text_outputs:
try: f.flush()
except OSError: print('Warning! OSError when flushing.')
def _write_text(self, *strings):
for f in self.text_outputs:
for string in strings:
f.write(string)
def _truncate(self, s):
if len(s) > 33:
return s[:30] + "..."
else:
return s
_Logger.DEFAULT = _Logger()
_Logger.CURRENT = _Logger.DEFAULT

View file

@ -0,0 +1,260 @@
# Code in this file is copied and adapted from
# https://github.com/openai/evolution-strategies-starter.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import builtins
import functools
import copy
import os
# ================================================================
# Import all names into common namespace
# ================================================================
clip = tf.clip_by_value
# Make consistent with numpy
# ----------------------------------------
def sum(x, axis=None, keepdims=False):
return tf.reduce_sum(x, reduction_indices=None if axis is None else [axis], keep_dims = keepdims)
def mean(x, axis=None, keepdims=False):
return tf.reduce_mean(x, reduction_indices=None if axis is None else [axis], keep_dims = keepdims)
def var(x, axis=None, keepdims=False):
meanx = mean(x, axis=axis, keepdims=keepdims)
return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims)
def std(x, axis=None, keepdims=False):
return tf.sqrt(var(x, axis=axis, keepdims=keepdims))
def max(x, axis=None, keepdims=False):
return tf.reduce_max(x, reduction_indices=None if axis is None else [axis], keep_dims = keepdims)
def min(x, axis=None, keepdims=False):
return tf.reduce_min(x, reduction_indices=None if axis is None else [axis], keep_dims = keepdims)
def concatenate(arrs, axis=0):
return tf.concat(arrs, axis)
def argmax(x, axis=None):
return tf.argmax(x, dimension=axis)
def switch(condition, then_expression, else_expression):
'''Switches between two operations depending on a scalar value (int or bool).
Note that both `then_expression` and `else_expression`
should be symbolic tensors of the *same shape*.
# Arguments
condition: scalar tensor.
then_expression: TensorFlow operation.
else_expression: TensorFlow operation.
'''
x_shape = copy.copy(then_expression.get_shape())
x = tf.cond(tf.cast(condition, 'bool'),
lambda: then_expression,
lambda: else_expression)
x.set_shape(x_shape)
return x
# Extras
# ----------------------------------------
def l2loss(params):
if len(params) == 0:
return tf.constant(0.0)
else:
return tf.add_n([sum(tf.square(p)) for p in params])
def lrelu(x, leak=0.2):
f1 = 0.5 * (1 + leak)
f2 = 0.5 * (1 - leak)
return f1 * x + f2 * abs(x)
def categorical_sample_logits(X):
# https://github.com/tensorflow/tensorflow/issues/456
U = tf.random_uniform(tf.shape(X))
return argmax(X - tf.log(-tf.log(U)), axis=1)
# ================================================================
# Global session
# ================================================================
def get_session():
return tf.get_default_session()
def single_threaded_session():
tf_config = tf.ConfigProto(
inter_op_parallelism_threads=1,
intra_op_parallelism_threads=1)
return tf.Session(config=tf_config)
ALREADY_INITIALIZED = set()
def initialize():
new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
get_session().run(tf.variables_initializer(new_variables))
ALREADY_INITIALIZED.update(new_variables)
def eval(expr, feed_dict=None):
if feed_dict is None: feed_dict = {}
return get_session().run(expr, feed_dict=feed_dict)
def set_value(v, val):
get_session().run(v.assign(val))
def load_state(fname):
saver = tf.train.Saver()
saver.restore(get_session(), fname)
def save_state(fname):
os.makedirs(os.path.dirname(fname), exist_ok=True)
saver = tf.train.Saver()
saver.save(get_session(), fname)
# ================================================================
# Model components
# ================================================================
def normc_initializer(std=1.0):
def _initializer(shape, dtype=None, partition_info=None): #pylint: disable=W0613
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)
return _initializer
def dense(x, size, name, weight_init=None, bias=True):
w = tf.get_variable(name + "/w", [x.get_shape()[1], size], initializer=weight_init)
ret = tf.matmul(x, w)
if bias:
b = tf.get_variable(name + "/b", [size], initializer=tf.zeros_initializer())
return ret + b
else:
return ret
# ================================================================
# Basic Stuff
# ================================================================
def function(inputs, outputs, updates=None, givens=None):
if isinstance(outputs, list):
return _Function(inputs, outputs, updates, givens=givens)
elif isinstance(outputs, dict):
f = _Function(inputs, outputs.values(), updates, givens=givens)
return lambda *inputs : dict(zip(outputs.keys(), f(*inputs)))
else:
f = _Function(inputs, [outputs], updates, givens=givens)
return lambda *inputs : f(*inputs)[0]
class _Function(object):
def __init__(self, inputs, outputs, updates, givens, check_nan=False):
assert all(len(i.op.inputs)==0 for i in inputs), "inputs should all be placeholders"
self.inputs = inputs
updates = updates or []
self.update_group = tf.group(*updates)
self.outputs_update = list(outputs) + [self.update_group]
self.givens = {} if givens is None else givens
self.check_nan = check_nan
def __call__(self, *inputvals):
assert len(inputvals) == len(self.inputs)
feed_dict = dict(zip(self.inputs, inputvals))
feed_dict.update(self.givens)
results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
if self.check_nan:
if any(np.isnan(r).any() for r in results):
raise RuntimeError("Nan detected")
return results
# ================================================================
# Graph traversal
# ================================================================
VARIABLES = {}
# ================================================================
# Flat vectors
# ================================================================
def var_shape(x):
out = [k.value for k in x.get_shape()]
assert all(isinstance(a, int) for a in out), \
"shape function assumes that shape is fully known"
return out
def numel(x):
return intprod(var_shape(x))
def intprod(x):
return int(np.prod(x))
def flatgrad(loss, var_list):
grads = tf.gradients(loss, var_list)
return tf.concat([tf.reshape(grad, [numel(v)], 0)
for (v, grad) in zip(var_list, grads)])
class SetFromFlat(object):
def __init__(self, var_list, dtype=tf.float32):
assigns = []
shapes = list(map(var_shape, var_list))
total_size = np.sum([intprod(shape) for shape in shapes])
self.theta = theta = tf.placeholder(dtype,[total_size])
start=0
assigns = []
for (shape,v) in zip(shapes,var_list):
size = intprod(shape)
assigns.append(tf.assign(v, tf.reshape(theta[start:start+size],shape)))
start+=size
assert start == total_size
self.op = tf.group(*assigns)
def __call__(self, theta):
get_session().run(self.op, feed_dict={self.theta:theta})
class GetFlat(object):
def __init__(self, var_list):
self.op = tf.concat([tf.reshape(v, [numel(v)]) for v in var_list], 0)
def __call__(self):
return get_session().run(self.op)
# ================================================================
# Misc
# ================================================================
def scope_vars(scope, trainable_only):
"""
Get variables inside a scope
The scope can be specified as a string
"""
return tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only else tf.GraphKeys.GLOBAL_VARIABLES,
scope=scope if isinstance(scope, str) else scope.name
)
def in_session(f):
@functools.wraps(f)
def newfunc(*args, **kwargs):
with tf.Session():
f(*args, **kwargs)
return newfunc
_PLACEHOLDER_CACHE = {} # name -> (placeholder, dtype, shape)
def get_placeholder(name, dtype, shape):
print("calling get_placeholder", name)
if name in _PLACEHOLDER_CACHE:
out, dtype1, shape1 = _PLACEHOLDER_CACHE[name]
assert dtype1==dtype and shape1==shape
return out
else:
out = tf.placeholder(dtype=dtype, shape=shape, name=name)
_PLACEHOLDER_CACHE[name] = (out,dtype,shape)
return out
def get_placeholder_cached(name):
return _PLACEHOLDER_CACHE[name][0]
def flattenallbut0(x):
return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
def reset():
global _PLACEHOLDER_CACHE
global VARIABLES
_PLACEHOLDER_CACHE = {}
VARIABLES = {}
tf.reset_default_graph()

View file

@ -0,0 +1,86 @@
# Code in this file is copied and adapted from
# https://github.com/openai/evolution-strategies-starter.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def compute_ranks(x):
"""Returns ranks in [0, len(x))
Note: This is different from scipy.stats.rankdata, which returns ranks in
[1, len(x)].
"""
assert x.ndim == 1
ranks = np.empty(len(x), dtype=int)
ranks[x.argsort()] = np.arange(len(x))
return ranks
def compute_centered_ranks(x):
y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32)
y /= (x.size - 1)
y -= 0.5
return y
def make_session(single_threaded):
if not single_threaded:
return tf.InteractiveSession()
return tf.InteractiveSession(
config=tf.ConfigProto(inter_op_parallelism_threads=1,
intra_op_parallelism_threads=1))
def itergroups(items, group_size):
assert group_size >= 1
group = []
for x in items:
group.append(x)
if len(group) == group_size:
yield tuple(group)
del group[:]
if group:
yield tuple(group)
def batched_weighted_sum(weights, vecs, batch_size):
total = 0
num_items_summed = 0
for batch_weights, batch_vecs in zip(itergroups(weights, batch_size),
itergroups(vecs, batch_size)):
assert len(batch_weights) == len(batch_vecs) <= batch_size
total += np.dot(np.asarray(batch_weights, dtype=np.float32),
np.asarray(batch_vecs, dtype=np.float32))
num_items_summed += len(batch_weights)
return total, num_items_summed
class RunningStat(object):
def __init__(self, shape, eps):
self.sum = np.zeros(shape, dtype=np.float32)
self.sumsq = np.full(shape, eps, dtype=np.float32)
self.count = eps
def increment(self, s, ssq, c):
self.sum += s
self.sumsq += ssq
self.count += c
@property
def mean(self):
return self.sum / self.count
@property
def std(self):
return np.sqrt(np.maximum(self.sumsq / self.count - np.square(self.mean),
1e-2))
def set_from_init(self, init_mean, init_std, init_count):
self.sum[:] = init_mean * init_count
self.sumsq[:] = (np.square(init_mean) + np.square(init_std)) * init_count
self.count = init_count