diff --git a/doc/source/example-policy-gradient.rst b/doc/source/example-policy-gradient.rst new file mode 100644 index 000000000..449a83848 --- /dev/null +++ b/doc/source/example-policy-gradient.rst @@ -0,0 +1,31 @@ +Policy Gradient Methods +======================= + +This code shows how to do reinforcement learning with policy gradient methods. +View the `code for this example`_. + +To run this example, you will need to install `TensorFlow with GPU support`_ (at +least version ``1.0.0``) and a few other dependencies. + +.. code-block:: bash + + pip install gym[atari] + pip install tensorflow + +Then install the package as follows. + +.. code-block:: bash + + cd ray/examples/policy_gradient/ + python setup.py install + +Then you can run the example as follows. + +.. code-block:: bash + + python ray/examples/policy_gradient/examples/example.py + +This will train an agent on an Atari environment. + +.. _`TensorFlow with GPU support`: https://www.tensorflow.org/install/ +.. _`code for this example`: https://github.com/ray-project/ray/tree/master/examples/policy_gradient diff --git a/doc/source/index.rst b/doc/source/index.rst index 08fe213c5..6b33266b7 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -27,6 +27,7 @@ learning and reinforcement learning applications.* :caption: Examples example-hyperopt.rst + example-policy-gradient.rst example-resnet.rst example-lbfgs.md example-rl-pong.md diff --git a/examples/policy_gradient/examples/example.py b/examples/policy_gradient/examples/example.py new file mode 100644 index 000000000..833ceb43b --- /dev/null +++ b/examples/policy_gradient/examples/example.py @@ -0,0 +1,65 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray + +from reinforce.agent import Agent, RemoteAgent +from reinforce.rollout import collect_samples +from reinforce.utils import iterate, shuffle + +config = {"kl_coeff": 0.2, + "num_sgd_iter": 30, + "sgd_stepsize": 5e-5, + "sgd_batchsize": 128, + "entropy_coeff": 0.0, + "clip_param": 0.3, + "kl_target": 0.01, + "timesteps_per_batch": 40000} + +ray.init() + +mdp_name = "Pong-ram-v3" + +agents = [RemoteAgent(mdp_name, 1, config, False) for _ in range(5)] +agent = Agent(mdp_name, 1, config, True) + +kl_coeff = config["kl_coeff"] + +for j in range(1000): + print("== iteration", j) + weights = agent.get_weights() + [a.load_weights(weights) for a in agents] + trajectory, total_reward, traj_len_mean = collect_samples(agents, config["timesteps_per_batch"], 0.995, 1.0, 2000) + print("total reward is ", total_reward) + print("trajectory length mean is ", traj_len_mean) + print("timesteps: ", trajectory["dones"].shape[0]) + trajectory["advantages"] = (trajectory["advantages"] - trajectory["advantages"].mean()) / trajectory["advantages"].std() + print("Computing policy (optimizer='" + agent.optimizer.get_name() + "', iterations=" + str(config["num_sgd_iter"]) + ", stepsize=" + str(config["sgd_stepsize"]) + "):") + names = ["iter", "loss", "kl", "entropy"] + print(("{:>15}" * len(names)).format(*names)) + trajectory = shuffle(trajectory) + ppo = agent.ppo + for i in range(config["num_sgd_iter"]): + # Test on current set of rollouts + loss, kl, entropy = agent.sess.run([ppo.loss, ppo.mean_kl, ppo.mean_entropy], + feed_dict={ppo.observations: trajectory["observations"], + ppo.advantages: trajectory["advantages"], + ppo.actions: trajectory["actions"].squeeze(), + ppo.prev_logits: trajectory["logprobs"], + ppo.kl_coeff: kl_coeff}) + print("{:>15}{:15.5e}{:15.5e}{:15.5e}".format(i, loss, kl, entropy)) + # Run SGD for training on current set of rollouts + for batch in iterate(trajectory, config["sgd_batchsize"]): + agent.sess.run([agent.train_op], + feed_dict={ppo.observations: batch["observations"], + ppo.advantages: batch["advantages"], + ppo.actions: batch["actions"].squeeze(), + ppo.prev_logits: batch["logprobs"], + ppo.kl_coeff: kl_coeff}) + if kl > 2.0 * config["kl_target"]: + kl_coeff *= 1.5 + elif kl < 0.5 * config["kl_target"]: + kl_coeff *= 0.5 + print("kl div = ", kl) + print("kl coeff = ", kl_coeff) diff --git a/examples/policy_gradient/reinforce/__init__.py b/examples/policy_gradient/reinforce/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/policy_gradient/reinforce/agent.py b/examples/policy_gradient/reinforce/agent.py new file mode 100644 index 000000000..82f313c6d --- /dev/null +++ b/examples/policy_gradient/reinforce/agent.py @@ -0,0 +1,41 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import os + +import ray + +from reinforce.env import BatchedEnv +from reinforce.policy import ProximalPolicyLoss +from reinforce.filter import MeanStdFilter +from reinforce.rollout import rollouts, add_advantage_values + +class Agent(object): + + def __init__(self, name, batchsize, config, use_gpu): + if not use_gpu: + os.environ["CUDA_VISIBLE_DEVICES"] = "" + self.env = BatchedEnv(name, batchsize, preprocessor=None) + self.sess = tf.Session() + self.ppo = ProximalPolicyLoss(self.env.observation_space, self.env.action_space, config, self.sess) + self.optimizer = tf.train.AdamOptimizer(config["sgd_stepsize"]) + self.train_op = self.optimizer.minimize(self.ppo.loss) + self.variables = ray.experimental.TensorFlowVariables(self.ppo.loss, self.sess) + self.observation_filter = MeanStdFilter(self.env.observation_space.shape, clip=None) + self.reward_filter = MeanStdFilter((), clip=5.0) + self.sess.run(tf.global_variables_initializer()) + + def get_weights(self): + return self.variables.get_weights() + + def load_weights(self, weights): + self.variables.set_weights(weights) + + def compute_trajectory(self, gamma, lam, horizon): + trajectory = rollouts(self.ppo, self.env, horizon, self.observation_filter, self.reward_filter) + add_advantage_values(trajectory, gamma, lam, self.reward_filter) + return trajectory + +RemoteAgent = ray.actor(Agent) diff --git a/examples/policy_gradient/reinforce/distributions.py b/examples/policy_gradient/reinforce/distributions.py new file mode 100644 index 000000000..b4377c507 --- /dev/null +++ b/examples/policy_gradient/reinforce/distributions.py @@ -0,0 +1,58 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import numpy as np + +class Categorical(object): + + def __init__(self, logits): + self.logits = logits + + def logp(self, x): + return -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x) + + def entropy(self): + a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1], keep_dims=True) + ea0 = tf.exp(a0) + z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True) + p0 = ea0 / z0 + return tf.reduce_sum(p0 * (tf.log(z0) - a0), reduction_indices=[1]) + + def kl(self, other): + a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1], keep_dims=True) + a1 = other.logits - tf.reduce_max(other.logits, reduction_indices=[1], keep_dims=True) + ea0 = tf.exp(a0) + ea1 = tf.exp(a1) + z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True) + z1 = tf.reduce_sum(ea1, reduction_indices=[1], keep_dims=True) + p0 = ea0 / z0 + return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), reduction_indices=[1]) + + def sample(self): + return tf.multinomial(self.logits, 1) + +class DiagGaussian(object): + + def __init__(self, flat): + self.flat = flat + mean, logstd = tf.split(1, 2, flat) + self.mean = mean + self.logstd = logstd + self.std = tf.exp(logstd) + + def logp(self, x): + return - 0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std), reduction_indices=[1]) \ + - 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) \ + - tf.reduce_sum(self.logstd, reduction_indices=[1]) + + def kl(self, other): + assert isinstance(other, DiagGaussian) + return tf.reduce_sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, reduction_indices=[1]) + + def entropy(self): + return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), reduction_indices=[1]) + + def sample(self): + return self.mean + self.std * tf.random_normal(tf.shape(self.mean)) diff --git a/examples/policy_gradient/reinforce/env.py b/examples/policy_gradient/reinforce/env.py new file mode 100644 index 000000000..30649c029 --- /dev/null +++ b/examples/policy_gradient/reinforce/env.py @@ -0,0 +1,45 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gym +import numpy as np + +def atari_preprocessor(observation): + "Convert images from (210, 160, 3) to (3, 80, 80) by downsampling." + return (observation[25:-25:2,::2,:][None] - 128.0) / 128.8 + +def ram_preprocessor(observation): + return (observation - 128.0) / 128.0 + +class BatchedEnv(object): + "A BatchedEnv holds multiple gym enviroments and performs steps on all of them." + + def __init__(self, name, batchsize, preprocessor=None): + self.envs = [gym.make(name) for _ in range(batchsize)] + self.observation_space = self.envs[0].observation_space + self.action_space = self.envs[0].action_space + self.batchsize = batchsize + self.preprocessor = preprocessor if preprocessor else lambda obs: obs[None] + + def reset(self): + observations = [self.preprocessor(env.reset()) for env in self.envs] + self.shape = observations[0].shape + self.dones = [False for _ in range(self.batchsize)] + return np.vstack(observations) + + def step(self, actions, render=False): + observations = [] + rewards = [] + for i, action in enumerate(actions): + if self.dones[i]: + observations.append(np.zeros(self.shape)) + rewards.append(0.0) + continue + observation, reward, done, info = self.envs[i].step(action) + if render: + self.envs[0].render() + observations.append(self.preprocessor(observation)) + rewards.append(reward) + self.dones[i] = done + return np.vstack(observations), np.array(rewards, dtype="float32"), np.array(self.dones) diff --git a/examples/policy_gradient/reinforce/filter.py b/examples/policy_gradient/reinforce/filter.py new file mode 100644 index 000000000..0b807b3d2 --- /dev/null +++ b/examples/policy_gradient/reinforce/filter.py @@ -0,0 +1,136 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings +import numpy as np + +class NoFilter(object): + + def __init__(self): + pass + + def __call__(self, x, update=True): + return np.asarray(x) + +# http://www.johndcook.com/blog/standard_deviation/ +class RunningStat(object): + + def __init__(self, shape=None): + self._n = 0 + self._M = np.zeros(shape) + self._S = np.zeros(shape) + + def push(self, x): + x = np.asarray(x) + # Unvectorized update of the running statistics. + assert x.shape == self._M.shape, "x.shape = {}, self.shape = {}".format(x.shape, self._M.shape) + n1 = self._n + self._n += 1 + if self._n == 1: + self._M[...] = x + else: + delta = x - self._M + self._M[...] += delta / self._n + self._S[...] += delta * delta * n1 / self._n + + def update(self, other): + n1 = self._n + n2 = other._n + n = n1 + n2 + delta = self._M - other._M + delta2 = delta * delta + M = (n1 * self._M + n2 * other._M) / n + S = self._S + other._S + delta2 * n1 * n2 / n + self._n = n + self._M = M + self._S = S + + @property + def n(self): + return self._n + + @property + def mean(self): + return self._M + + @property + def var(self): + return self._S/(self._n - 1) if self._n > 1 else np.square(self._M) + + @property + def std(self): + return np.sqrt(self.var) + + @property + def shape(self): + return self._M.shape + +class MeanStdFilter(object): + """ + y = (x-mean)/std + using running estimates of mean,std + """ + + def __init__(self, shape, demean=True, destd=True, clip=10.0): + self.demean = demean + self.destd = destd + self.clip = clip + + self.rs = RunningStat(shape) + + def __call__(self, x, update=True): + x = np.asarray(x) + if update: + if len(x.shape) == len(self.rs.shape) + 1: + # The vectorized case. + for i in range(x.shape[0]): + self.rs.push(x[i]) + else: + # The unvectorized case. + self.rs.push(x) + if self.demean: + x = x - self.rs.mean + if self.destd: + x = x / (self.rs.std+1e-8) + if self.clip: + if np.amin(x) < -self.clip or np.amax(x) > self.clip: + print("Clipping value to " + str(self.clip)) + x = np.clip(x, -self.clip, self.clip) + return x + + +def test_running_stat(): + for shp in ((), (3,), (3,4)): + li = [] + rs = RunningStat(shp) + for _ in range(5): + val = np.random.randn(*shp) + rs.push(val) + li.append(val) + m = np.mean(li, axis=0) + assert np.allclose(rs.mean, m) + v = np.square(m) if (len(li) == 1) else np.var(li, ddof=1, axis=0) + assert np.allclose(rs.var, v) + +def test_combining_stat(): + for shape in [(), (3,), (3,4)]: + li = [] + rs1 = RunningStat(shape) + rs2 = RunningStat(shape) + rs = RunningStat(shape) + for _ in range(5): + val = np.random.randn(*shape) + rs1.push(val) + rs.push(val) + li.append(val) + for _ in range(9): + rs2.push(val) + rs.push(val) + li.append(val) + rs1.update(rs2) + assert np.allclose(rs.mean, rs1.mean) + assert np.allclose(rs.std, rs1.std) + +test_running_stat() +test_combining_stat() diff --git a/examples/policy_gradient/reinforce/models/__init__.py b/examples/policy_gradient/reinforce/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/policy_gradient/reinforce/models/fcnet.py b/examples/policy_gradient/reinforce/models/fcnet.py new file mode 100644 index 000000000..da0050a81 --- /dev/null +++ b/examples/policy_gradient/reinforce/models/fcnet.py @@ -0,0 +1,26 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import tensorflow.contrib.slim as slim + +import numpy as np + +def normc_initializer(std=1.0): + def _initializer(shape, dtype=None, partition_info=None): #pylint: disable=W0613 + out = np.random.randn(*shape).astype(np.float32) + out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) + return tf.constant(out) + return _initializer + +def fc_net(inputs, num_classes=10, logstd=False): + fc1 = slim.fully_connected(inputs, 128, weights_initializer=normc_initializer(1.0), scope="fc1") + fc2 = slim.fully_connected(fc1, 128, weights_initializer=normc_initializer(1.0), scope="fc2") + fc3 = slim.fully_connected(fc2, 128, weights_initializer=normc_initializer(1.0), scope="fc3") + fc4 = slim.fully_connected(fc3, num_classes, weights_initializer=normc_initializer(0.01), activation_fn=None, scope="fc4") + if logstd: + logstd = tf.get_variable(name="logstd", shape=[num_classes], initializer=tf.zeros_initializer) + return tf.concat(1, [fc4, logstd]) + else: + return fc4 diff --git a/examples/policy_gradient/reinforce/models/visionnet.py b/examples/policy_gradient/reinforce/models/visionnet.py new file mode 100644 index 000000000..569eb322b --- /dev/null +++ b/examples/policy_gradient/reinforce/models/visionnet.py @@ -0,0 +1,13 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import tensorflow.contrib.slim as slim + +def vision_net(inputs, num_classes=10): + conv1 = slim.conv2d(inputs, 16, [8, 8], 4, scope="conv1") + conv2 = slim.conv2d(conv1, 32, [4, 4], 2, scope="conv2") + fc1 = slim.conv2d(conv2, 512, [10, 10], padding="VALID", scope="fc1") + fc2 = slim.conv2d(fc1, num_classes, [1, 1], activation_fn=None, normalizer_fn=None, scope="fc2") + return tf.squeeze(fc2, [1, 2]) diff --git a/examples/policy_gradient/reinforce/policy.py b/examples/policy_gradient/reinforce/policy.py new file mode 100644 index 000000000..af0d613e9 --- /dev/null +++ b/examples/policy_gradient/reinforce/policy.py @@ -0,0 +1,55 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gym.spaces +import tensorflow as tf +import tensorflow.contrib.slim as slim +from reinforce.models.visionnet import vision_net +from reinforce.models.fcnet import fc_net +from reinforce.distributions import Categorical, DiagGaussian + +class ProximalPolicyLoss(object): + + def __init__(self, observation_space, action_space, config, sess): + assert isinstance(action_space, gym.spaces.Discrete) or isinstance(action_space, gym.spaces.Box) + # adapting the kl divergence + self.kl_coeff = tf.placeholder(name="newkl", shape=(), dtype=tf.float32) + self.observations = tf.placeholder(tf.float32, shape=(None,) + observation_space.shape) + self.advantages = tf.placeholder(tf.float32, shape=(None,)) + + if isinstance(action_space, gym.spaces.Box): + # First half of the dimensions are the means, the second half are the standard deviations + self.action_dim = action_space.shape[0] + self.logit_dim = 2 * self.action_dim + self.actions = tf.placeholder(tf.float32, shape=(None, action_space.shape[0])) + Distribution = DiagGaussian + elif isinstance(action_space, gym.spaces.Discrete): + self.action_dim = action_space.n + self.logit_dim = self.action_dim + self.actions = tf.placeholder(tf.int64, shape=(None,)) + Distribution = Categorical + else: + raise NotImplemented("action space" + str(type(env.action_space)) + "currently not supported") + self.prev_logits = tf.placeholder(tf.float32, shape=(None, self.logit_dim)) + self.prev_dist = Distribution(self.prev_logits) + self.curr_logits = fc_net(self.observations, num_classes=self.logit_dim) + self.curr_dist = Distribution(self.curr_logits) + self.sampler = self.curr_dist.sample() + self.entropy = self.curr_dist.entropy() + # Make loss functions. + self.ratio = tf.exp(self.curr_dist.logp(self.actions) - self.prev_dist.logp(self.actions)) + self.kl = self.prev_dist.kl(self.curr_dist) + self.mean_kl = tf.reduce_mean(self.kl) + self.mean_entropy = tf.reduce_mean(self.entropy) + self.surr1 = self.ratio * self.advantages + self.surr2 = tf.clip_by_value(self.ratio, 1 - config["clip_param"], 1 + config["clip_param"]) * self.advantages + self.surr = tf.minimum(self.surr1, self.surr2) + self.loss = tf.reduce_mean(-self.surr + self.kl_coeff * self.kl - config["entropy_coeff"] * self.entropy) + self.sess = sess + + def compute_actions(self, observations): + return self.sess.run([self.sampler, self.curr_logits], feed_dict={self.observations: observations}) + + def loss(self): + return self.loss diff --git a/examples/policy_gradient/reinforce/rollout.py b/examples/policy_gradient/reinforce/rollout.py new file mode 100644 index 000000000..094e71609 --- /dev/null +++ b/examples/policy_gradient/reinforce/rollout.py @@ -0,0 +1,91 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import ray + +from reinforce.filter import NoFilter +from reinforce.utils import flatten, concatenate + +def rollouts(policy, env, horizon, observation_filter=NoFilter(), reward_filter=NoFilter()): + """Perform a batch of rollouts of a policy in an environment. + + Args: + policy: The policy that will be rollout out. Can be an arbitrary object + that supports a compute_actions(observation) function. + env: The environment the rollout is computed in. Needs to support the + OpenAI gym API and needs to support batches of data. + horizon: Upper bound for the number of timesteps for each rollout in the + batch. + observation_filter: Function that is applied to each of the observations. + reward_filter: Function that is applied to each of the rewards. + + Returns: + A trajectory, which is a dictionary with keys "observations", "rewards", + "orig_rewards", "actions", "logprobs", "dones". Each value is an array of + shape (num_timesteps, env.batchsize, shape). + """ + + observation = observation_filter(env.reset()) + done = np.array(env.batchsize * [False]) + t = 0 + observations = [] + raw_rewards = [] # Empirical rewards + actions = [] + logprobs = [] + dones = [] + + while not done.all() and t < horizon: + action, logprob = policy.compute_actions(observation) + observations.append(observation[None]) + actions.append(action[None]) + logprobs.append(logprob[None]) + observation, raw_reward, done = env.step(action) + observation = observation_filter(observation) + raw_rewards.append(raw_reward[None]) + dones.append(done[None]) + t += 1 + + return {"observations": np.vstack(observations), + "raw_rewards": np.vstack(raw_rewards), + "actions": np.vstack(actions), + "logprobs": np.vstack(logprobs), + "dones": np.vstack(dones)} + +def add_advantage_values(trajectory, gamma, lam, reward_filter): + rewards = trajectory["raw_rewards"] + dones = trajectory["dones"] + advantages = np.zeros_like(rewards) + last_advantage = np.zeros(rewards.shape[1], dtype="float32") + + for t in reversed(range(len(rewards))): + delta = rewards[t,:] * (1 - dones[t,:]) + last_advantage = delta + gamma * lam * last_advantage + advantages[t,:] = last_advantage + reward_filter(advantages[t,:]) + + trajectory["advantages"] = advantages + +@ray.remote +def compute_trajectory(policy, env, gamma, lam, horizon, observation_filter, reward_filter): + trajectory = rollouts(policy, env, horizon, observation_filter, reward_filter) + add_advantage_values(trajectory, gamma, lam, reward_filter) + return trajectory + +def collect_samples(agents, num_timesteps, gamma, lam, horizon, observation_filter=NoFilter(), reward_filter=NoFilter()): + num_timesteps_so_far = 0 + trajectories = [] + total_rewards = [] + traj_len_means = [] + while num_timesteps_so_far < num_timesteps: + trajectory_batch = ray.get([agent.compute_trajectory(gamma, lam, horizon) for agent in agents]) + trajectory = concatenate(trajectory_batch) + total_rewards.append(trajectory["raw_rewards"].sum(axis=0).mean() / len(agents)) + trajectory = flatten(trajectory) + not_done = np.logical_not(trajectory["dones"]) + traj_len_means.append(not_done.sum(axis=0).mean() / len(agents)) + trajectory = {key: val[not_done] for key, val in trajectory.items()} + num_timesteps_so_far += len(trajectory["dones"]) + trajectories.append(trajectory) + return concatenate(trajectories), np.mean(total_rewards), np.mean(traj_len_means) diff --git a/examples/policy_gradient/reinforce/tfutils.py b/examples/policy_gradient/reinforce/tfutils.py new file mode 100644 index 000000000..5fed1a74e --- /dev/null +++ b/examples/policy_gradient/reinforce/tfutils.py @@ -0,0 +1,30 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import threading + +class DataQueue(object): + + def __init__(self, placeholder_dict): + """Here, placeholder_dict is an OrderedDict.""" + placeholders = placeholder_dict.values() + shapes = [placeholder.get_shape()[1:].as_list() for placeholder in placeholders] + types = [placeholder.dtype for placeholder in placeholders] + self.queue = tf.RandomShuffleQueue(shapes=shapes, dtypes=dtypes, capacity=2000, min_after_dequeue=1000) + self.enqueue_op = self.queue.enqueue_many(placeholders) + + def thread_main(self, sess, data_iterator): + for data in data_iterator: + feed_dict = {placeholder: data[name] for (name, placeholder) in placeholder_dict} + sess.run(self.enqueue_op, feed_dict=feed_dict) + + def start_thread(self, sess, data_iterator, num_threads=1): + threads = [] + for n in range(num_thread): + t = threading.Thread(target=self.train_main, args=(sess, data_iterator)) + t.daemon = True # Thread will close when parent quits + t.start() + threads.append(t) + return threads diff --git a/examples/policy_gradient/reinforce/utils.py b/examples/policy_gradient/reinforce/utils.py new file mode 100644 index 000000000..21fe93dc5 --- /dev/null +++ b/examples/policy_gradient/reinforce/utils.py @@ -0,0 +1,44 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +def flatten(weights, start=0, stop=2): + """This methods reshapes all values in a dictionary. + + The indices from start to stop will be flattened into a single index. + + Args: + weights: A dictionary mapping keys to numpy arrays. + start: The starting index. + stop: The ending index. + """ + for key, val in weights.items(): + new_shape = val.shape[0:start] + (-1,) + val.shape[stop:] + weights[key] = val.reshape(new_shape) + return weights + +def concatenate(weights_list): + keys = weights_list[0].keys() + result = {} + for key in keys: + result[key] = np.concatenate([l[key] for l in weights_list]) + return result + +def shuffle(trajectory): + permutation = np.random.permutation(trajectory["dones"].shape[0]) + for key, val in trajectory.items(): + trajectory[key] = val[permutation][permutation] + return trajectory + +def iterate(trajectory, batchsize): + trajectory = shuffle(trajectory) + curr_index = 0 + # TODO(pcm): This drops some data at the end of the batch. + while curr_index + batchsize < trajectory["dones"].shape[0]: + batch = dict() + for key in trajectory: + batch[key] = trajectory[key][curr_index:curr_index+batchsize] + curr_index += batchsize + yield batch diff --git a/examples/policy_gradient/setup.py b/examples/policy_gradient/setup.py new file mode 100644 index 000000000..bc02ddb21 --- /dev/null +++ b/examples/policy_gradient/setup.py @@ -0,0 +1,9 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from setuptools import setup, find_packages + +setup(name="reinforce", + version="0.0.1", + packages=find_packages()) diff --git a/examples/policy_gradient/tests/test.py b/examples/policy_gradient/tests/test.py new file mode 100644 index 000000000..988c15061 --- /dev/null +++ b/examples/policy_gradient/tests/test.py @@ -0,0 +1,58 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +import numpy as np +import tensorflow as tf +from numpy.testing import assert_allclose + +from reinforce.distributions import Categorical +from reinforce.utils import flatten, concatenate + +class DistibutionsTest(unittest.TestCase): + + def testCategorical(self): + num_samples = 100000 + logits = tf.placeholder(tf.float32, shape=(None, 10)) + z = 8 * (np.random.rand(10) - 0.5) + data = np.tile(z, (num_samples, 1)) + c = Categorical(logits) + sample_op = c.sample() + sess = tf.Session() + sess.run(tf.global_variables_initializer()) + samples = sess.run(sample_op, feed_dict={logits: data}) + counts = np.zeros(10) + for sample in samples: + counts[sample] += 1.0 + probs = np.exp(z) / np.sum(np.exp(z)) + self.assertTrue(np.sum(np.abs(probs - counts / num_samples)) <= 0.01) + +class UtilsTest(unittest.TestCase): + + def testFlatten(self): + d = {"s": np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]), + "a": np.array([[[5], [-5]], [[6], [-6]]])} + flat = flatten(d.copy(), start=0, stop=2) + assert_allclose(d["s"][0][0][:], flat["s"][0][:]) + assert_allclose(d["s"][0][1][:], flat["s"][1][:]) + assert_allclose(d["s"][1][0][:], flat["s"][2][:]) + assert_allclose(d["s"][1][1][:], flat["s"][3][:]) + assert_allclose(d["a"][0][0], flat["a"][0]) + assert_allclose(d["a"][0][1], flat["a"][1]) + assert_allclose(d["a"][1][0], flat["a"][2]) + assert_allclose(d["a"][1][1], flat["a"][3]) + + def testConcatenate(self): + d1 = {"s": np.array([0, 1]), "a": np.array([2, 3])} + d2 = {"s": np.array([4, 5]), "a": np.array([6, 7])} + d = concatenate([d1, d2]) + assert_allclose(d["s"], np.array([0, 1, 4, 5])) + assert_allclose(d["a"], np.array([2, 3, 6, 7])) + + D = concatenate([d]) + assert_allclose(D["s"], np.array([0, 1, 4, 5])) + assert_allclose(D["a"], np.array([2, 3, 6, 7])) + +if __name__ == "__main__": + unittest.main(verbosity=2)