Add policy gradient example. (#344)

* add policy gradient example

* fix typos

* Minor changes plus some documentation.

* Minor fixes.
This commit is contained in:
Philipp Moritz 2017-03-07 23:42:44 -08:00 committed by Robert Nishihara
parent 0de57be085
commit 555dcf35a2
17 changed files with 703 additions and 0 deletions

View file

@ -0,0 +1,31 @@
Policy Gradient Methods
=======================
This code shows how to do reinforcement learning with policy gradient methods.
View the `code for this example`_.
To run this example, you will need to install `TensorFlow with GPU support`_ (at
least version ``1.0.0``) and a few other dependencies.
.. code-block:: bash
pip install gym[atari]
pip install tensorflow
Then install the package as follows.
.. code-block:: bash
cd ray/examples/policy_gradient/
python setup.py install
Then you can run the example as follows.
.. code-block:: bash
python ray/examples/policy_gradient/examples/example.py
This will train an agent on an Atari environment.
.. _`TensorFlow with GPU support`: https://www.tensorflow.org/install/
.. _`code for this example`: https://github.com/ray-project/ray/tree/master/examples/policy_gradient

View file

@ -27,6 +27,7 @@ learning and reinforcement learning applications.*
:caption: Examples
example-hyperopt.rst
example-policy-gradient.rst
example-resnet.rst
example-lbfgs.md
example-rl-pong.md

View file

@ -0,0 +1,65 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from reinforce.agent import Agent, RemoteAgent
from reinforce.rollout import collect_samples
from reinforce.utils import iterate, shuffle
config = {"kl_coeff": 0.2,
"num_sgd_iter": 30,
"sgd_stepsize": 5e-5,
"sgd_batchsize": 128,
"entropy_coeff": 0.0,
"clip_param": 0.3,
"kl_target": 0.01,
"timesteps_per_batch": 40000}
ray.init()
mdp_name = "Pong-ram-v3"
agents = [RemoteAgent(mdp_name, 1, config, False) for _ in range(5)]
agent = Agent(mdp_name, 1, config, True)
kl_coeff = config["kl_coeff"]
for j in range(1000):
print("== iteration", j)
weights = agent.get_weights()
[a.load_weights(weights) for a in agents]
trajectory, total_reward, traj_len_mean = collect_samples(agents, config["timesteps_per_batch"], 0.995, 1.0, 2000)
print("total reward is ", total_reward)
print("trajectory length mean is ", traj_len_mean)
print("timesteps: ", trajectory["dones"].shape[0])
trajectory["advantages"] = (trajectory["advantages"] - trajectory["advantages"].mean()) / trajectory["advantages"].std()
print("Computing policy (optimizer='" + agent.optimizer.get_name() + "', iterations=" + str(config["num_sgd_iter"]) + ", stepsize=" + str(config["sgd_stepsize"]) + "):")
names = ["iter", "loss", "kl", "entropy"]
print(("{:>15}" * len(names)).format(*names))
trajectory = shuffle(trajectory)
ppo = agent.ppo
for i in range(config["num_sgd_iter"]):
# Test on current set of rollouts
loss, kl, entropy = agent.sess.run([ppo.loss, ppo.mean_kl, ppo.mean_entropy],
feed_dict={ppo.observations: trajectory["observations"],
ppo.advantages: trajectory["advantages"],
ppo.actions: trajectory["actions"].squeeze(),
ppo.prev_logits: trajectory["logprobs"],
ppo.kl_coeff: kl_coeff})
print("{:>15}{:15.5e}{:15.5e}{:15.5e}".format(i, loss, kl, entropy))
# Run SGD for training on current set of rollouts
for batch in iterate(trajectory, config["sgd_batchsize"]):
agent.sess.run([agent.train_op],
feed_dict={ppo.observations: batch["observations"],
ppo.advantages: batch["advantages"],
ppo.actions: batch["actions"].squeeze(),
ppo.prev_logits: batch["logprobs"],
ppo.kl_coeff: kl_coeff})
if kl > 2.0 * config["kl_target"]:
kl_coeff *= 1.5
elif kl < 0.5 * config["kl_target"]:
kl_coeff *= 0.5
print("kl div = ", kl)
print("kl coeff = ", kl_coeff)

View file

@ -0,0 +1,41 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import os
import ray
from reinforce.env import BatchedEnv
from reinforce.policy import ProximalPolicyLoss
from reinforce.filter import MeanStdFilter
from reinforce.rollout import rollouts, add_advantage_values
class Agent(object):
def __init__(self, name, batchsize, config, use_gpu):
if not use_gpu:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
self.env = BatchedEnv(name, batchsize, preprocessor=None)
self.sess = tf.Session()
self.ppo = ProximalPolicyLoss(self.env.observation_space, self.env.action_space, config, self.sess)
self.optimizer = tf.train.AdamOptimizer(config["sgd_stepsize"])
self.train_op = self.optimizer.minimize(self.ppo.loss)
self.variables = ray.experimental.TensorFlowVariables(self.ppo.loss, self.sess)
self.observation_filter = MeanStdFilter(self.env.observation_space.shape, clip=None)
self.reward_filter = MeanStdFilter((), clip=5.0)
self.sess.run(tf.global_variables_initializer())
def get_weights(self):
return self.variables.get_weights()
def load_weights(self, weights):
self.variables.set_weights(weights)
def compute_trajectory(self, gamma, lam, horizon):
trajectory = rollouts(self.ppo, self.env, horizon, self.observation_filter, self.reward_filter)
add_advantage_values(trajectory, gamma, lam, self.reward_filter)
return trajectory
RemoteAgent = ray.actor(Agent)

View file

@ -0,0 +1,58 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
class Categorical(object):
def __init__(self, logits):
self.logits = logits
def logp(self, x):
return -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
def entropy(self):
a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1], keep_dims=True)
ea0 = tf.exp(a0)
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (tf.log(z0) - a0), reduction_indices=[1])
def kl(self, other):
a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1], keep_dims=True)
a1 = other.logits - tf.reduce_max(other.logits, reduction_indices=[1], keep_dims=True)
ea0 = tf.exp(a0)
ea1 = tf.exp(a1)
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True)
z1 = tf.reduce_sum(ea1, reduction_indices=[1], keep_dims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), reduction_indices=[1])
def sample(self):
return tf.multinomial(self.logits, 1)
class DiagGaussian(object):
def __init__(self, flat):
self.flat = flat
mean, logstd = tf.split(1, 2, flat)
self.mean = mean
self.logstd = logstd
self.std = tf.exp(logstd)
def logp(self, x):
return - 0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std), reduction_indices=[1]) \
- 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) \
- tf.reduce_sum(self.logstd, reduction_indices=[1])
def kl(self, other):
assert isinstance(other, DiagGaussian)
return tf.reduce_sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, reduction_indices=[1])
def entropy(self):
return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), reduction_indices=[1])
def sample(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))

View file

@ -0,0 +1,45 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
import numpy as np
def atari_preprocessor(observation):
"Convert images from (210, 160, 3) to (3, 80, 80) by downsampling."
return (observation[25:-25:2,::2,:][None] - 128.0) / 128.8
def ram_preprocessor(observation):
return (observation - 128.0) / 128.0
class BatchedEnv(object):
"A BatchedEnv holds multiple gym enviroments and performs steps on all of them."
def __init__(self, name, batchsize, preprocessor=None):
self.envs = [gym.make(name) for _ in range(batchsize)]
self.observation_space = self.envs[0].observation_space
self.action_space = self.envs[0].action_space
self.batchsize = batchsize
self.preprocessor = preprocessor if preprocessor else lambda obs: obs[None]
def reset(self):
observations = [self.preprocessor(env.reset()) for env in self.envs]
self.shape = observations[0].shape
self.dones = [False for _ in range(self.batchsize)]
return np.vstack(observations)
def step(self, actions, render=False):
observations = []
rewards = []
for i, action in enumerate(actions):
if self.dones[i]:
observations.append(np.zeros(self.shape))
rewards.append(0.0)
continue
observation, reward, done, info = self.envs[i].step(action)
if render:
self.envs[0].render()
observations.append(self.preprocessor(observation))
rewards.append(reward)
self.dones[i] = done
return np.vstack(observations), np.array(rewards, dtype="float32"), np.array(self.dones)

View file

@ -0,0 +1,136 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
import numpy as np
class NoFilter(object):
def __init__(self):
pass
def __call__(self, x, update=True):
return np.asarray(x)
# http://www.johndcook.com/blog/standard_deviation/
class RunningStat(object):
def __init__(self, shape=None):
self._n = 0
self._M = np.zeros(shape)
self._S = np.zeros(shape)
def push(self, x):
x = np.asarray(x)
# Unvectorized update of the running statistics.
assert x.shape == self._M.shape, "x.shape = {}, self.shape = {}".format(x.shape, self._M.shape)
n1 = self._n
self._n += 1
if self._n == 1:
self._M[...] = x
else:
delta = x - self._M
self._M[...] += delta / self._n
self._S[...] += delta * delta * n1 / self._n
def update(self, other):
n1 = self._n
n2 = other._n
n = n1 + n2
delta = self._M - other._M
delta2 = delta * delta
M = (n1 * self._M + n2 * other._M) / n
S = self._S + other._S + delta2 * n1 * n2 / n
self._n = n
self._M = M
self._S = S
@property
def n(self):
return self._n
@property
def mean(self):
return self._M
@property
def var(self):
return self._S/(self._n - 1) if self._n > 1 else np.square(self._M)
@property
def std(self):
return np.sqrt(self.var)
@property
def shape(self):
return self._M.shape
class MeanStdFilter(object):
"""
y = (x-mean)/std
using running estimates of mean,std
"""
def __init__(self, shape, demean=True, destd=True, clip=10.0):
self.demean = demean
self.destd = destd
self.clip = clip
self.rs = RunningStat(shape)
def __call__(self, x, update=True):
x = np.asarray(x)
if update:
if len(x.shape) == len(self.rs.shape) + 1:
# The vectorized case.
for i in range(x.shape[0]):
self.rs.push(x[i])
else:
# The unvectorized case.
self.rs.push(x)
if self.demean:
x = x - self.rs.mean
if self.destd:
x = x / (self.rs.std+1e-8)
if self.clip:
if np.amin(x) < -self.clip or np.amax(x) > self.clip:
print("Clipping value to " + str(self.clip))
x = np.clip(x, -self.clip, self.clip)
return x
def test_running_stat():
for shp in ((), (3,), (3,4)):
li = []
rs = RunningStat(shp)
for _ in range(5):
val = np.random.randn(*shp)
rs.push(val)
li.append(val)
m = np.mean(li, axis=0)
assert np.allclose(rs.mean, m)
v = np.square(m) if (len(li) == 1) else np.var(li, ddof=1, axis=0)
assert np.allclose(rs.var, v)
def test_combining_stat():
for shape in [(), (3,), (3,4)]:
li = []
rs1 = RunningStat(shape)
rs2 = RunningStat(shape)
rs = RunningStat(shape)
for _ in range(5):
val = np.random.randn(*shape)
rs1.push(val)
rs.push(val)
li.append(val)
for _ in range(9):
rs2.push(val)
rs.push(val)
li.append(val)
rs1.update(rs2)
assert np.allclose(rs.mean, rs1.mean)
assert np.allclose(rs.std, rs1.std)
test_running_stat()
test_combining_stat()

View file

@ -0,0 +1,26 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
def normc_initializer(std=1.0):
def _initializer(shape, dtype=None, partition_info=None): #pylint: disable=W0613
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)
return _initializer
def fc_net(inputs, num_classes=10, logstd=False):
fc1 = slim.fully_connected(inputs, 128, weights_initializer=normc_initializer(1.0), scope="fc1")
fc2 = slim.fully_connected(fc1, 128, weights_initializer=normc_initializer(1.0), scope="fc2")
fc3 = slim.fully_connected(fc2, 128, weights_initializer=normc_initializer(1.0), scope="fc3")
fc4 = slim.fully_connected(fc3, num_classes, weights_initializer=normc_initializer(0.01), activation_fn=None, scope="fc4")
if logstd:
logstd = tf.get_variable(name="logstd", shape=[num_classes], initializer=tf.zeros_initializer)
return tf.concat(1, [fc4, logstd])
else:
return fc4

View file

@ -0,0 +1,13 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.contrib.slim as slim
def vision_net(inputs, num_classes=10):
conv1 = slim.conv2d(inputs, 16, [8, 8], 4, scope="conv1")
conv2 = slim.conv2d(conv1, 32, [4, 4], 2, scope="conv2")
fc1 = slim.conv2d(conv2, 512, [10, 10], padding="VALID", scope="fc1")
fc2 = slim.conv2d(fc1, num_classes, [1, 1], activation_fn=None, normalizer_fn=None, scope="fc2")
return tf.squeeze(fc2, [1, 2])

View file

@ -0,0 +1,55 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym.spaces
import tensorflow as tf
import tensorflow.contrib.slim as slim
from reinforce.models.visionnet import vision_net
from reinforce.models.fcnet import fc_net
from reinforce.distributions import Categorical, DiagGaussian
class ProximalPolicyLoss(object):
def __init__(self, observation_space, action_space, config, sess):
assert isinstance(action_space, gym.spaces.Discrete) or isinstance(action_space, gym.spaces.Box)
# adapting the kl divergence
self.kl_coeff = tf.placeholder(name="newkl", shape=(), dtype=tf.float32)
self.observations = tf.placeholder(tf.float32, shape=(None,) + observation_space.shape)
self.advantages = tf.placeholder(tf.float32, shape=(None,))
if isinstance(action_space, gym.spaces.Box):
# First half of the dimensions are the means, the second half are the standard deviations
self.action_dim = action_space.shape[0]
self.logit_dim = 2 * self.action_dim
self.actions = tf.placeholder(tf.float32, shape=(None, action_space.shape[0]))
Distribution = DiagGaussian
elif isinstance(action_space, gym.spaces.Discrete):
self.action_dim = action_space.n
self.logit_dim = self.action_dim
self.actions = tf.placeholder(tf.int64, shape=(None,))
Distribution = Categorical
else:
raise NotImplemented("action space" + str(type(env.action_space)) + "currently not supported")
self.prev_logits = tf.placeholder(tf.float32, shape=(None, self.logit_dim))
self.prev_dist = Distribution(self.prev_logits)
self.curr_logits = fc_net(self.observations, num_classes=self.logit_dim)
self.curr_dist = Distribution(self.curr_logits)
self.sampler = self.curr_dist.sample()
self.entropy = self.curr_dist.entropy()
# Make loss functions.
self.ratio = tf.exp(self.curr_dist.logp(self.actions) - self.prev_dist.logp(self.actions))
self.kl = self.prev_dist.kl(self.curr_dist)
self.mean_kl = tf.reduce_mean(self.kl)
self.mean_entropy = tf.reduce_mean(self.entropy)
self.surr1 = self.ratio * self.advantages
self.surr2 = tf.clip_by_value(self.ratio, 1 - config["clip_param"], 1 + config["clip_param"]) * self.advantages
self.surr = tf.minimum(self.surr1, self.surr2)
self.loss = tf.reduce_mean(-self.surr + self.kl_coeff * self.kl - config["entropy_coeff"] * self.entropy)
self.sess = sess
def compute_actions(self, observations):
return self.sess.run([self.sampler, self.curr_logits], feed_dict={self.observations: observations})
def loss(self):
return self.loss

View file

@ -0,0 +1,91 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import ray
from reinforce.filter import NoFilter
from reinforce.utils import flatten, concatenate
def rollouts(policy, env, horizon, observation_filter=NoFilter(), reward_filter=NoFilter()):
"""Perform a batch of rollouts of a policy in an environment.
Args:
policy: The policy that will be rollout out. Can be an arbitrary object
that supports a compute_actions(observation) function.
env: The environment the rollout is computed in. Needs to support the
OpenAI gym API and needs to support batches of data.
horizon: Upper bound for the number of timesteps for each rollout in the
batch.
observation_filter: Function that is applied to each of the observations.
reward_filter: Function that is applied to each of the rewards.
Returns:
A trajectory, which is a dictionary with keys "observations", "rewards",
"orig_rewards", "actions", "logprobs", "dones". Each value is an array of
shape (num_timesteps, env.batchsize, shape).
"""
observation = observation_filter(env.reset())
done = np.array(env.batchsize * [False])
t = 0
observations = []
raw_rewards = [] # Empirical rewards
actions = []
logprobs = []
dones = []
while not done.all() and t < horizon:
action, logprob = policy.compute_actions(observation)
observations.append(observation[None])
actions.append(action[None])
logprobs.append(logprob[None])
observation, raw_reward, done = env.step(action)
observation = observation_filter(observation)
raw_rewards.append(raw_reward[None])
dones.append(done[None])
t += 1
return {"observations": np.vstack(observations),
"raw_rewards": np.vstack(raw_rewards),
"actions": np.vstack(actions),
"logprobs": np.vstack(logprobs),
"dones": np.vstack(dones)}
def add_advantage_values(trajectory, gamma, lam, reward_filter):
rewards = trajectory["raw_rewards"]
dones = trajectory["dones"]
advantages = np.zeros_like(rewards)
last_advantage = np.zeros(rewards.shape[1], dtype="float32")
for t in reversed(range(len(rewards))):
delta = rewards[t,:] * (1 - dones[t,:])
last_advantage = delta + gamma * lam * last_advantage
advantages[t,:] = last_advantage
reward_filter(advantages[t,:])
trajectory["advantages"] = advantages
@ray.remote
def compute_trajectory(policy, env, gamma, lam, horizon, observation_filter, reward_filter):
trajectory = rollouts(policy, env, horizon, observation_filter, reward_filter)
add_advantage_values(trajectory, gamma, lam, reward_filter)
return trajectory
def collect_samples(agents, num_timesteps, gamma, lam, horizon, observation_filter=NoFilter(), reward_filter=NoFilter()):
num_timesteps_so_far = 0
trajectories = []
total_rewards = []
traj_len_means = []
while num_timesteps_so_far < num_timesteps:
trajectory_batch = ray.get([agent.compute_trajectory(gamma, lam, horizon) for agent in agents])
trajectory = concatenate(trajectory_batch)
total_rewards.append(trajectory["raw_rewards"].sum(axis=0).mean() / len(agents))
trajectory = flatten(trajectory)
not_done = np.logical_not(trajectory["dones"])
traj_len_means.append(not_done.sum(axis=0).mean() / len(agents))
trajectory = {key: val[not_done] for key, val in trajectory.items()}
num_timesteps_so_far += len(trajectory["dones"])
trajectories.append(trajectory)
return concatenate(trajectories), np.mean(total_rewards), np.mean(traj_len_means)

View file

@ -0,0 +1,30 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import threading
class DataQueue(object):
def __init__(self, placeholder_dict):
"""Here, placeholder_dict is an OrderedDict."""
placeholders = placeholder_dict.values()
shapes = [placeholder.get_shape()[1:].as_list() for placeholder in placeholders]
types = [placeholder.dtype for placeholder in placeholders]
self.queue = tf.RandomShuffleQueue(shapes=shapes, dtypes=dtypes, capacity=2000, min_after_dequeue=1000)
self.enqueue_op = self.queue.enqueue_many(placeholders)
def thread_main(self, sess, data_iterator):
for data in data_iterator:
feed_dict = {placeholder: data[name] for (name, placeholder) in placeholder_dict}
sess.run(self.enqueue_op, feed_dict=feed_dict)
def start_thread(self, sess, data_iterator, num_threads=1):
threads = []
for n in range(num_thread):
t = threading.Thread(target=self.train_main, args=(sess, data_iterator))
t.daemon = True # Thread will close when parent quits
t.start()
threads.append(t)
return threads

View file

@ -0,0 +1,44 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def flatten(weights, start=0, stop=2):
"""This methods reshapes all values in a dictionary.
The indices from start to stop will be flattened into a single index.
Args:
weights: A dictionary mapping keys to numpy arrays.
start: The starting index.
stop: The ending index.
"""
for key, val in weights.items():
new_shape = val.shape[0:start] + (-1,) + val.shape[stop:]
weights[key] = val.reshape(new_shape)
return weights
def concatenate(weights_list):
keys = weights_list[0].keys()
result = {}
for key in keys:
result[key] = np.concatenate([l[key] for l in weights_list])
return result
def shuffle(trajectory):
permutation = np.random.permutation(trajectory["dones"].shape[0])
for key, val in trajectory.items():
trajectory[key] = val[permutation][permutation]
return trajectory
def iterate(trajectory, batchsize):
trajectory = shuffle(trajectory)
curr_index = 0
# TODO(pcm): This drops some data at the end of the batch.
while curr_index + batchsize < trajectory["dones"].shape[0]:
batch = dict()
for key in trajectory:
batch[key] = trajectory[key][curr_index:curr_index+batchsize]
curr_index += batchsize
yield batch

View file

@ -0,0 +1,9 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from setuptools import setup, find_packages
setup(name="reinforce",
version="0.0.1",
packages=find_packages())

View file

@ -0,0 +1,58 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import numpy as np
import tensorflow as tf
from numpy.testing import assert_allclose
from reinforce.distributions import Categorical
from reinforce.utils import flatten, concatenate
class DistibutionsTest(unittest.TestCase):
def testCategorical(self):
num_samples = 100000
logits = tf.placeholder(tf.float32, shape=(None, 10))
z = 8 * (np.random.rand(10) - 0.5)
data = np.tile(z, (num_samples, 1))
c = Categorical(logits)
sample_op = c.sample()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
samples = sess.run(sample_op, feed_dict={logits: data})
counts = np.zeros(10)
for sample in samples:
counts[sample] += 1.0
probs = np.exp(z) / np.sum(np.exp(z))
self.assertTrue(np.sum(np.abs(probs - counts / num_samples)) <= 0.01)
class UtilsTest(unittest.TestCase):
def testFlatten(self):
d = {"s": np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]),
"a": np.array([[[5], [-5]], [[6], [-6]]])}
flat = flatten(d.copy(), start=0, stop=2)
assert_allclose(d["s"][0][0][:], flat["s"][0][:])
assert_allclose(d["s"][0][1][:], flat["s"][1][:])
assert_allclose(d["s"][1][0][:], flat["s"][2][:])
assert_allclose(d["s"][1][1][:], flat["s"][3][:])
assert_allclose(d["a"][0][0], flat["a"][0])
assert_allclose(d["a"][0][1], flat["a"][1])
assert_allclose(d["a"][1][0], flat["a"][2])
assert_allclose(d["a"][1][1], flat["a"][3])
def testConcatenate(self):
d1 = {"s": np.array([0, 1]), "a": np.array([2, 3])}
d2 = {"s": np.array([4, 5]), "a": np.array([6, 7])}
d = concatenate([d1, d2])
assert_allclose(d["s"], np.array([0, 1, 4, 5]))
assert_allclose(d["a"], np.array([2, 3, 6, 7]))
D = concatenate([d])
assert_allclose(D["s"], np.array([0, 1, 4, 5]))
assert_allclose(D["a"], np.array([2, 3, 6, 7]))
if __name__ == "__main__":
unittest.main(verbosity=2)