diff --git a/python/ray/rllib/a3c/envs.py b/python/ray/rllib/a3c/envs.py index a8790ea48..f534087b0 100644 --- a/python/ray/rllib/a3c/envs.py +++ b/python/ray/rllib/a3c/envs.py @@ -3,7 +3,6 @@ from __future__ import division from __future__ import print_function import gym -from gym.spaces.box import Box import logging import time @@ -15,24 +14,11 @@ logger.setLevel(logging.INFO) def create_and_wrap(env_creator, options): env = env_creator() - env = RLLibPreprocessing(env.spec.id, env, options) + env = ModelCatalog.get_preprocessor_as_wrapper(env, options) env = Diagnostic(env) return env -class RLLibPreprocessing(gym.ObservationWrapper): - def __init__(self, env_id, env=None, options=dict()): - super(RLLibPreprocessing, self).__init__(env) - self.preprocessor = ModelCatalog.get_preprocessor( - env_id, env.observation_space.shape, options) - self._process_shape = self.preprocessor.transform_shape( - env.observation_space.shape) - self.observation_space = Box(-1.0, 1.0, self._process_shape) - - def _observation(self, observation): - return self.preprocessor.transform(observation) - - class Diagnostic(gym.Wrapper): def __init__(self, env=None): super(Diagnostic, self).__init__(env) diff --git a/python/ray/rllib/dqn/common/atari_wrappers_deprecated.py b/python/ray/rllib/dqn/common/wrappers.py similarity index 83% rename from python/ray/rllib/dqn/common/atari_wrappers_deprecated.py rename to python/ray/rllib/dqn/common/wrappers.py index 37d4125d4..e7da2a7c5 100644 --- a/python/ray/rllib/dqn/common/atari_wrappers_deprecated.py +++ b/python/ray/rllib/dqn/common/wrappers.py @@ -9,6 +9,8 @@ import numpy as np from collections import deque from gym import spaces +from ray.rllib.models import ModelCatalog + class NoopResetEnv(gym.Wrapper): def __init__(self, env=None, noop_max=30): @@ -186,7 +188,7 @@ class FrameStack(gym.Wrapper): See Also -------- - ray.rllib.dqn.common.atari_wrappers.LazyFrames + LazyFrames """ gym.Wrapper.__init__(self, env) self.k = k @@ -211,41 +213,23 @@ class FrameStack(gym.Wrapper): return LazyFrames(list(self.frames)) -class ScaledFloatFrame(gym.ObservationWrapper): - def _observation(self, obs): - # careful! This undoes the memory optimization, use - # with smaller replay buffers only. - return np.array(obs).astype(np.float32) / 255.0 +def wrap_dqn(env, options): + """Apply a common set of wrappers for DQN.""" + is_atari = (env.observation_space.shape == ModelCatalog.ATARI_OBS_SHAPE) + + if is_atari: + env = EpisodicLifeEnv(env) + env = NoopResetEnv(env, noop_max=30) + if 'NoFrameskip' in env.spec.id: + env = MaxAndSkipEnv(env, skip=4) + if 'FIRE' in env.unwrapped.get_action_meanings(): + env = FireResetEnv(env) + + env = ModelCatalog.get_preprocessor_as_wrapper(env, options) + + if is_atari: + env = FrameStack(env, 4) + env = ClippedRewardsWrapper(env) -def wrap_dqn(env): - """Apply a common set of wrappers for Atari games.""" - assert 'NoFrameskip' in env.spec.id - env = EpisodicLifeEnv(env) - env = NoopResetEnv(env, noop_max=30) - env = MaxAndSkipEnv(env, skip=4) - if 'FIRE' in env.unwrapped.get_action_meanings(): - env = FireResetEnv(env) - env = ProcessFrame80(env) - env = FrameStack(env, 4) - env = ClippedRewardsWrapper(env) return env - - -class A2cProcessFrame(gym.Wrapper): - def __init__(self, env): - gym.Wrapper.__init__(self, env) - self.observation_space = spaces.Box(low=0, high=255, shape=(80, 80, 1)) - - def _step(self, action): - ob, reward, done, info = self.env.step(action) - return A2cProcessFrame.process(ob), reward, done, info - - def _reset(self): - return A2cProcessFrame.process(self.env.reset()) - - @staticmethod - def process(frame): - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) - frame = cv2.resize(frame, (80, 80), interpolation=cv2.INTER_AREA) - return frame.reshape(80, 80, 1) diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index a60d4d769..76697e444 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -12,8 +12,7 @@ import tensorflow as tf import ray from ray.rllib.common import Agent, TrainingResult from ray.rllib.dqn import logger, models -from ray.rllib.dqn.common.atari_wrappers_deprecated \ - import wrap_dqn, ScaledFloatFrame +from ray.rllib.dqn.common.wrappers import wrap_dqn from ray.rllib.dqn.common.schedules import LinearSchedule from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer @@ -104,9 +103,7 @@ DEFAULT_CONFIG = dict( class Actor(object): def __init__(self, env_creator, config, logdir): env = env_creator() - # TODO(ekl): replace this with RLlib preprocessors - if "NoFrameskip" in env.spec.id: - env = ScaledFloatFrame(wrap_dqn(env)) + env = wrap_dqn(env, config["model"]) self.env = env self.config = config diff --git a/python/ray/rllib/dqn/models.py b/python/ray/rllib/dqn/models.py index f2bc94ff7..28fe422b9 100644 --- a/python/ray/rllib/dqn/models.py +++ b/python/ray/rllib/dqn/models.py @@ -152,7 +152,7 @@ class DQNGraph(object): with tf.variable_scope("q_func", reuse=True): q_tp1_using_online_net = _build_q_network( self.obs_tp1, num_actions, config) - q_tp1_best_using_online_net = tf.arg_max(q_tp1_using_online_net, 1) + q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1) q_tp1_best = tf.reduce_sum( self.q_tp1 * tf.one_hot( q_tp1_best_using_online_net, num_actions), 1) diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 99e48d67d..ca368e68e 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -34,6 +34,9 @@ class ModelCatalog(object): action_op = dist.sample() """ + ATARI_OBS_SHAPE = (210, 160, 3) + ATARI_RAM_OBS_SHAPE = (128,) + _registered_preprocessor = dict() @staticmethod @@ -87,33 +90,49 @@ class ModelCatalog(object): Args: env_name (str): The name of the environment. obs_shape (tuple): The shape of the env observation space. + options (dict): Options to pass to the preprocessor. Returns: preprocessor (Preprocessor): Preprocessor for the env observations. """ - ATARI_OBS_SHAPE = (210, 160, 3) - ATARI_RAM_OBS_SHAPE = (128,) - for k in options.keys(): if k not in MODEL_CONFIGS: raise Exception( "Unknown config key `{}`, all keys: {}".format( k, MODEL_CONFIGS)) + print("Observation shape is {}".format(obs_shape)) + if env_name in cls._registered_preprocessor: return cls._registered_preprocessor[env_name](options) - if obs_shape == ATARI_OBS_SHAPE: + if obs_shape == cls.ATARI_OBS_SHAPE: print("Assuming Atari pixel env, using AtariPixelPreprocessor.") return AtariPixelPreprocessor(options) - elif obs_shape == ATARI_RAM_OBS_SHAPE: + elif obs_shape == cls.ATARI_RAM_OBS_SHAPE: print("Assuming Atari ram env, using AtariRamPreprocessor.") return AtariRamPreprocessor(options) print("Non-atari env, not using any observation preprocessor.") return NoPreprocessor(options) + @classmethod + def get_preprocessor_as_wrapper(cls, env, options=dict()): + """Returns a preprocessor as a gym observation wrapper. + + Args: + env (gym.Env): The gym environment to wrap. + options (dict): Options to pass to the preprocessor. + + Returns: + wrapper (gym.ObservationWrapper): Preprocessor in wrapper form. + """ + + preprocessor = cls.get_preprocessor( + env.spec.id, env.observation_space.shape, options) + return _RLlibPreprocessorWrapper(env, preprocessor) + @classmethod def register_preprocessor(cls, env_name, preprocessor_class): """Register a preprocessor class for a specific environment. @@ -125,3 +144,19 @@ class ModelCatalog(object): Python class of the distribution. """ cls._registered_preprocessor[env_name] = preprocessor_class + + +class _RLlibPreprocessorWrapper(gym.ObservationWrapper): + """Adapts a RLlib preprocessor for use as an observation wrapper.""" + + def __init__(self, env, preprocessor): + super(_RLlibPreprocessorWrapper, self).__init__(env) + self.preprocessor = preprocessor + + from gym.spaces.box import Box + self.observation_space = Box( + -1.0, 1.0, + preprocessor.transform_shape(env.observation_space.shape)) + + def _observation(self, observation): + return self.preprocessor.transform(observation) diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index baf05f8df..08c679df9 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -100,7 +100,7 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \ docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ - --env PongNoFrameskip-v4 \ + --env PongDeterministic-v4 \ --alg DQN \ --stop '{"training_iteration": 2}' \ --config '{"lr": 1e-4, "schedule_max_timesteps": 2000000, "buffer_size": 10000, "exploration_fraction": 0.1, "exploration_final_eps": 0.01, "sample_batch_size": 4, "learning_starts": 10000, "target_network_update_freq": 1000, "gamma": 0.99, "prioritized_replay": true}'