[rllib] Make Pong-v0 + EvolutionStrategies work by sharing preprocessors with PPO (#848)

* fix by sharing preprocessors

* revert param changeg

* Update evolution_strategies.py

* Update catalog.py
This commit is contained in:
Eric Liang 2017-08-22 03:51:49 +02:00 committed by Philipp Moritz
parent be4beb19c1
commit c81821b856
10 changed files with 105 additions and 84 deletions

2
.gitignore vendored
View file

@ -1,9 +1,11 @@
# The build output should clearly not be checked in
/python/ray/core
/python/ray/pyarrow_files/pyarrow/
/python/build
/python/dist
/src/common/thirdparty/redis
/src/numbuf/thirdparty/arrow
/src/thirdparty/arrow
# Files generated by flatc should be ignored
/src/common/format/*.py

View file

@ -13,6 +13,7 @@ import time
import ray
from ray.rllib.common import Algorithm, TrainingResult
from ray.rllib.models import ModelCatalog
from ray.rllib.evolution_strategies import optimizers
from ray.rllib.evolution_strategies import policies
@ -72,9 +73,14 @@ class Worker(object):
self.noise = SharedNoiseTable(noise)
self.env = gym.make(env_name)
self.preprocessor = ModelCatalog.get_preprocessor(env_name)
self.preprocessor_shape = self.preprocessor.transform_shape(
self.env.observation_space.shape)
self.sess = utils.make_session(single_threaded=True)
self.policy = policies.GenericPolicy(
self.env.observation_space, self.env.action_space, **policy_params)
self.env.observation_space, self.env.action_space,
self.preprocessor, **policy_params)
tf_util.initialize()
self.rs = np.random.RandomState()
@ -88,13 +94,14 @@ class Worker(object):
self.config["calc_obstat_prob"] != 0 and
self.rs.rand() < self.config["calc_obstat_prob"]):
rollout_rews, rollout_len, obs = self.policy.rollout(
self.env, timestep_limit=timestep_limit, save_obs=True,
random_stream=self.rs)
self.env, self.preprocessor, timestep_limit=timestep_limit,
save_obs=True, random_stream=self.rs)
task_ob_stat.increment(obs.sum(axis=0), np.square(obs).sum(axis=0),
len(obs))
else:
rollout_rews, rollout_len = self.policy.rollout(
self.env, timestep_limit=timestep_limit, random_stream=self.rs)
self.env, self.preprocessor, timestep_limit=timestep_limit,
random_stream=self.rs)
return rollout_rews, rollout_len
def do_rollouts(self, params, ob_mean, ob_std, timestep_limit=None):
@ -109,8 +116,7 @@ class Worker(object):
noise_inds, returns, sign_returns, lengths = [], [], [], []
# We set eps=0 because we're incrementing only.
task_ob_stat = utils.RunningStat(
self.env.observation_space.shape, eps=0)
task_ob_stat = utils.RunningStat(self.preprocessor_shape, eps=0)
# Perform some rollouts with noise.
task_tstart = time.time()
@ -161,12 +167,17 @@ class EvolutionStrategies(Algorithm):
}
env = gym.make(env_name)
preprocessor = ModelCatalog.get_preprocessor(env_name)
preprocessor_shape = preprocessor.transform_shape(
env.observation_space.shape)
utils.make_session(single_threaded=False)
self.policy = policies.GenericPolicy(
env.observation_space, env.action_space, **policy_params)
env.observation_space, env.action_space, preprocessor,
**policy_params)
tf_util.initialize()
self.optimizer = optimizers.Adam(self.policy, config["stepsize"])
self.ob_stat = utils.RunningStat(env.observation_space.shape, eps=1e-2)
self.ob_stat = utils.RunningStat(preprocessor_shape, eps=1e-2)
# Create the shared noise table.
print("Creating shared noise table.")

View file

@ -84,8 +84,8 @@ class Policy:
# === Rollouts/training ===
def rollout(self, env, render=False, timestep_limit=None, save_obs=False,
random_stream=None):
def rollout(self, env, preprocessor, render=False, timestep_limit=None,
save_obs=False, random_stream=None):
"""Do a rollout.
If random_stream is provided, the rollout will take noisy actions with
@ -99,12 +99,15 @@ class Policy:
t = 0
if save_obs:
obs = []
ob = env.reset()
# TODO(ekl) the squeeze() is needed for Pong-v0, but we should fix
# this in the preprocessor instead
ob = preprocessor.transform(env.reset()).squeeze()
for _ in range(timestep_limit):
ac = self.act(ob[None], random_stream=random_stream)[0]
if save_obs:
obs.append(ob)
ob, rew, done, _ = env.step(ac)
ob = preprocessor.transform(ob).squeeze()
rews.append(rew)
t += 1
if render:
@ -140,26 +143,28 @@ def bins(x, dim, num_bins, name):
class GenericPolicy(Policy):
def _initialize(self, ob_space, ac_space, ac_noise_std):
def _initialize(self, ob_space, ac_space, preprocessor, ac_noise_std):
self.ac_space = ac_space
self.ac_noise_std = ac_noise_std
self.preprocessor_shape = preprocessor.transform_shape(ob_space.shape)
with tf.variable_scope(type(self).__name__) as scope:
# Observation normalization.
ob_mean = tf.get_variable(
'ob_mean', ob_space.shape, tf.float32,
'ob_mean', self.preprocessor_shape, tf.float32,
tf.constant_initializer(np.nan), trainable=False)
ob_std = tf.get_variable(
'ob_std', ob_space.shape, tf.float32,
'ob_std', self.preprocessor_shape, tf.float32,
tf.constant_initializer(np.nan), trainable=False)
in_mean = tf.placeholder(tf.float32, ob_space.shape)
in_std = tf.placeholder(tf.float32, ob_space.shape)
in_mean = tf.placeholder(tf.float32, self.preprocessor_shape)
in_std = tf.placeholder(tf.float32, self.preprocessor_shape)
self._set_ob_mean_std = U.function([in_mean, in_std], [], updates=[
tf.assign(ob_mean, in_mean),
tf.assign(ob_std, in_std),
])
inputs = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
inputs = tf.placeholder(
tf.float32, [None] + list(self.preprocessor_shape))
# TODO(ekl): we should do clipping in a standard RLlib preprocessor
clipped_inputs = tf.clip_by_value(

View file

@ -6,6 +6,8 @@ import gym
from ray.rllib.models.action_dist import (
Categorical, Deterministic, DiagGaussian)
from ray.rllib.models.preprocessors import (
NoPreprocessor, AtariRamPreprocessor, AtariPixelPreprocessor)
from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.visionnet import VisionNetwork
from ray.rllib.models.convnet import ConvolutionalNetwork
@ -83,4 +85,17 @@ class ModelCatalog(object):
preprocessor (Preprocessor): Preprocessor for the env observations.
"""
raise NotImplementedError
if env_name == "Pong-v0":
return AtariPixelPreprocessor()
elif env_name == "Pong-ram-v3":
return AtariRamPreprocessor()
elif env_name == "CartPole-v0" or env_name == "CartPole-v1":
return NoPreprocessor()
elif env_name == "Hopper-v1":
return NoPreprocessor()
elif env_name == "Walker2d-v1":
return NoPreprocessor()
elif env_name == "Humanoid-v1" or env_name == "Pendulum-v0":
return NoPreprocessor()
else:
return AtariPixelPreprocessor()

View file

@ -1,14 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# TODO(ekl) implement common preprocessors
class Preprocessor(object):
def output_shape(self):
"""Returns the new output shape, or None if unchanged."""
raise NotImplementedError
def preprocess(self, observation):
"""Returns the preprocessed observation."""
raise NotImplementedError

View file

@ -0,0 +1,41 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class Preprocessor(object):
"""Defines an abstract observation preprocessor function."""
def transform_shape(self, obs_shape):
"""Returns the preprocessed observation shape."""
raise NotImplementedError
def transform(self, observation):
"""Returns the preprocessed observation."""
raise NotImplementedError
class AtariPixelPreprocessor(Preprocessor):
def transform_shape(self, obs_shape):
return (80, 80, 3)
# TODO(ekl) why does this need to return an extra size-1 dim (the [None])
def transform(self, observation):
"""Downsamples images from (210, 160, 3) to (80, 80, 3)."""
return (observation[25:-25:2, ::2, :][None] - 128) / 128
class AtariRamPreprocessor(Preprocessor):
def transform_shape(self, obs_shape):
return (128,)
def transform(self, observation):
return (observation - 128) / 128
class NoPreprocessor(Preprocessor):
def transform_shape(self, obs_shape):
return obs_shape
def transform(self, observation):
return observation

View file

@ -46,8 +46,6 @@ class Agent(object):
self.config = config
self.logdir = logdir
self.env = BatchedEnv(name, batchsize, preprocessor=preprocessor)
if preprocessor.shape is None:
preprocessor.shape = self.env.observation_space.shape
if is_remote:
config_proto = tf.ConfigProto()
else:
@ -62,8 +60,11 @@ class Agent(object):
# Defines the training inputs.
self.kl_coeff = tf.placeholder(
name="newkl", shape=(), dtype=tf.float32)
self.preprocessor_shape = preprocessor.transform_shape(
self.env.observation_space.shape)
self.observations = tf.placeholder(
tf.float32, shape=(None,) + preprocessor.shape)
tf.float32, shape=(None,) + self.preprocessor_shape)
self.advantages = tf.placeholder(tf.float32, shape=(None,))
action_space = self.env.action_space
@ -121,7 +122,8 @@ class Agent(object):
self.common_policy = self.par_opt.get_common_loss()
self.variables = ray.experimental.TensorFlowVariables(
self.common_policy.loss, self.sess)
self.observation_filter = MeanStdFilter(preprocessor.shape, clip=None)
self.observation_filter = MeanStdFilter(
self.preprocessor_shape, clip=None)
self.reward_filter = MeanStdFilter((), clip=5.0)
self.sess.run(tf.global_variables_initializer())

View file

@ -6,31 +6,6 @@ import gym
import numpy as np
class AtariPixelPreprocessor(object):
def __init__(self):
self.shape = (80, 80, 3)
def __call__(self, observation):
"Convert images from (210, 160, 3) to (3, 80, 80) by downsampling."
return (observation[25:-25:2, ::2, :][None] - 128) / 128
class AtariRamPreprocessor(object):
def __init__(self):
self.shape = (128,)
def __call__(self, observation):
return (observation - 128) / 128
class NoPreprocessor(object):
def __init__(self):
self.shape = None
def __call__(self, observation):
return observation
class BatchedEnv(object):
"""This holds multiple gym envs and performs steps on all of them."""
def __init__(self, name, batchsize, preprocessor=None):
@ -42,7 +17,8 @@ class BatchedEnv(object):
else lambda obs: obs[None])
def reset(self):
observations = [self.preprocessor(env.reset()) for env in self.envs]
observations = [
self.preprocessor.transform(env.reset()) for env in self.envs]
self.shape = observations[0].shape
self.dones = [False for _ in range(self.batchsize)]
return np.vstack(observations)
@ -58,7 +34,7 @@ class BatchedEnv(object):
observation, reward, done, info = self.envs[i].step(action)
if render:
self.envs[0].render()
observations.append(self.preprocessor(observation))
observations.append(self.preprocessor.transform(observation))
rewards.append(reward)
self.dones[i] = done
return (np.vstack(observations), np.array(rewards, dtype="float32"),

View file

@ -10,9 +10,8 @@ import tensorflow as tf
import ray
from ray.rllib.common import Algorithm, TrainingResult
from ray.rllib.models import ModelCatalog
from ray.rllib.policy_gradient.agent import Agent, RemoteAgent
from ray.rllib.policy_gradient.env import (
NoPreprocessor, AtariRamPreprocessor, AtariPixelPreprocessor)
from ray.rllib.policy_gradient.rollout import collect_samples
from ray.rllib.policy_gradient.utils import shuffle
@ -75,23 +74,7 @@ class PolicyGradient(Algorithm):
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
# TODO(ekl): preprocessor should be associated with the env elsewhere
if self.env_name == "Pong-v0":
preprocessor = AtariPixelPreprocessor()
elif self.env_name == "Pong-ram-v3":
preprocessor = AtariRamPreprocessor()
elif self.env_name == "CartPole-v0" or self.env_name == "CartPole-v1":
preprocessor = NoPreprocessor()
elif self.env_name == "Hopper-v1":
preprocessor = NoPreprocessor()
elif self.env_name == "Walker2d-v1":
preprocessor = NoPreprocessor()
elif self.env_name == "Humanoid-v1":
preprocessor = NoPreprocessor()
else:
preprocessor = AtariPixelPreprocessor()
self.preprocessor = preprocessor
self.preprocessor = ModelCatalog.get_preprocessor(self.env_name)
self.global_step = 0
self.j = 0
self.kl_coeff = config["kl_coeff"]

View file

@ -60,8 +60,8 @@ if __name__ == "__main__":
env_name, config, upload_dir=args.upload_dir)
else:
assert False, ("Unknown algorithm, check --alg argument. Valid "
"choices are PolicyGradientPolicyGradient, "
"EvolutionStrategies, DQN and A3C.")
"choices are PolicyGradient, EvolutionStrategies, "
"DQN and A3C.")
result_logger = ray.rllib.common.RLLibLogger(
os.path.join(alg.logdir, "result.json"))