From cd9dc398ff2110b84573bd1f28a332b91ed73a8e Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 23 Oct 2017 23:16:52 -0700 Subject: [PATCH] [rllib] Support discrete observation spaces such as FrozenLake-v0 (#1140) * add * remove transform_shape * fix test * fix --- python/ray/rllib/dqn/common/wrappers.py | 3 +- python/ray/rllib/dqn/dqn.py | 4 +- python/ray/rllib/es/es.py | 14 ++---- python/ray/rllib/es/policies.py | 12 ++--- python/ray/rllib/models/catalog.py | 38 ++++++++------ python/ray/rllib/models/preprocessors.py | 58 +++++++++++++--------- python/ray/rllib/ppo/env.py | 3 +- python/ray/rllib/ppo/runner.py | 7 +-- python/ray/rllib/test/test_catalog.py | 12 ++++- test/jenkins_tests/run_multi_node_tests.sh | 13 +++++ 10 files changed, 98 insertions(+), 66 deletions(-) diff --git a/python/ray/rllib/dqn/common/wrappers.py b/python/ray/rllib/dqn/common/wrappers.py index e7da2a7c5..2f96600b1 100644 --- a/python/ray/rllib/dqn/common/wrappers.py +++ b/python/ray/rllib/dqn/common/wrappers.py @@ -216,7 +216,8 @@ class FrameStack(gym.Wrapper): def wrap_dqn(env, options): """Apply a common set of wrappers for DQN.""" - is_atari = (env.observation_space.shape == ModelCatalog.ATARI_OBS_SHAPE) + is_atari = (hasattr(env.observation_space, "shape") and + env.observation_space.shape == ModelCatalog.ATARI_OBS_SHAPE) if is_atari: env = EpisodicLifeEnv(env) diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index b05883231..fb97fe742 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -198,8 +198,8 @@ class Actor(object): self.dqn_graph.apply_gradients(self.sess, grad) def stats(self, num_timesteps): - mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 1) - mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 1) + mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 5) + mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 5) exploration = self.exploration.value(num_timesteps) return ( mean_100ep_reward, diff --git a/python/ray/rllib/es/es.py b/python/ray/rllib/es/es.py index d93a3d942..a25f89f7a 100644 --- a/python/ray/rllib/es/es.py +++ b/python/ray/rllib/es/es.py @@ -74,10 +74,7 @@ class Worker(object): self.noise = SharedNoiseTable(noise) self.env = env_creator() - self.preprocessor = ModelCatalog.get_preprocessor( - self.env.spec.id, self.env.observation_space.shape) - self.preprocessor_shape = self.preprocessor.transform_shape( - self.env.observation_space.shape) + self.preprocessor = ModelCatalog.get_preprocessor(self.env) self.sess = utils.make_session(single_threaded=True) self.policy = policies.GenericPolicy( @@ -118,7 +115,7 @@ class Worker(object): noise_inds, returns, sign_returns, lengths = [], [], [], [] # We set eps=0 because we're incrementing only. - task_ob_stat = utils.RunningStat(self.preprocessor_shape, eps=0) + task_ob_stat = utils.RunningStat(self.preprocessor.shape, eps=0) # Perform some rollouts with noise. task_tstart = time.time() @@ -169,10 +166,7 @@ class ESAgent(Agent): } env = self.env_creator() - preprocessor = ModelCatalog.get_preprocessor( - env.spec.id, env.observation_space.shape) - preprocessor_shape = preprocessor.transform_shape( - env.observation_space.shape) + preprocessor = ModelCatalog.get_preprocessor(env) self.sess = utils.make_session(single_threaded=False) self.policy = policies.GenericPolicy( @@ -180,7 +174,7 @@ class ESAgent(Agent): **policy_params) tf_util.initialize() self.optimizer = optimizers.Adam(self.policy, self.config["stepsize"]) - self.ob_stat = utils.RunningStat(preprocessor_shape, eps=1e-2) + self.ob_stat = utils.RunningStat(preprocessor.shape, eps=1e-2) # Create the shared noise table. print("Creating shared noise table.") diff --git a/python/ray/rllib/es/policies.py b/python/ray/rllib/es/policies.py index c152328de..4d445937c 100644 --- a/python/ray/rllib/es/policies.py +++ b/python/ray/rllib/es/policies.py @@ -144,25 +144,25 @@ class GenericPolicy(Policy): def _initialize(self, ob_space, ac_space, preprocessor, ac_noise_std): self.ac_space = ac_space self.ac_noise_std = ac_noise_std - self.preprocessor_shape = preprocessor.transform_shape(ob_space.shape) + self.preprocessor = preprocessor with tf.variable_scope(type(self).__name__) as scope: # Observation normalization. ob_mean = tf.get_variable( - 'ob_mean', self.preprocessor_shape, tf.float32, + 'ob_mean', self.preprocessor.shape, tf.float32, tf.constant_initializer(np.nan), trainable=False) ob_std = tf.get_variable( - 'ob_std', self.preprocessor_shape, tf.float32, + 'ob_std', self.preprocessor.shape, tf.float32, tf.constant_initializer(np.nan), trainable=False) - in_mean = tf.placeholder(tf.float32, self.preprocessor_shape) - in_std = tf.placeholder(tf.float32, self.preprocessor_shape) + in_mean = tf.placeholder(tf.float32, self.preprocessor.shape) + in_std = tf.placeholder(tf.float32, self.preprocessor.shape) self._set_ob_mean_std = U.function([in_mean, in_std], [], updates=[ tf.assign(ob_mean, in_mean), tf.assign(ob_std, in_std), ]) inputs = tf.placeholder( - tf.float32, [None] + list(self.preprocessor_shape)) + tf.float32, [None] + list(self.preprocessor.shape)) # TODO(ekl): we should do clipping in a standard RLlib preprocessor clipped_inputs = tf.clip_by_value( diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index ca368e68e..ecbea794a 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -7,7 +7,8 @@ import gym from ray.rllib.models.action_dist import ( Categorical, Deterministic, DiagGaussian) from ray.rllib.models.preprocessors import ( - NoPreprocessor, AtariRamPreprocessor, AtariPixelPreprocessor) + NoPreprocessor, AtariRamPreprocessor, AtariPixelPreprocessor, + OneHotPreprocessor) from ray.rllib.models.fcnet import FullyConnectedNetwork from ray.rllib.models.visionnet import VisionNetwork @@ -84,18 +85,25 @@ class ModelCatalog(object): return FullyConnectedNetwork(inputs, num_outputs, options) @classmethod - def get_preprocessor(cls, env_name, obs_shape, options=dict()): + def get_preprocessor(cls, env, options=dict()): """Returns a suitable processor for the given environment. Args: - env_name (str): The name of the environment. - obs_shape (tuple): The shape of the env observation space. + env (gym.Env): The gym environment to preprocess. options (dict): Options to pass to the preprocessor. Returns: preprocessor (Preprocessor): Preprocessor for the env observations. """ + # For older gym versions that don't set shape for Discrete + if not hasattr(env.observation_space, "shape") and \ + isinstance(env.observation_space, gym.spaces.Discrete): + env.observation_space.shape = () + + env_name = env.spec.id + obs_shape = env.observation_space.shape + for k in options.keys(): if k not in MODEL_CONFIGS: raise Exception( @@ -107,15 +115,20 @@ class ModelCatalog(object): if env_name in cls._registered_preprocessor: return cls._registered_preprocessor[env_name](options) - if obs_shape == cls.ATARI_OBS_SHAPE: + if obs_shape == (): + print("Using one-hot preprocessor for discrete envs.") + preprocessor = OneHotPreprocessor + elif obs_shape == cls.ATARI_OBS_SHAPE: print("Assuming Atari pixel env, using AtariPixelPreprocessor.") - return AtariPixelPreprocessor(options) + preprocessor = AtariPixelPreprocessor elif obs_shape == cls.ATARI_RAM_OBS_SHAPE: print("Assuming Atari ram env, using AtariRamPreprocessor.") - return AtariRamPreprocessor(options) + preprocessor = AtariRamPreprocessor + else: + print("Non-atari env, not using any observation preprocessor.") + preprocessor = NoPreprocessor - print("Non-atari env, not using any observation preprocessor.") - return NoPreprocessor(options) + return preprocessor(env.observation_space, options) @classmethod def get_preprocessor_as_wrapper(cls, env, options=dict()): @@ -129,8 +142,7 @@ class ModelCatalog(object): wrapper (gym.ObservationWrapper): Preprocessor in wrapper form. """ - preprocessor = cls.get_preprocessor( - env.spec.id, env.observation_space.shape, options) + preprocessor = cls.get_preprocessor(env, options) return _RLlibPreprocessorWrapper(env, preprocessor) @classmethod @@ -154,9 +166,7 @@ class _RLlibPreprocessorWrapper(gym.ObservationWrapper): self.preprocessor = preprocessor from gym.spaces.box import Box - self.observation_space = Box( - -1.0, 1.0, - preprocessor.transform_shape(env.observation_space.shape)) + self.observation_space = Box(-1.0, 1.0, preprocessor.shape) def _observation(self, observation): return self.preprocessor.transform(observation) diff --git a/python/ray/rllib/models/preprocessors.py b/python/ray/rllib/models/preprocessors.py index 740358c0b..97ed9e5cd 100644 --- a/python/ray/rllib/models/preprocessors.py +++ b/python/ray/rllib/models/preprocessors.py @@ -6,19 +6,20 @@ import numpy as np class Preprocessor(object): - """Defines an abstract observation preprocessor function.""" + """Defines an abstract observation preprocessor function. - def __init__(self, options): - self.options = options + Attributes: + shape (obj): Shape of the preprocessed output. + """ + + def __init__(self, obs_space, options): + self._obs_space = obs_space + self._options = options self._init() def _init(self): pass - def transform_shape(self, obs_shape): - """Returns the preprocessed observation shape.""" - raise NotImplementedError - def transform(self, observation): """Returns the preprocessed observation.""" raise NotImplementedError @@ -26,32 +27,30 @@ class Preprocessor(object): class AtariPixelPreprocessor(Preprocessor): def _init(self): - self.grayscale = self.options.get("grayscale", False) - self.zero_mean = self.options.get("zero_mean", True) - self.dim = self.options.get("dim", 80) - - def transform_shape(self, obs_shape): - if self.grayscale: - return (self.dim, self.dim, 1) + self._grayscale = self._options.get("grayscale", False) + self._zero_mean = self._options.get("zero_mean", True) + self._dim = self._options.get("dim", 80) + if self._grayscale: + self.shape = (self._dim, self._dim, 1) else: - return (self.dim, self.dim, 3) + self.shape = (self._dim, self._dim, 3) # TODO(ekl) why does this need to return an extra size-1 dim (the [None]) def transform(self, observation): """Downsamples images from (210, 160, 3) by the configured factor.""" scaled = observation[25:-25, :, :] - if self.dim < 80: + if self._dim < 80: scaled = cv2.resize(scaled, (80, 80)) # OpenAI: Resize by half, then down to 42x42 (essentially mipmapping). # If we resize directly we lose pixels that, when mapped to 42x42, # aren't close enough to the pixel boundary. - scaled = cv2.resize(scaled, (self.dim, self.dim)) - if self.grayscale: + scaled = cv2.resize(scaled, (self._dim, self._dim)) + if self._grayscale: scaled = scaled.mean(2) scaled = scaled.astype(np.float32) # Rescale needed for maintaining 1 channel - scaled = np.reshape(scaled, [self.dim, self.dim, 1]) - if self.zero_mean: + scaled = np.reshape(scaled, [self._dim, self._dim, 1]) + if self._zero_mean: scaled = (scaled - 128) / 128 else: scaled *= 1.0 / 255.0 @@ -60,16 +59,27 @@ class AtariPixelPreprocessor(Preprocessor): # TODO(rliaw): Also should include the deepmind preprocessor class AtariRamPreprocessor(Preprocessor): - def transform_shape(self, obs_shape): - return (128,) + def _init(self): + self.shape = (128,) def transform(self, observation): return (observation - 128) / 128 +class OneHotPreprocessor(Preprocessor): + def _init(self): + assert self._obs_space.shape == () + self.shape = (self._obs_space.n,) + + def transform(self, observation): + arr = np.zeros(self._obs_space.n) + arr[observation] = 1 + return arr + + class NoPreprocessor(Preprocessor): - def transform_shape(self, obs_shape): - return obs_shape + def _init(self): + self.shape = self._obs_space.shape def transform(self, observation): return observation diff --git a/python/ray/rllib/ppo/env.py b/python/ray/rllib/ppo/env.py index fefd4dadb..7638f34c8 100644 --- a/python/ray/rllib/ppo/env.py +++ b/python/ray/rllib/ppo/env.py @@ -15,8 +15,7 @@ class BatchedEnv(object): self.action_space = self.envs[0].action_space self.batchsize = batchsize self.preprocessor = ModelCatalog.get_preprocessor( - self.envs[0].spec.id, self.envs[0].observation_space.shape, - options["model"]) + self.envs[0], options["model"]) self.extra_frameskip = options.get("extra_frameskip", 1) assert self.extra_frameskip >= 1 diff --git a/python/ray/rllib/ppo/runner.py b/python/ray/rllib/ppo/runner.py index 22ad93d85..c63d3ba91 100644 --- a/python/ray/rllib/ppo/runner.py +++ b/python/ray/rllib/ppo/runner.py @@ -63,12 +63,9 @@ class Runner(object): self.kl_coeff = tf.placeholder( name="newkl", shape=(), dtype=tf.float32) - # The shape of the preprocessed observations. - self.preprocessor_shape = self.preprocessor.transform_shape( - self.env.observation_space.shape) # The input observations. self.observations = tf.placeholder( - tf.float32, shape=(None,) + self.preprocessor_shape) + tf.float32, shape=(None,) + self.preprocessor.shape) # Targets of the value function. self.returns = tf.placeholder(tf.float32, shape=(None,)) # Advantage values in the policy gradient estimator. @@ -142,7 +139,7 @@ class Runner(object): self.common_policy.loss, self.sess) if config["observation_filter"] == "MeanStdFilter": self.observation_filter = MeanStdFilter( - self.preprocessor_shape, clip=None) + self.preprocessor.shape, clip=None) elif config["observation_filter"] == "NoFilter": self.observation_filter = NoFilter() else: diff --git a/python/ray/rllib/test/test_catalog.py b/python/ray/rllib/test/test_catalog.py index e229e2616..671e88f72 100644 --- a/python/ray/rllib/test/test_catalog.py +++ b/python/ray/rllib/test/test_catalog.py @@ -3,12 +3,20 @@ from ray.rllib.models.preprocessors import Preprocessor class FakePreprocessor(Preprocessor): - def __init__(self, options): pass +class FakeEnv(object): + def __init__(self): + self.observation_space = lambda: None + self.observation_space.shape = () + self.spec = lambda: None + self.spec.id = "FakeEnv-v0" + + def test_preprocessor(): ModelCatalog.register_preprocessor("FakeEnv-v0", FakePreprocessor) - preprocessor = ModelCatalog.get_preprocessor("FakeEnv-v0", (1, 1)) + env = FakeEnv() + preprocessor = ModelCatalog.get_preprocessor(env) assert type(preprocessor) == FakePreprocessor diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 2b082ba88..909898f27 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -98,6 +98,19 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \ --stop '{"training_iteration": 2}' \ --config '{"lr": 1e-3, "schedule_max_timesteps": 100000, "exploration_fraction": 0.1, "exploration_final_eps": 0.02, "dueling": false, "hiddens": [], "model": {"fcnet_hiddens": [64], "fcnet_activation": "relu"}}' +docker run --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env FrozenLake-v0 \ + --alg DQN \ + --stop '{"training_iteration": 2}' + +docker run --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env FrozenLake-v0 \ + --alg PPO \ + --stop '{"training_iteration": 2}' \ + --config '{"num_sgd_iter": 10, "sgd_batchsize": 64, "timesteps_per_batch": 1000, "num_workers": 1}' + docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env PongDeterministic-v4 \