[rllib] Support discrete observation spaces such as FrozenLake-v0 (#1140)

* add

* remove transform_shape

* fix test

* fix
This commit is contained in:
Eric Liang 2017-10-23 23:16:52 -07:00 committed by Richard Liaw
parent 0c9817fa76
commit cd9dc398ff
10 changed files with 98 additions and 66 deletions

View file

@ -216,7 +216,8 @@ class FrameStack(gym.Wrapper):
def wrap_dqn(env, options):
"""Apply a common set of wrappers for DQN."""
is_atari = (env.observation_space.shape == ModelCatalog.ATARI_OBS_SHAPE)
is_atari = (hasattr(env.observation_space, "shape") and
env.observation_space.shape == ModelCatalog.ATARI_OBS_SHAPE)
if is_atari:
env = EpisodicLifeEnv(env)

View file

@ -198,8 +198,8 @@ class Actor(object):
self.dqn_graph.apply_gradients(self.sess, grad)
def stats(self, num_timesteps):
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 1)
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 1)
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 5)
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 5)
exploration = self.exploration.value(num_timesteps)
return (
mean_100ep_reward,

View file

@ -74,10 +74,7 @@ class Worker(object):
self.noise = SharedNoiseTable(noise)
self.env = env_creator()
self.preprocessor = ModelCatalog.get_preprocessor(
self.env.spec.id, self.env.observation_space.shape)
self.preprocessor_shape = self.preprocessor.transform_shape(
self.env.observation_space.shape)
self.preprocessor = ModelCatalog.get_preprocessor(self.env)
self.sess = utils.make_session(single_threaded=True)
self.policy = policies.GenericPolicy(
@ -118,7 +115,7 @@ class Worker(object):
noise_inds, returns, sign_returns, lengths = [], [], [], []
# We set eps=0 because we're incrementing only.
task_ob_stat = utils.RunningStat(self.preprocessor_shape, eps=0)
task_ob_stat = utils.RunningStat(self.preprocessor.shape, eps=0)
# Perform some rollouts with noise.
task_tstart = time.time()
@ -169,10 +166,7 @@ class ESAgent(Agent):
}
env = self.env_creator()
preprocessor = ModelCatalog.get_preprocessor(
env.spec.id, env.observation_space.shape)
preprocessor_shape = preprocessor.transform_shape(
env.observation_space.shape)
preprocessor = ModelCatalog.get_preprocessor(env)
self.sess = utils.make_session(single_threaded=False)
self.policy = policies.GenericPolicy(
@ -180,7 +174,7 @@ class ESAgent(Agent):
**policy_params)
tf_util.initialize()
self.optimizer = optimizers.Adam(self.policy, self.config["stepsize"])
self.ob_stat = utils.RunningStat(preprocessor_shape, eps=1e-2)
self.ob_stat = utils.RunningStat(preprocessor.shape, eps=1e-2)
# Create the shared noise table.
print("Creating shared noise table.")

View file

@ -144,25 +144,25 @@ class GenericPolicy(Policy):
def _initialize(self, ob_space, ac_space, preprocessor, ac_noise_std):
self.ac_space = ac_space
self.ac_noise_std = ac_noise_std
self.preprocessor_shape = preprocessor.transform_shape(ob_space.shape)
self.preprocessor = preprocessor
with tf.variable_scope(type(self).__name__) as scope:
# Observation normalization.
ob_mean = tf.get_variable(
'ob_mean', self.preprocessor_shape, tf.float32,
'ob_mean', self.preprocessor.shape, tf.float32,
tf.constant_initializer(np.nan), trainable=False)
ob_std = tf.get_variable(
'ob_std', self.preprocessor_shape, tf.float32,
'ob_std', self.preprocessor.shape, tf.float32,
tf.constant_initializer(np.nan), trainable=False)
in_mean = tf.placeholder(tf.float32, self.preprocessor_shape)
in_std = tf.placeholder(tf.float32, self.preprocessor_shape)
in_mean = tf.placeholder(tf.float32, self.preprocessor.shape)
in_std = tf.placeholder(tf.float32, self.preprocessor.shape)
self._set_ob_mean_std = U.function([in_mean, in_std], [], updates=[
tf.assign(ob_mean, in_mean),
tf.assign(ob_std, in_std),
])
inputs = tf.placeholder(
tf.float32, [None] + list(self.preprocessor_shape))
tf.float32, [None] + list(self.preprocessor.shape))
# TODO(ekl): we should do clipping in a standard RLlib preprocessor
clipped_inputs = tf.clip_by_value(

View file

@ -7,7 +7,8 @@ import gym
from ray.rllib.models.action_dist import (
Categorical, Deterministic, DiagGaussian)
from ray.rllib.models.preprocessors import (
NoPreprocessor, AtariRamPreprocessor, AtariPixelPreprocessor)
NoPreprocessor, AtariRamPreprocessor, AtariPixelPreprocessor,
OneHotPreprocessor)
from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.visionnet import VisionNetwork
@ -84,18 +85,25 @@ class ModelCatalog(object):
return FullyConnectedNetwork(inputs, num_outputs, options)
@classmethod
def get_preprocessor(cls, env_name, obs_shape, options=dict()):
def get_preprocessor(cls, env, options=dict()):
"""Returns a suitable processor for the given environment.
Args:
env_name (str): The name of the environment.
obs_shape (tuple): The shape of the env observation space.
env (gym.Env): The gym environment to preprocess.
options (dict): Options to pass to the preprocessor.
Returns:
preprocessor (Preprocessor): Preprocessor for the env observations.
"""
# For older gym versions that don't set shape for Discrete
if not hasattr(env.observation_space, "shape") and \
isinstance(env.observation_space, gym.spaces.Discrete):
env.observation_space.shape = ()
env_name = env.spec.id
obs_shape = env.observation_space.shape
for k in options.keys():
if k not in MODEL_CONFIGS:
raise Exception(
@ -107,15 +115,20 @@ class ModelCatalog(object):
if env_name in cls._registered_preprocessor:
return cls._registered_preprocessor[env_name](options)
if obs_shape == cls.ATARI_OBS_SHAPE:
if obs_shape == ():
print("Using one-hot preprocessor for discrete envs.")
preprocessor = OneHotPreprocessor
elif obs_shape == cls.ATARI_OBS_SHAPE:
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
return AtariPixelPreprocessor(options)
preprocessor = AtariPixelPreprocessor
elif obs_shape == cls.ATARI_RAM_OBS_SHAPE:
print("Assuming Atari ram env, using AtariRamPreprocessor.")
return AtariRamPreprocessor(options)
preprocessor = AtariRamPreprocessor
else:
print("Non-atari env, not using any observation preprocessor.")
preprocessor = NoPreprocessor
print("Non-atari env, not using any observation preprocessor.")
return NoPreprocessor(options)
return preprocessor(env.observation_space, options)
@classmethod
def get_preprocessor_as_wrapper(cls, env, options=dict()):
@ -129,8 +142,7 @@ class ModelCatalog(object):
wrapper (gym.ObservationWrapper): Preprocessor in wrapper form.
"""
preprocessor = cls.get_preprocessor(
env.spec.id, env.observation_space.shape, options)
preprocessor = cls.get_preprocessor(env, options)
return _RLlibPreprocessorWrapper(env, preprocessor)
@classmethod
@ -154,9 +166,7 @@ class _RLlibPreprocessorWrapper(gym.ObservationWrapper):
self.preprocessor = preprocessor
from gym.spaces.box import Box
self.observation_space = Box(
-1.0, 1.0,
preprocessor.transform_shape(env.observation_space.shape))
self.observation_space = Box(-1.0, 1.0, preprocessor.shape)
def _observation(self, observation):
return self.preprocessor.transform(observation)

View file

@ -6,19 +6,20 @@ import numpy as np
class Preprocessor(object):
"""Defines an abstract observation preprocessor function."""
"""Defines an abstract observation preprocessor function.
def __init__(self, options):
self.options = options
Attributes:
shape (obj): Shape of the preprocessed output.
"""
def __init__(self, obs_space, options):
self._obs_space = obs_space
self._options = options
self._init()
def _init(self):
pass
def transform_shape(self, obs_shape):
"""Returns the preprocessed observation shape."""
raise NotImplementedError
def transform(self, observation):
"""Returns the preprocessed observation."""
raise NotImplementedError
@ -26,32 +27,30 @@ class Preprocessor(object):
class AtariPixelPreprocessor(Preprocessor):
def _init(self):
self.grayscale = self.options.get("grayscale", False)
self.zero_mean = self.options.get("zero_mean", True)
self.dim = self.options.get("dim", 80)
def transform_shape(self, obs_shape):
if self.grayscale:
return (self.dim, self.dim, 1)
self._grayscale = self._options.get("grayscale", False)
self._zero_mean = self._options.get("zero_mean", True)
self._dim = self._options.get("dim", 80)
if self._grayscale:
self.shape = (self._dim, self._dim, 1)
else:
return (self.dim, self.dim, 3)
self.shape = (self._dim, self._dim, 3)
# TODO(ekl) why does this need to return an extra size-1 dim (the [None])
def transform(self, observation):
"""Downsamples images from (210, 160, 3) by the configured factor."""
scaled = observation[25:-25, :, :]
if self.dim < 80:
if self._dim < 80:
scaled = cv2.resize(scaled, (80, 80))
# OpenAI: Resize by half, then down to 42x42 (essentially mipmapping).
# If we resize directly we lose pixels that, when mapped to 42x42,
# aren't close enough to the pixel boundary.
scaled = cv2.resize(scaled, (self.dim, self.dim))
if self.grayscale:
scaled = cv2.resize(scaled, (self._dim, self._dim))
if self._grayscale:
scaled = scaled.mean(2)
scaled = scaled.astype(np.float32)
# Rescale needed for maintaining 1 channel
scaled = np.reshape(scaled, [self.dim, self.dim, 1])
if self.zero_mean:
scaled = np.reshape(scaled, [self._dim, self._dim, 1])
if self._zero_mean:
scaled = (scaled - 128) / 128
else:
scaled *= 1.0 / 255.0
@ -60,16 +59,27 @@ class AtariPixelPreprocessor(Preprocessor):
# TODO(rliaw): Also should include the deepmind preprocessor
class AtariRamPreprocessor(Preprocessor):
def transform_shape(self, obs_shape):
return (128,)
def _init(self):
self.shape = (128,)
def transform(self, observation):
return (observation - 128) / 128
class OneHotPreprocessor(Preprocessor):
def _init(self):
assert self._obs_space.shape == ()
self.shape = (self._obs_space.n,)
def transform(self, observation):
arr = np.zeros(self._obs_space.n)
arr[observation] = 1
return arr
class NoPreprocessor(Preprocessor):
def transform_shape(self, obs_shape):
return obs_shape
def _init(self):
self.shape = self._obs_space.shape
def transform(self, observation):
return observation

View file

@ -15,8 +15,7 @@ class BatchedEnv(object):
self.action_space = self.envs[0].action_space
self.batchsize = batchsize
self.preprocessor = ModelCatalog.get_preprocessor(
self.envs[0].spec.id, self.envs[0].observation_space.shape,
options["model"])
self.envs[0], options["model"])
self.extra_frameskip = options.get("extra_frameskip", 1)
assert self.extra_frameskip >= 1

View file

@ -63,12 +63,9 @@ class Runner(object):
self.kl_coeff = tf.placeholder(
name="newkl", shape=(), dtype=tf.float32)
# The shape of the preprocessed observations.
self.preprocessor_shape = self.preprocessor.transform_shape(
self.env.observation_space.shape)
# The input observations.
self.observations = tf.placeholder(
tf.float32, shape=(None,) + self.preprocessor_shape)
tf.float32, shape=(None,) + self.preprocessor.shape)
# Targets of the value function.
self.returns = tf.placeholder(tf.float32, shape=(None,))
# Advantage values in the policy gradient estimator.
@ -142,7 +139,7 @@ class Runner(object):
self.common_policy.loss, self.sess)
if config["observation_filter"] == "MeanStdFilter":
self.observation_filter = MeanStdFilter(
self.preprocessor_shape, clip=None)
self.preprocessor.shape, clip=None)
elif config["observation_filter"] == "NoFilter":
self.observation_filter = NoFilter()
else:

View file

@ -3,12 +3,20 @@ from ray.rllib.models.preprocessors import Preprocessor
class FakePreprocessor(Preprocessor):
def __init__(self, options):
pass
class FakeEnv(object):
def __init__(self):
self.observation_space = lambda: None
self.observation_space.shape = ()
self.spec = lambda: None
self.spec.id = "FakeEnv-v0"
def test_preprocessor():
ModelCatalog.register_preprocessor("FakeEnv-v0", FakePreprocessor)
preprocessor = ModelCatalog.get_preprocessor("FakeEnv-v0", (1, 1))
env = FakeEnv()
preprocessor = ModelCatalog.get_preprocessor(env)
assert type(preprocessor) == FakePreprocessor

View file

@ -98,6 +98,19 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \
--stop '{"training_iteration": 2}' \
--config '{"lr": 1e-3, "schedule_max_timesteps": 100000, "exploration_fraction": 0.1, "exploration_final_eps": 0.02, "dueling": false, "hiddens": [], "model": {"fcnet_hiddens": [64], "fcnet_activation": "relu"}}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env FrozenLake-v0 \
--alg DQN \
--stop '{"training_iteration": 2}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env FrozenLake-v0 \
--alg PPO \
--stop '{"training_iteration": 2}' \
--config '{"num_sgd_iter": 10, "sgd_batchsize": 64, "timesteps_per_batch": 1000, "num_workers": 1}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env PongDeterministic-v4 \