mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Initial version of evolution strategies example. (#544)
* Initial commit of evolution strategies example. * Some small simplifications. * Update example to use new API. * Add example to documentation.
This commit is contained in:
parent
9f91eb8c91
commit
3c5375345f
8 changed files with 1222 additions and 0 deletions
87
doc/source/example-evolution-strategies.rst
Normal file
87
doc/source/example-evolution-strategies.rst
Normal file
|
@ -0,0 +1,87 @@
|
|||
Evolution Strategies
|
||||
====================
|
||||
|
||||
This document provides a walkthrough of the evolution strategies example.
|
||||
To run the application, first install some dependencies.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install tensorflow
|
||||
pip install gym
|
||||
|
||||
You can view the `code for this example`_.
|
||||
|
||||
.. _`code for this example`: https://github.com/ray-project/ray/tree/master/examples/evolution_strategies
|
||||
|
||||
The script can be run as follows. Note that the configuration is tuned to work
|
||||
on the ``Humanoid-v1`` gym environment.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/evolution_strategies/evolution_strategies.py
|
||||
|
||||
At the heart of this example, we define a ``Worker`` class. These workers have
|
||||
a method ``do_rollouts``, which will be used to perform simulate randomly
|
||||
perturbed policies in a given environment.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ray.remote
|
||||
class Worker(object):
|
||||
def __init__(self, config, policy_params, env_name, noise):
|
||||
self.env = # Initialize environment.
|
||||
self.policy = # Construct policy.
|
||||
# Details omitted.
|
||||
|
||||
def do_rollouts(self, params):
|
||||
# Set the network weights.
|
||||
self.policy.set_trainable_flat(params)
|
||||
perturbation = # Generate a random perturbation to the policy.
|
||||
|
||||
self.policy.set_trainable_flat(params + perturbation)
|
||||
# Do rollout with the perturbed policy.
|
||||
|
||||
self.policy.set_trainable_flat(params - perturbation)
|
||||
# Do rollout with the perturbed policy.
|
||||
|
||||
# Return the rewards.
|
||||
|
||||
In the main loop, we create a number of actors with this class.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
workers = [Worker.remote(config, policy_params, env_name, noise_id)
|
||||
for _ in range(num_workers)]
|
||||
|
||||
We then enter an infinite loop in which we use the actors to perform rollouts
|
||||
and use the rewards from the rollouts to update the policy.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
while True:
|
||||
# Get the current policy weights.
|
||||
theta = policy.get_trainable_flat()
|
||||
# Put the current policy weights in the object store.
|
||||
theta_id = ray.put(theta)
|
||||
# Use the actors to do rollouts, note that we pass in the ID of the policy
|
||||
# weights.
|
||||
rollout_ids = [worker.do_rollouts.remote(theta_id), for worker in workers]
|
||||
# Get the results of the rollouts.
|
||||
results = ray.get(rollout_ids)
|
||||
# Update the policy.
|
||||
optimizer.update(...)
|
||||
|
||||
In addition, note that we create a large object representing a shared block of
|
||||
random noise. We then put the block in the object store so that each ``Worker``
|
||||
actor can use it without creating its own copy.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ray.remote
|
||||
def create_shared_noise():
|
||||
noise = np.random.randn(250000000)
|
||||
return noise
|
||||
|
||||
noise_id = create_shared_noise.remote()
|
||||
|
||||
Recall that the ``noise_id`` argument is passed into the actor constructor.
|
|
@ -30,6 +30,7 @@ Ray
|
|||
example-resnet.rst
|
||||
example-a3c.rst
|
||||
example-lbfgs.rst
|
||||
example-evolution-strategies.rst
|
||||
using-ray-with-tensorflow.rst
|
||||
|
||||
.. toctree::
|
||||
|
|
286
examples/evolution_strategies/evolution_strategies.py
Normal file
286
examples/evolution_strategies/evolution_strategies.py
Normal file
|
@ -0,0 +1,286 @@
|
|||
# Code in this file is copied and adapted from
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from collections import namedtuple
|
||||
import gym
|
||||
import numpy as np
|
||||
import ray
|
||||
import time
|
||||
|
||||
import optimizers
|
||||
import policies
|
||||
import tabular_logger as tlogger
|
||||
import tf_util
|
||||
import utils
|
||||
|
||||
|
||||
Config = namedtuple("Config", [
|
||||
"l2coeff", "noise_stdev", "episodes_per_batch", "timesteps_per_batch",
|
||||
"calc_obstat_prob", "eval_prob", "snapshot_freq", "return_proc_mode",
|
||||
"episode_cutoff_mode"
|
||||
])
|
||||
|
||||
Result = namedtuple("Result", [
|
||||
"noise_inds_n", "returns_n2", "sign_returns_n2", "lengths_n2",
|
||||
"eval_return", "eval_length", "ob_sum", "ob_sumsq", "ob_count"
|
||||
])
|
||||
|
||||
|
||||
@ray.remote
|
||||
def create_shared_noise():
|
||||
"""Create a large array of noise to be shared by all workers."""
|
||||
seed = 123
|
||||
count = 250000000
|
||||
noise = np.random.RandomState(seed).randn(count).astype(np.float32)
|
||||
return noise
|
||||
|
||||
|
||||
class SharedNoiseTable(object):
|
||||
def __init__(self, noise):
|
||||
self.noise = noise
|
||||
assert self.noise.dtype == np.float32
|
||||
|
||||
def get(self, i, dim):
|
||||
return self.noise[i:i + dim]
|
||||
|
||||
def sample_index(self, stream, dim):
|
||||
return stream.randint(0, len(self.noise) - dim + 1)
|
||||
|
||||
|
||||
@ray.remote
|
||||
class Worker(object):
|
||||
def __init__(self, config, policy_params, env_name, noise,
|
||||
min_task_runtime=0.2):
|
||||
self.min_task_runtime = min_task_runtime
|
||||
self.config = config
|
||||
self.policy_params = policy_params
|
||||
self.noise = SharedNoiseTable(noise)
|
||||
|
||||
self.env = gym.make(env_name)
|
||||
self.sess = utils.make_session(single_threaded=True)
|
||||
self.policy = policies.MujocoPolicy(self.env.observation_space,
|
||||
self.env.action_space,
|
||||
**policy_params)
|
||||
tf_util.initialize()
|
||||
|
||||
self.rs = np.random.RandomState()
|
||||
|
||||
assert self.policy.needs_ob_stat == (self.config.calc_obstat_prob != 0)
|
||||
|
||||
def rollout_and_update_ob_stat(self, timestep_limit, task_ob_stat):
|
||||
if (self.policy.needs_ob_stat and self.config.calc_obstat_prob != 0 and
|
||||
self.rs.rand() < self.config.calc_obstat_prob):
|
||||
rollout_rews, rollout_len, obs = self.policy.rollout(
|
||||
self.env, timestep_limit=timestep_limit, save_obs=True,
|
||||
random_stream=self.rs)
|
||||
task_ob_stat.increment(obs.sum(axis=0), np.square(obs).sum(axis=0),
|
||||
len(obs))
|
||||
else:
|
||||
rollout_rews, rollout_len = self.policy.rollout(
|
||||
self.env, timestep_limit=timestep_limit, random_stream=self.rs)
|
||||
return rollout_rews, rollout_len
|
||||
|
||||
def do_rollouts(self, params, ob_mean, ob_std, timestep_limit=None):
|
||||
# Set the network weights.
|
||||
self.policy.set_trainable_flat(params)
|
||||
|
||||
if self.policy.needs_ob_stat:
|
||||
self.policy.set_ob_stat(ob_mean, ob_std)
|
||||
|
||||
if self.config.eval_prob != 0:
|
||||
raise NotImplementedError("Eval rollouts are not implemented.")
|
||||
|
||||
noise_inds, returns, sign_returns, lengths = [], [], [], []
|
||||
# We set eps=0 because we're incrementing only.
|
||||
task_ob_stat = utils.RunningStat(self.env.observation_space.shape, eps=0)
|
||||
|
||||
# Perform some rollouts with noise.
|
||||
while (len(noise_inds) == 0 or
|
||||
time.time() - task_tstart < self.min_task_runtime):
|
||||
noise_idx = self.noise.sample_index(self.rs, self.policy.num_params)
|
||||
perturbation = self.config.noise_stdev * self.noise.get(
|
||||
noise_idx, self.policy.num_params)
|
||||
|
||||
# These two sampling steps could be done in parallel on different actors
|
||||
# letting us update twice as frequently.
|
||||
self.policy.set_trainable_flat(params + perturbation)
|
||||
rews_pos, len_pos = self.rollout_and_update_ob_stat(timestep_limit,
|
||||
task_ob_stat)
|
||||
|
||||
self.policy.set_trainable_flat(params - perturbation)
|
||||
rews_neg, len_neg = self.rollout_and_update_ob_stat(timestep_limit,
|
||||
task_ob_stat)
|
||||
|
||||
noise_inds.append(noise_idx)
|
||||
returns.append([rews_pos.sum(), rews_neg.sum()])
|
||||
sign_returns.append([np.sign(rews_pos).sum(), np.sign(rews_neg).sum()])
|
||||
lengths.append([len_pos, len_neg])
|
||||
|
||||
return Result(
|
||||
noise_inds_n=np.array(noise_inds),
|
||||
returns_n2=np.array(returns, dtype=np.float32),
|
||||
sign_returns_n2=np.array(sign_returns, dtype=np.float32),
|
||||
lengths_n2=np.array(lengths, dtype=np.int32),
|
||||
eval_return=None,
|
||||
eval_length=None,
|
||||
ob_sum=(None if task_ob_stat.count == 0 else task_ob_stat.sum),
|
||||
ob_sumsq=(None if task_ob_stat.count == 0 else task_ob_stat.sumsq),
|
||||
ob_count=task_ob_stat.count)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Train an RL agent on Pong.")
|
||||
parser.add_argument("--num-workers", default=10, type=int,
|
||||
help=("The number of actors to create in aggregate "
|
||||
"across the cluster."))
|
||||
parser.add_argument("--env-name", default="Pendulum-v0", type=str,
|
||||
help="The name of the gym environment to use.")
|
||||
parser.add_argument("--stepsize", default=0.01, type=float,
|
||||
help="The stepsize to use.")
|
||||
parser.add_argument("--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
|
||||
args = parser.parse_args()
|
||||
num_workers = args.num_workers
|
||||
env_name = args.env_name
|
||||
stepsize = args.stepsize
|
||||
|
||||
ray.init(redis_address=args.redis_address,
|
||||
num_workers=(0 if args.redis_address is None else None))
|
||||
|
||||
# Tell Ray to serialize Config and Result objects.
|
||||
ray.register_class(Config)
|
||||
ray.register_class(Result)
|
||||
|
||||
config = Config(l2coeff=0.005,
|
||||
noise_stdev=0.02,
|
||||
episodes_per_batch=10000,
|
||||
timesteps_per_batch=100000,
|
||||
calc_obstat_prob=0.01,
|
||||
eval_prob=0,
|
||||
snapshot_freq=20,
|
||||
return_proc_mode="centered_rank",
|
||||
episode_cutoff_mode="env_default")
|
||||
|
||||
policy_params = {
|
||||
"ac_bins": "continuous:",
|
||||
"ac_noise_std": 0.01,
|
||||
"nonlin_type": "tanh",
|
||||
"hidden_dims": [256, 256],
|
||||
"connection_type": "ff"
|
||||
}
|
||||
|
||||
# Create the shared noise table.
|
||||
print("Creating shared noise table.")
|
||||
noise_id = create_shared_noise.remote()
|
||||
noise = SharedNoiseTable(ray.get(noise_id))
|
||||
|
||||
# Create the actors.
|
||||
print("Creating actors.")
|
||||
workers = [Worker.remote(config, policy_params, env_name, noise_id)
|
||||
for _ in range(num_workers)]
|
||||
|
||||
env = gym.make(env_name)
|
||||
sess = utils.make_session(single_threaded=False)
|
||||
policy = policies.MujocoPolicy(env.observation_space, env.action_space,
|
||||
**policy_params)
|
||||
tf_util.initialize()
|
||||
optimizer = optimizers.Adam(policy, stepsize)
|
||||
|
||||
ob_stat = utils.RunningStat(env.observation_space.shape, eps=1e-2)
|
||||
|
||||
episodes_so_far = 0
|
||||
timesteps_so_far = 0
|
||||
tstart = time.time()
|
||||
|
||||
while True:
|
||||
step_tstart = time.time()
|
||||
theta = policy.get_trainable_flat()
|
||||
assert theta.dtype == np.float32
|
||||
|
||||
# Put the current policy weights in the object store.
|
||||
theta_id = ray.put(theta)
|
||||
# Use the actors to do rollouts, note that we pass in the ID of the policy
|
||||
# weights.
|
||||
rollout_ids = [worker.do_rollouts.remote(
|
||||
theta_id,
|
||||
ob_stat.mean if policy.needs_ob_stat else None,
|
||||
ob_stat.std if policy.needs_ob_stat else None)
|
||||
for worker in workers]
|
||||
|
||||
# Get the results of the rollouts.
|
||||
results = ray.get(rollout_ids)
|
||||
|
||||
curr_task_results = []
|
||||
ob_count_this_batch = 0
|
||||
# Loop over the results
|
||||
for result in results:
|
||||
assert result.eval_length is None, "We aren't doing eval rollouts."
|
||||
assert result.noise_inds_n.ndim == 1
|
||||
assert result.returns_n2.shape == (len(result.noise_inds_n), 2)
|
||||
assert result.lengths_n2.shape == (len(result.noise_inds_n), 2)
|
||||
assert result.returns_n2.dtype == np.float32
|
||||
|
||||
result_num_eps = result.lengths_n2.size
|
||||
result_num_timesteps = result.lengths_n2.sum()
|
||||
episodes_so_far += result_num_eps
|
||||
timesteps_so_far += result_num_timesteps
|
||||
|
||||
curr_task_results.append(result)
|
||||
# Update ob stats.
|
||||
if policy.needs_ob_stat and result.ob_count > 0:
|
||||
ob_stat.increment(result.ob_sum, result.ob_sumsq, result.ob_count)
|
||||
ob_count_this_batch += result.ob_count
|
||||
|
||||
# Assemble the results.
|
||||
noise_inds_n = np.concatenate([r.noise_inds_n for
|
||||
r in curr_task_results])
|
||||
returns_n2 = np.concatenate([r.returns_n2 for r in curr_task_results])
|
||||
lengths_n2 = np.concatenate([r.lengths_n2 for r in curr_task_results])
|
||||
assert noise_inds_n.shape[0] == returns_n2.shape[0] == lengths_n2.shape[0]
|
||||
# Process the returns.
|
||||
if config.return_proc_mode == "centered_rank":
|
||||
proc_returns_n2 = utils.compute_centered_ranks(returns_n2)
|
||||
else:
|
||||
raise NotImplementedError(config.return_proc_mode)
|
||||
|
||||
# Compute and take a step.
|
||||
g, count = utils.batched_weighted_sum(
|
||||
proc_returns_n2[:, 0] - proc_returns_n2[:, 1],
|
||||
(noise.get(idx, policy.num_params) for idx in noise_inds_n),
|
||||
batch_size=500)
|
||||
g /= returns_n2.size
|
||||
assert (g.shape == (policy.num_params,) and g.dtype == np.float32 and
|
||||
count == len(noise_inds_n))
|
||||
update_ratio = optimizer.update(-g + config.l2coeff * theta)
|
||||
|
||||
# Update ob stat (we're never running the policy in the master, but we
|
||||
# might be snapshotting the policy).
|
||||
if policy.needs_ob_stat:
|
||||
policy.set_ob_stat(ob_stat.mean, ob_stat.std)
|
||||
|
||||
step_tend = time.time()
|
||||
tlogger.record_tabular("EpRewMean", returns_n2.mean())
|
||||
tlogger.record_tabular("EpRewStd", returns_n2.std())
|
||||
tlogger.record_tabular("EpLenMean", lengths_n2.mean())
|
||||
|
||||
tlogger.record_tabular("Norm",
|
||||
float(np.square(policy.get_trainable_flat()).sum()))
|
||||
tlogger.record_tabular("GradNorm", float(np.square(g).sum()))
|
||||
tlogger.record_tabular("UpdateRatio", float(update_ratio))
|
||||
|
||||
tlogger.record_tabular("EpisodesThisIter", lengths_n2.size)
|
||||
tlogger.record_tabular("EpisodesSoFar", episodes_so_far)
|
||||
tlogger.record_tabular("TimestepsThisIter", lengths_n2.sum())
|
||||
tlogger.record_tabular("TimestepsSoFar", timesteps_so_far)
|
||||
|
||||
tlogger.record_tabular("ObCount", ob_count_this_batch)
|
||||
|
||||
tlogger.record_tabular("TimeElapsedThisIter", step_tend - step_tstart)
|
||||
tlogger.record_tabular("TimeElapsed", step_tend - tstart)
|
||||
tlogger.dump_tabular()
|
57
examples/evolution_strategies/optimizers.py
Normal file
57
examples/evolution_strategies/optimizers.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
# Code in this file is copied and adapted from
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Optimizer(object):
|
||||
def __init__(self, pi):
|
||||
self.pi = pi
|
||||
self.dim = pi.num_params
|
||||
self.t = 0
|
||||
|
||||
def update(self, globalg):
|
||||
self.t += 1
|
||||
step = self._compute_step(globalg)
|
||||
theta = self.pi.get_trainable_flat()
|
||||
ratio = np.linalg.norm(step) / np.linalg.norm(theta)
|
||||
self.pi.set_trainable_flat(theta + step)
|
||||
return ratio
|
||||
|
||||
def _compute_step(self, globalg):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, pi, stepsize, momentum=0.9):
|
||||
Optimizer.__init__(self, pi)
|
||||
self.v = np.zeros(self.dim, dtype=np.float32)
|
||||
self.stepsize, self.momentum = stepsize, momentum
|
||||
|
||||
def _compute_step(self, globalg):
|
||||
self.v = self.momentum * self.v + (1. - self.momentum) * globalg
|
||||
step = -self.stepsize * self.v
|
||||
return step
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, pi, stepsize, beta1=0.9, beta2=0.999, epsilon=1e-08):
|
||||
Optimizer.__init__(self, pi)
|
||||
self.stepsize = stepsize
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
self.epsilon = epsilon
|
||||
self.m = np.zeros(self.dim, dtype=np.float32)
|
||||
self.v = np.zeros(self.dim, dtype=np.float32)
|
||||
|
||||
def _compute_step(self, globalg):
|
||||
a = self.stepsize * (np.sqrt(1 - self.beta2 ** self.t) /
|
||||
(1 - self.beta1 ** self.t))
|
||||
self.m = self.beta1 * self.m + (1 - self.beta1) * globalg
|
||||
self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg)
|
||||
step = -a * self.m / (np.sqrt(self.v) + self.epsilon)
|
||||
return step
|
251
examples/evolution_strategies/policies.py
Normal file
251
examples/evolution_strategies/policies.py
Normal file
|
@ -0,0 +1,251 @@
|
|||
# Code in this file is copied and adapted from
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import tf_util as U
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Policy:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args, self.kwargs = args, kwargs
|
||||
self.scope = self._initialize(*args, **kwargs)
|
||||
self.all_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope.name)
|
||||
|
||||
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope.name)
|
||||
self.num_params = sum(int(np.prod(v.get_shape().as_list())) for v in self.trainable_variables)
|
||||
self._setfromflat = U.SetFromFlat(self.trainable_variables)
|
||||
self._getflat = U.GetFlat(self.trainable_variables)
|
||||
|
||||
logger.info('Trainable variables ({} parameters)'.format(self.num_params))
|
||||
for v in self.trainable_variables:
|
||||
shp = v.get_shape().as_list()
|
||||
logger.info('- {} shape:{} size:{}'.format(v.name, shp, np.prod(shp)))
|
||||
logger.info('All variables')
|
||||
for v in self.all_variables:
|
||||
shp = v.get_shape().as_list()
|
||||
logger.info('- {} shape:{} size:{}'.format(v.name, shp, np.prod(shp)))
|
||||
|
||||
placeholders = [tf.placeholder(v.value().dtype, v.get_shape().as_list()) for v in self.all_variables]
|
||||
self.set_all_vars = U.function(
|
||||
inputs=placeholders,
|
||||
outputs=[],
|
||||
updates=[tf.group(*[v.assign(p) for v, p in zip(self.all_variables, placeholders)])]
|
||||
)
|
||||
|
||||
def _initialize(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, filename):
|
||||
assert filename.endswith('.h5')
|
||||
with h5py.File(filename, 'w') as f:
|
||||
for v in self.all_variables:
|
||||
f[v.name] = v.eval()
|
||||
# TODO: it would be nice to avoid pickle, but it's convenient to pass Python objects to _initialize
|
||||
# (like Gym spaces or numpy arrays)
|
||||
f.attrs['name'] = type(self).__name__
|
||||
f.attrs['args_and_kwargs'] = np.void(pickle.dumps((self.args, self.kwargs), protocol=-1))
|
||||
|
||||
@classmethod
|
||||
def Load(cls, filename, extra_kwargs=None):
|
||||
with h5py.File(filename, 'r') as f:
|
||||
args, kwargs = pickle.loads(f.attrs['args_and_kwargs'].tostring())
|
||||
if extra_kwargs:
|
||||
kwargs.update(extra_kwargs)
|
||||
policy = cls(*args, **kwargs)
|
||||
policy.set_all_vars(*[f[v.name][...] for v in policy.all_variables])
|
||||
return policy
|
||||
|
||||
# === Rollouts/training ===
|
||||
|
||||
def rollout(self, env, *, render=False, timestep_limit=None, save_obs=False, random_stream=None):
|
||||
"""
|
||||
If random_stream is provided, the rollout will take noisy actions with noise drawn from that stream.
|
||||
Otherwise, no action noise will be added.
|
||||
"""
|
||||
env_timestep_limit = env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps')
|
||||
timestep_limit = env_timestep_limit if timestep_limit is None else min(timestep_limit, env_timestep_limit)
|
||||
rews = []
|
||||
t = 0
|
||||
if save_obs:
|
||||
obs = []
|
||||
ob = env.reset()
|
||||
for _ in range(timestep_limit):
|
||||
ac = self.act(ob[None], random_stream=random_stream)[0]
|
||||
if save_obs:
|
||||
obs.append(ob)
|
||||
ob, rew, done, _ = env.step(ac)
|
||||
rews.append(rew)
|
||||
t += 1
|
||||
if render:
|
||||
env.render()
|
||||
if done:
|
||||
break
|
||||
rews = np.array(rews, dtype=np.float32)
|
||||
if save_obs:
|
||||
return rews, t, np.array(obs)
|
||||
return rews, t
|
||||
|
||||
def act(self, ob, random_stream=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_trainable_flat(self, x):
|
||||
self._setfromflat(x)
|
||||
|
||||
def get_trainable_flat(self):
|
||||
return self._getflat()
|
||||
|
||||
@property
|
||||
def needs_ob_stat(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_ob_stat(self, ob_mean, ob_std):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def bins(x, dim, num_bins, name):
|
||||
scores = U.dense(x, dim * num_bins, name, U.normc_initializer(0.01))
|
||||
scores_nab = tf.reshape(scores, [-1, dim, num_bins])
|
||||
return tf.argmax(scores_nab, 2) # 0 ... num_bins-1
|
||||
|
||||
|
||||
class MujocoPolicy(Policy):
|
||||
def _initialize(self, ob_space, ac_space, ac_bins, ac_noise_std, nonlin_type, hidden_dims, connection_type):
|
||||
self.ac_space = ac_space
|
||||
self.ac_bins = ac_bins
|
||||
self.ac_noise_std = ac_noise_std
|
||||
self.hidden_dims = hidden_dims
|
||||
self.connection_type = connection_type
|
||||
|
||||
assert len(ob_space.shape) == len(self.ac_space.shape) == 1
|
||||
assert np.all(np.isfinite(self.ac_space.low)) and np.all(np.isfinite(self.ac_space.high)), \
|
||||
'Action bounds required'
|
||||
|
||||
self.nonlin = {'tanh': tf.tanh, 'relu': tf.nn.relu, 'lrelu': U.lrelu, 'elu': tf.nn.elu}[nonlin_type]
|
||||
|
||||
with tf.variable_scope(type(self).__name__) as scope:
|
||||
# Observation normalization
|
||||
ob_mean = tf.get_variable(
|
||||
'ob_mean', ob_space.shape, tf.float32, tf.constant_initializer(np.nan), trainable=False)
|
||||
ob_std = tf.get_variable(
|
||||
'ob_std', ob_space.shape, tf.float32, tf.constant_initializer(np.nan), trainable=False)
|
||||
in_mean = tf.placeholder(tf.float32, ob_space.shape)
|
||||
in_std = tf.placeholder(tf.float32, ob_space.shape)
|
||||
self._set_ob_mean_std = U.function([in_mean, in_std], [], updates=[
|
||||
tf.assign(ob_mean, in_mean),
|
||||
tf.assign(ob_std, in_std),
|
||||
])
|
||||
|
||||
# Policy network
|
||||
o = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
|
||||
a = self._make_net(tf.clip_by_value((o - ob_mean) / ob_std, -5.0, 5.0))
|
||||
self._act = U.function([o], a)
|
||||
return scope
|
||||
|
||||
def _make_net(self, o):
|
||||
# Process observation
|
||||
if self.connection_type == 'ff':
|
||||
x = o
|
||||
for ilayer, hd in enumerate(self.hidden_dims):
|
||||
x = self.nonlin(U.dense(x, hd, 'l{}'.format(ilayer), U.normc_initializer(1.0)))
|
||||
else:
|
||||
raise NotImplementedError(self.connection_type)
|
||||
|
||||
# Map to action
|
||||
adim, ahigh, alow = self.ac_space.shape[0], self.ac_space.high, self.ac_space.low
|
||||
assert isinstance(self.ac_bins, str)
|
||||
ac_bin_mode, ac_bin_arg = self.ac_bins.split(':')
|
||||
|
||||
if ac_bin_mode == 'uniform':
|
||||
# Uniformly spaced bins, from ac_space.low to ac_space.high
|
||||
num_ac_bins = int(ac_bin_arg)
|
||||
aidx_na = bins(x, adim, num_ac_bins, 'out') # 0 ... num_ac_bins-1
|
||||
ac_range_1a = (ahigh - alow)[None, :]
|
||||
a = 1. / (num_ac_bins - 1.) * tf.to_float(aidx_na) * ac_range_1a + alow[None, :]
|
||||
|
||||
elif ac_bin_mode == 'custom':
|
||||
# Custom bins specified as a list of values from -1 to 1
|
||||
# The bins are rescaled to ac_space.low to ac_space.high
|
||||
acvals_k = np.array(list(map(float, ac_bin_arg.split(','))), dtype=np.float32)
|
||||
logger.info('Custom action values: ' + ' '.join('{:.3f}'.format(x) for x in acvals_k))
|
||||
assert acvals_k.ndim == 1 and acvals_k[0] == -1 and acvals_k[-1] == 1
|
||||
acvals_ak = (
|
||||
(ahigh - alow)[:, None] / (acvals_k[-1] - acvals_k[0]) * (acvals_k - acvals_k[0])[None, :]
|
||||
+ alow[:, None]
|
||||
)
|
||||
|
||||
aidx_na = bins(x, adim, len(acvals_k), 'out') # values in [0, k-1]
|
||||
a = tf.gather_nd(
|
||||
acvals_ak,
|
||||
tf.concat([
|
||||
tf.tile(np.arange(adim)[None, :, None], [tf.shape(aidx_na)[0], 1, 1]),
|
||||
2,
|
||||
tf.expand_dims(aidx_na, -1)
|
||||
]) # (n,a,2)
|
||||
) # (n,a)
|
||||
elif ac_bin_mode == 'continuous':
|
||||
a = U.dense(x, adim, 'out', U.normc_initializer(0.01))
|
||||
else:
|
||||
raise NotImplementedError(ac_bin_mode)
|
||||
|
||||
return a
|
||||
|
||||
def act(self, ob, random_stream=None):
|
||||
a = self._act(ob)
|
||||
if random_stream is not None and self.ac_noise_std != 0:
|
||||
a += random_stream.randn(*a.shape) * self.ac_noise_std
|
||||
return a
|
||||
|
||||
@property
|
||||
def needs_ob_stat(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def needs_ref_batch(self):
|
||||
return False
|
||||
|
||||
def set_ob_stat(self, ob_mean, ob_std):
|
||||
self._set_ob_mean_std(ob_mean, ob_std)
|
||||
|
||||
def initialize_from(self, filename, ob_stat=None):
|
||||
"""
|
||||
Initializes weights from another policy, which must have the same architecture (variable names),
|
||||
but the weight arrays can be smaller than the current policy.
|
||||
"""
|
||||
with h5py.File(filename, 'r') as f:
|
||||
f_var_names = []
|
||||
f.visititems(lambda name, obj: f_var_names.append(name) if isinstance(obj, h5py.Dataset) else None)
|
||||
assert set(v.name for v in self.all_variables) == set(f_var_names), 'Variable names do not match'
|
||||
|
||||
init_vals = []
|
||||
for v in self.all_variables:
|
||||
shp = v.get_shape().as_list()
|
||||
f_shp = f[v.name].shape
|
||||
assert len(shp) == len(f_shp) and all(a >= b for a, b in zip(shp, f_shp)), \
|
||||
'This policy must have more weights than the policy to load'
|
||||
init_val = v.eval()
|
||||
# ob_mean and ob_std are initialized with nan, so set them manually
|
||||
if 'ob_mean' in v.name:
|
||||
init_val[:] = 0
|
||||
init_mean = init_val
|
||||
elif 'ob_std' in v.name:
|
||||
init_val[:] = 0.001
|
||||
init_std = init_val
|
||||
# Fill in subarray from the loaded policy
|
||||
init_val[tuple([np.s_[:s] for s in f_shp])] = f[v.name]
|
||||
init_vals.append(init_val)
|
||||
self.set_all_vars(*init_vals)
|
||||
|
||||
if ob_stat is not None:
|
||||
ob_stat.set_from_init(init_mean, init_std, init_count=1e5)
|
194
examples/evolution_strategies/tabular_logger.py
Normal file
194
examples/evolution_strategies/tabular_logger.py
Normal file
|
@ -0,0 +1,194 @@
|
|||
# Code in this file is copied and adapted from
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
WARN = 30
|
||||
ERROR = 40
|
||||
|
||||
DISABLED = 50
|
||||
|
||||
class TbWriter(object):
|
||||
"""
|
||||
Based on SummaryWriter, but changed to allow for a different prefix
|
||||
and to get rid of multithreading
|
||||
oops, ended up using the same prefix anyway.
|
||||
"""
|
||||
def __init__(self, dir, prefix):
|
||||
self.dir = dir
|
||||
self.step = 1 # Start at 1, because EvWriter automatically generates an object with step=0
|
||||
self.evwriter = pywrap_tensorflow.EventsWriter(compat.as_bytes(os.path.join(dir, prefix)))
|
||||
def write_values(self, key2val):
|
||||
summary = tf.Summary(value=[tf.Summary.Value(tag=k, simple_value=float(v))
|
||||
for (k, v) in key2val.items()])
|
||||
event = event_pb2.Event(wall_time=time.time(), summary=summary)
|
||||
event.step = self.step # is there any reason why you'd want to specify the step?
|
||||
self.evwriter.WriteEvent(event)
|
||||
self.evwriter.Flush()
|
||||
self.step += 1
|
||||
def close(self):
|
||||
self.evwriter.Close()
|
||||
|
||||
# ================================================================
|
||||
# API
|
||||
# ================================================================
|
||||
|
||||
def start(dir):
|
||||
"""
|
||||
dir: directory to put all output files
|
||||
force: if dir already exists, should we delete it, or throw a RuntimeError?
|
||||
"""
|
||||
if _Logger.CURRENT is not _Logger.DEFAULT:
|
||||
sys.stderr.write("WARNING: You asked to start logging (dir=%s), but you never stopped the previous logger (dir=%s).\n"%(dir, _Logger.CURRENT.dir))
|
||||
_Logger.CURRENT = _Logger(dir=dir)
|
||||
|
||||
def stop():
|
||||
if _Logger.CURRENT is _Logger.DEFAULT:
|
||||
sys.stderr.write("WARNING: You asked to stop logging, but you never started any previous logger.\n"%(dir, _Logger.CURRENT.dir))
|
||||
return
|
||||
_Logger.CURRENT.close()
|
||||
_Logger.CURRENT = _Logger.DEFAULT
|
||||
|
||||
def record_tabular(key, val):
|
||||
"""
|
||||
Log a value of some diagnostic
|
||||
Call this once for each diagnostic quantity, each iteration
|
||||
"""
|
||||
_Logger.CURRENT.record_tabular(key, val)
|
||||
|
||||
def dump_tabular():
|
||||
"""
|
||||
Write all of the diagnostics from the current iteration
|
||||
|
||||
level: int. (see logger.py docs) If the global logger level is higher than
|
||||
the level argument here, don't print to stdout.
|
||||
"""
|
||||
_Logger.CURRENT.dump_tabular()
|
||||
|
||||
def log(*args, level=INFO):
|
||||
"""
|
||||
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
|
||||
"""
|
||||
_Logger.CURRENT.log(*args, level=level)
|
||||
|
||||
def debug(*args):
|
||||
log(*args, level=DEBUG)
|
||||
def info(*args):
|
||||
log(*args, level=INFO)
|
||||
def warn(*args):
|
||||
log(*args, level=WARN)
|
||||
def error(*args):
|
||||
log(*args, level=ERROR)
|
||||
|
||||
def set_level(level):
|
||||
"""
|
||||
Set logging threshold on current logger.
|
||||
"""
|
||||
_Logger.CURRENT.set_level(level)
|
||||
|
||||
def get_dir():
|
||||
"""
|
||||
Get directory that log files are being written to.
|
||||
will be None if there is no output directory (i.e., if you didn't call start)
|
||||
"""
|
||||
return _Logger.CURRENT.get_dir()
|
||||
|
||||
def get_expt_dir():
|
||||
sys.stderr.write("get_expt_dir() is Deprecated. Switch to get_dir()\n")
|
||||
return get_dir()
|
||||
|
||||
# ================================================================
|
||||
# Backend
|
||||
# ================================================================
|
||||
|
||||
class _Logger(object):
|
||||
DEFAULT = None # A logger with no output files. (See right below class definition)
|
||||
# So that you can still log to the terminal without setting up any output files
|
||||
CURRENT = None # Current logger being used by the free functions above
|
||||
|
||||
def __init__(self, dir=None):
|
||||
self.name2val = OrderedDict() # values this iteration
|
||||
self.level = INFO
|
||||
self.dir = dir
|
||||
self.text_outputs = [sys.stdout]
|
||||
if dir is not None:
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
self.text_outputs.append(open(os.path.join(dir, "log.txt"), "w"))
|
||||
self.tbwriter = TbWriter(dir=dir, prefix="events")
|
||||
else:
|
||||
self.tbwriter = None
|
||||
|
||||
# Logging API, forwarded
|
||||
# ----------------------------------------
|
||||
def record_tabular(self, key, val):
|
||||
self.name2val[key] = val
|
||||
def dump_tabular(self):
|
||||
# Create strings for printing
|
||||
key2str = OrderedDict()
|
||||
for (key,val) in self.name2val.items():
|
||||
if hasattr(val, "__float__"): valstr = "%-8.3g"%val
|
||||
else: valstr = val
|
||||
key2str[self._truncate(key)]=self._truncate(valstr)
|
||||
keywidth = max(map(len, key2str.keys()))
|
||||
valwidth = max(map(len, key2str.values()))
|
||||
# Write to all text outputs
|
||||
self._write_text("-"*(keywidth+valwidth+7), "\n")
|
||||
for (key,val) in key2str.items():
|
||||
self._write_text("| ", key, " "*(keywidth-len(key)), " | ", val, " "*(valwidth-len(val)), " |\n")
|
||||
self._write_text("-"*(keywidth+valwidth+7), "\n")
|
||||
for f in self.text_outputs:
|
||||
try: f.flush()
|
||||
except OSError: sys.stderr.write('Warning! OSError when flushing.\n')
|
||||
# Write to tensorboard
|
||||
if self.tbwriter is not None:
|
||||
self.tbwriter.write_values(self.name2val)
|
||||
self.name2val.clear()
|
||||
def log(self, *args, level=INFO):
|
||||
if self.level <= level:
|
||||
self._do_log(*args)
|
||||
|
||||
# Configuration
|
||||
# ----------------------------------------
|
||||
def set_level(self, level):
|
||||
self.level = level
|
||||
def get_dir(self):
|
||||
return self.dir
|
||||
|
||||
def close(self):
|
||||
for f in self.text_outputs[1:]: f.close()
|
||||
if self.tbwriter: self.tbwriter.close()
|
||||
|
||||
# Misc
|
||||
# ----------------------------------------
|
||||
def _do_log(self, *args):
|
||||
self._write_text(*args, '\n')
|
||||
for f in self.text_outputs:
|
||||
try: f.flush()
|
||||
except OSError: print('Warning! OSError when flushing.')
|
||||
def _write_text(self, *strings):
|
||||
for f in self.text_outputs:
|
||||
for string in strings:
|
||||
f.write(string)
|
||||
def _truncate(self, s):
|
||||
if len(s) > 33:
|
||||
return s[:30] + "..."
|
||||
else:
|
||||
return s
|
||||
|
||||
_Logger.DEFAULT = _Logger()
|
||||
_Logger.CURRENT = _Logger.DEFAULT
|
260
examples/evolution_strategies/tf_util.py
Normal file
260
examples/evolution_strategies/tf_util.py
Normal file
|
@ -0,0 +1,260 @@
|
|||
# Code in this file is copied and adapted from
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import builtins
|
||||
import functools
|
||||
import copy
|
||||
import os
|
||||
|
||||
# ================================================================
|
||||
# Import all names into common namespace
|
||||
# ================================================================
|
||||
|
||||
clip = tf.clip_by_value
|
||||
|
||||
# Make consistent with numpy
|
||||
# ----------------------------------------
|
||||
|
||||
def sum(x, axis=None, keepdims=False):
|
||||
return tf.reduce_sum(x, reduction_indices=None if axis is None else [axis], keep_dims = keepdims)
|
||||
def mean(x, axis=None, keepdims=False):
|
||||
return tf.reduce_mean(x, reduction_indices=None if axis is None else [axis], keep_dims = keepdims)
|
||||
def var(x, axis=None, keepdims=False):
|
||||
meanx = mean(x, axis=axis, keepdims=keepdims)
|
||||
return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims)
|
||||
def std(x, axis=None, keepdims=False):
|
||||
return tf.sqrt(var(x, axis=axis, keepdims=keepdims))
|
||||
def max(x, axis=None, keepdims=False):
|
||||
return tf.reduce_max(x, reduction_indices=None if axis is None else [axis], keep_dims = keepdims)
|
||||
def min(x, axis=None, keepdims=False):
|
||||
return tf.reduce_min(x, reduction_indices=None if axis is None else [axis], keep_dims = keepdims)
|
||||
def concatenate(arrs, axis=0):
|
||||
return tf.concat(arrs, axis)
|
||||
def argmax(x, axis=None):
|
||||
return tf.argmax(x, dimension=axis)
|
||||
|
||||
def switch(condition, then_expression, else_expression):
|
||||
'''Switches between two operations depending on a scalar value (int or bool).
|
||||
Note that both `then_expression` and `else_expression`
|
||||
should be symbolic tensors of the *same shape*.
|
||||
|
||||
# Arguments
|
||||
condition: scalar tensor.
|
||||
then_expression: TensorFlow operation.
|
||||
else_expression: TensorFlow operation.
|
||||
'''
|
||||
x_shape = copy.copy(then_expression.get_shape())
|
||||
x = tf.cond(tf.cast(condition, 'bool'),
|
||||
lambda: then_expression,
|
||||
lambda: else_expression)
|
||||
x.set_shape(x_shape)
|
||||
return x
|
||||
|
||||
# Extras
|
||||
# ----------------------------------------
|
||||
def l2loss(params):
|
||||
if len(params) == 0:
|
||||
return tf.constant(0.0)
|
||||
else:
|
||||
return tf.add_n([sum(tf.square(p)) for p in params])
|
||||
def lrelu(x, leak=0.2):
|
||||
f1 = 0.5 * (1 + leak)
|
||||
f2 = 0.5 * (1 - leak)
|
||||
return f1 * x + f2 * abs(x)
|
||||
def categorical_sample_logits(X):
|
||||
# https://github.com/tensorflow/tensorflow/issues/456
|
||||
U = tf.random_uniform(tf.shape(X))
|
||||
return argmax(X - tf.log(-tf.log(U)), axis=1)
|
||||
|
||||
# ================================================================
|
||||
# Global session
|
||||
# ================================================================
|
||||
|
||||
def get_session():
|
||||
return tf.get_default_session()
|
||||
|
||||
def single_threaded_session():
|
||||
tf_config = tf.ConfigProto(
|
||||
inter_op_parallelism_threads=1,
|
||||
intra_op_parallelism_threads=1)
|
||||
return tf.Session(config=tf_config)
|
||||
|
||||
ALREADY_INITIALIZED = set()
|
||||
def initialize():
|
||||
new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
|
||||
get_session().run(tf.variables_initializer(new_variables))
|
||||
ALREADY_INITIALIZED.update(new_variables)
|
||||
|
||||
|
||||
def eval(expr, feed_dict=None):
|
||||
if feed_dict is None: feed_dict = {}
|
||||
return get_session().run(expr, feed_dict=feed_dict)
|
||||
|
||||
def set_value(v, val):
|
||||
get_session().run(v.assign(val))
|
||||
|
||||
def load_state(fname):
|
||||
saver = tf.train.Saver()
|
||||
saver.restore(get_session(), fname)
|
||||
|
||||
def save_state(fname):
|
||||
os.makedirs(os.path.dirname(fname), exist_ok=True)
|
||||
saver = tf.train.Saver()
|
||||
saver.save(get_session(), fname)
|
||||
|
||||
# ================================================================
|
||||
# Model components
|
||||
# ================================================================
|
||||
|
||||
|
||||
def normc_initializer(std=1.0):
|
||||
def _initializer(shape, dtype=None, partition_info=None): #pylint: disable=W0613
|
||||
out = np.random.randn(*shape).astype(np.float32)
|
||||
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
|
||||
return tf.constant(out)
|
||||
return _initializer
|
||||
|
||||
def dense(x, size, name, weight_init=None, bias=True):
|
||||
w = tf.get_variable(name + "/w", [x.get_shape()[1], size], initializer=weight_init)
|
||||
ret = tf.matmul(x, w)
|
||||
if bias:
|
||||
b = tf.get_variable(name + "/b", [size], initializer=tf.zeros_initializer())
|
||||
return ret + b
|
||||
else:
|
||||
return ret
|
||||
|
||||
# ================================================================
|
||||
# Basic Stuff
|
||||
# ================================================================
|
||||
|
||||
def function(inputs, outputs, updates=None, givens=None):
|
||||
if isinstance(outputs, list):
|
||||
return _Function(inputs, outputs, updates, givens=givens)
|
||||
elif isinstance(outputs, dict):
|
||||
f = _Function(inputs, outputs.values(), updates, givens=givens)
|
||||
return lambda *inputs : dict(zip(outputs.keys(), f(*inputs)))
|
||||
else:
|
||||
f = _Function(inputs, [outputs], updates, givens=givens)
|
||||
return lambda *inputs : f(*inputs)[0]
|
||||
|
||||
class _Function(object):
|
||||
def __init__(self, inputs, outputs, updates, givens, check_nan=False):
|
||||
assert all(len(i.op.inputs)==0 for i in inputs), "inputs should all be placeholders"
|
||||
self.inputs = inputs
|
||||
updates = updates or []
|
||||
self.update_group = tf.group(*updates)
|
||||
self.outputs_update = list(outputs) + [self.update_group]
|
||||
self.givens = {} if givens is None else givens
|
||||
self.check_nan = check_nan
|
||||
def __call__(self, *inputvals):
|
||||
assert len(inputvals) == len(self.inputs)
|
||||
feed_dict = dict(zip(self.inputs, inputvals))
|
||||
feed_dict.update(self.givens)
|
||||
results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
|
||||
if self.check_nan:
|
||||
if any(np.isnan(r).any() for r in results):
|
||||
raise RuntimeError("Nan detected")
|
||||
return results
|
||||
|
||||
# ================================================================
|
||||
# Graph traversal
|
||||
# ================================================================
|
||||
|
||||
VARIABLES = {}
|
||||
|
||||
# ================================================================
|
||||
# Flat vectors
|
||||
# ================================================================
|
||||
|
||||
def var_shape(x):
|
||||
out = [k.value for k in x.get_shape()]
|
||||
assert all(isinstance(a, int) for a in out), \
|
||||
"shape function assumes that shape is fully known"
|
||||
return out
|
||||
|
||||
def numel(x):
|
||||
return intprod(var_shape(x))
|
||||
|
||||
def intprod(x):
|
||||
return int(np.prod(x))
|
||||
|
||||
def flatgrad(loss, var_list):
|
||||
grads = tf.gradients(loss, var_list)
|
||||
return tf.concat([tf.reshape(grad, [numel(v)], 0)
|
||||
for (v, grad) in zip(var_list, grads)])
|
||||
|
||||
class SetFromFlat(object):
|
||||
def __init__(self, var_list, dtype=tf.float32):
|
||||
assigns = []
|
||||
shapes = list(map(var_shape, var_list))
|
||||
total_size = np.sum([intprod(shape) for shape in shapes])
|
||||
|
||||
self.theta = theta = tf.placeholder(dtype,[total_size])
|
||||
start=0
|
||||
assigns = []
|
||||
for (shape,v) in zip(shapes,var_list):
|
||||
size = intprod(shape)
|
||||
assigns.append(tf.assign(v, tf.reshape(theta[start:start+size],shape)))
|
||||
start+=size
|
||||
assert start == total_size
|
||||
self.op = tf.group(*assigns)
|
||||
def __call__(self, theta):
|
||||
get_session().run(self.op, feed_dict={self.theta:theta})
|
||||
|
||||
class GetFlat(object):
|
||||
def __init__(self, var_list):
|
||||
self.op = tf.concat([tf.reshape(v, [numel(v)]) for v in var_list], 0)
|
||||
def __call__(self):
|
||||
return get_session().run(self.op)
|
||||
|
||||
# ================================================================
|
||||
# Misc
|
||||
# ================================================================
|
||||
|
||||
def scope_vars(scope, trainable_only):
|
||||
"""
|
||||
Get variables inside a scope
|
||||
The scope can be specified as a string
|
||||
"""
|
||||
return tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only else tf.GraphKeys.GLOBAL_VARIABLES,
|
||||
scope=scope if isinstance(scope, str) else scope.name
|
||||
)
|
||||
|
||||
def in_session(f):
|
||||
@functools.wraps(f)
|
||||
def newfunc(*args, **kwargs):
|
||||
with tf.Session():
|
||||
f(*args, **kwargs)
|
||||
return newfunc
|
||||
|
||||
|
||||
_PLACEHOLDER_CACHE = {} # name -> (placeholder, dtype, shape)
|
||||
def get_placeholder(name, dtype, shape):
|
||||
print("calling get_placeholder", name)
|
||||
if name in _PLACEHOLDER_CACHE:
|
||||
out, dtype1, shape1 = _PLACEHOLDER_CACHE[name]
|
||||
assert dtype1==dtype and shape1==shape
|
||||
return out
|
||||
else:
|
||||
out = tf.placeholder(dtype=dtype, shape=shape, name=name)
|
||||
_PLACEHOLDER_CACHE[name] = (out,dtype,shape)
|
||||
return out
|
||||
def get_placeholder_cached(name):
|
||||
return _PLACEHOLDER_CACHE[name][0]
|
||||
|
||||
def flattenallbut0(x):
|
||||
return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
|
||||
|
||||
def reset():
|
||||
global _PLACEHOLDER_CACHE
|
||||
global VARIABLES
|
||||
_PLACEHOLDER_CACHE = {}
|
||||
VARIABLES = {}
|
||||
tf.reset_default_graph()
|
86
examples/evolution_strategies/utils.py
Normal file
86
examples/evolution_strategies/utils.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
# Code in this file is copied and adapted from
|
||||
# https://github.com/openai/evolution-strategies-starter.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def compute_ranks(x):
|
||||
"""Returns ranks in [0, len(x))
|
||||
|
||||
Note: This is different from scipy.stats.rankdata, which returns ranks in
|
||||
[1, len(x)].
|
||||
"""
|
||||
assert x.ndim == 1
|
||||
ranks = np.empty(len(x), dtype=int)
|
||||
ranks[x.argsort()] = np.arange(len(x))
|
||||
return ranks
|
||||
|
||||
|
||||
def compute_centered_ranks(x):
|
||||
y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32)
|
||||
y /= (x.size - 1)
|
||||
y -= 0.5
|
||||
return y
|
||||
|
||||
|
||||
def make_session(single_threaded):
|
||||
if not single_threaded:
|
||||
return tf.InteractiveSession()
|
||||
return tf.InteractiveSession(
|
||||
config=tf.ConfigProto(inter_op_parallelism_threads=1,
|
||||
intra_op_parallelism_threads=1))
|
||||
|
||||
|
||||
def itergroups(items, group_size):
|
||||
assert group_size >= 1
|
||||
group = []
|
||||
for x in items:
|
||||
group.append(x)
|
||||
if len(group) == group_size:
|
||||
yield tuple(group)
|
||||
del group[:]
|
||||
if group:
|
||||
yield tuple(group)
|
||||
|
||||
|
||||
def batched_weighted_sum(weights, vecs, batch_size):
|
||||
total = 0
|
||||
num_items_summed = 0
|
||||
for batch_weights, batch_vecs in zip(itergroups(weights, batch_size),
|
||||
itergroups(vecs, batch_size)):
|
||||
assert len(batch_weights) == len(batch_vecs) <= batch_size
|
||||
total += np.dot(np.asarray(batch_weights, dtype=np.float32),
|
||||
np.asarray(batch_vecs, dtype=np.float32))
|
||||
num_items_summed += len(batch_weights)
|
||||
return total, num_items_summed
|
||||
|
||||
|
||||
class RunningStat(object):
|
||||
def __init__(self, shape, eps):
|
||||
self.sum = np.zeros(shape, dtype=np.float32)
|
||||
self.sumsq = np.full(shape, eps, dtype=np.float32)
|
||||
self.count = eps
|
||||
|
||||
def increment(self, s, ssq, c):
|
||||
self.sum += s
|
||||
self.sumsq += ssq
|
||||
self.count += c
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self.sum / self.count
|
||||
|
||||
@property
|
||||
def std(self):
|
||||
return np.sqrt(np.maximum(self.sumsq / self.count - np.square(self.mean),
|
||||
1e-2))
|
||||
|
||||
def set_from_init(self, init_mean, init_std, init_count):
|
||||
self.sum[:] = init_mean * init_count
|
||||
self.sumsq[:] = (np.square(init_mean) + np.square(init_std)) * init_count
|
||||
self.count = init_count
|
Loading…
Add table
Reference in a new issue