mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Use RLlib preprocessors in DQN (fixes PongDeterministic-v4) (#1124)
* fix pong * rename * update
This commit is contained in:
parent
d6062ef8f6
commit
802941994d
6 changed files with 65 additions and 63 deletions
|
@ -3,7 +3,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
from gym.spaces.box import Box
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
@ -15,24 +14,11 @@ logger.setLevel(logging.INFO)
|
|||
|
||||
def create_and_wrap(env_creator, options):
|
||||
env = env_creator()
|
||||
env = RLLibPreprocessing(env.spec.id, env, options)
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(env, options)
|
||||
env = Diagnostic(env)
|
||||
return env
|
||||
|
||||
|
||||
class RLLibPreprocessing(gym.ObservationWrapper):
|
||||
def __init__(self, env_id, env=None, options=dict()):
|
||||
super(RLLibPreprocessing, self).__init__(env)
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(
|
||||
env_id, env.observation_space.shape, options)
|
||||
self._process_shape = self.preprocessor.transform_shape(
|
||||
env.observation_space.shape)
|
||||
self.observation_space = Box(-1.0, 1.0, self._process_shape)
|
||||
|
||||
def _observation(self, observation):
|
||||
return self.preprocessor.transform(observation)
|
||||
|
||||
|
||||
class Diagnostic(gym.Wrapper):
|
||||
def __init__(self, env=None):
|
||||
super(Diagnostic, self).__init__(env)
|
||||
|
|
|
@ -9,6 +9,8 @@ import numpy as np
|
|||
from collections import deque
|
||||
from gym import spaces
|
||||
|
||||
from ray.rllib.models import ModelCatalog
|
||||
|
||||
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
def __init__(self, env=None, noop_max=30):
|
||||
|
@ -186,7 +188,7 @@ class FrameStack(gym.Wrapper):
|
|||
|
||||
See Also
|
||||
--------
|
||||
ray.rllib.dqn.common.atari_wrappers.LazyFrames
|
||||
LazyFrames
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.k = k
|
||||
|
@ -211,41 +213,23 @@ class FrameStack(gym.Wrapper):
|
|||
return LazyFrames(list(self.frames))
|
||||
|
||||
|
||||
class ScaledFloatFrame(gym.ObservationWrapper):
|
||||
def _observation(self, obs):
|
||||
# careful! This undoes the memory optimization, use
|
||||
# with smaller replay buffers only.
|
||||
return np.array(obs).astype(np.float32) / 255.0
|
||||
def wrap_dqn(env, options):
|
||||
"""Apply a common set of wrappers for DQN."""
|
||||
|
||||
is_atari = (env.observation_space.shape == ModelCatalog.ATARI_OBS_SHAPE)
|
||||
|
||||
if is_atari:
|
||||
env = EpisodicLifeEnv(env)
|
||||
env = NoopResetEnv(env, noop_max=30)
|
||||
if 'NoFrameskip' in env.spec.id:
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||
env = FireResetEnv(env)
|
||||
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(env, options)
|
||||
|
||||
if is_atari:
|
||||
env = FrameStack(env, 4)
|
||||
env = ClippedRewardsWrapper(env)
|
||||
|
||||
def wrap_dqn(env):
|
||||
"""Apply a common set of wrappers for Atari games."""
|
||||
assert 'NoFrameskip' in env.spec.id
|
||||
env = EpisodicLifeEnv(env)
|
||||
env = NoopResetEnv(env, noop_max=30)
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||
env = FireResetEnv(env)
|
||||
env = ProcessFrame80(env)
|
||||
env = FrameStack(env, 4)
|
||||
env = ClippedRewardsWrapper(env)
|
||||
return env
|
||||
|
||||
|
||||
class A2cProcessFrame(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(80, 80, 1))
|
||||
|
||||
def _step(self, action):
|
||||
ob, reward, done, info = self.env.step(action)
|
||||
return A2cProcessFrame.process(ob), reward, done, info
|
||||
|
||||
def _reset(self):
|
||||
return A2cProcessFrame.process(self.env.reset())
|
||||
|
||||
@staticmethod
|
||||
def process(frame):
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
||||
frame = cv2.resize(frame, (80, 80), interpolation=cv2.INTER_AREA)
|
||||
return frame.reshape(80, 80, 1)
|
|
@ -12,8 +12,7 @@ import tensorflow as tf
|
|||
import ray
|
||||
from ray.rllib.common import Agent, TrainingResult
|
||||
from ray.rllib.dqn import logger, models
|
||||
from ray.rllib.dqn.common.atari_wrappers_deprecated \
|
||||
import wrap_dqn, ScaledFloatFrame
|
||||
from ray.rllib.dqn.common.wrappers import wrap_dqn
|
||||
from ray.rllib.dqn.common.schedules import LinearSchedule
|
||||
from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
|
@ -104,9 +103,7 @@ DEFAULT_CONFIG = dict(
|
|||
class Actor(object):
|
||||
def __init__(self, env_creator, config, logdir):
|
||||
env = env_creator()
|
||||
# TODO(ekl): replace this with RLlib preprocessors
|
||||
if "NoFrameskip" in env.spec.id:
|
||||
env = ScaledFloatFrame(wrap_dqn(env))
|
||||
env = wrap_dqn(env, config["model"])
|
||||
self.env = env
|
||||
self.config = config
|
||||
|
||||
|
|
|
@ -152,7 +152,7 @@ class DQNGraph(object):
|
|||
with tf.variable_scope("q_func", reuse=True):
|
||||
q_tp1_using_online_net = _build_q_network(
|
||||
self.obs_tp1, num_actions, config)
|
||||
q_tp1_best_using_online_net = tf.arg_max(q_tp1_using_online_net, 1)
|
||||
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
|
||||
q_tp1_best = tf.reduce_sum(
|
||||
self.q_tp1 * tf.one_hot(
|
||||
q_tp1_best_using_online_net, num_actions), 1)
|
||||
|
|
|
@ -34,6 +34,9 @@ class ModelCatalog(object):
|
|||
action_op = dist.sample()
|
||||
"""
|
||||
|
||||
ATARI_OBS_SHAPE = (210, 160, 3)
|
||||
ATARI_RAM_OBS_SHAPE = (128,)
|
||||
|
||||
_registered_preprocessor = dict()
|
||||
|
||||
@staticmethod
|
||||
|
@ -87,33 +90,49 @@ class ModelCatalog(object):
|
|||
Args:
|
||||
env_name (str): The name of the environment.
|
||||
obs_shape (tuple): The shape of the env observation space.
|
||||
options (dict): Options to pass to the preprocessor.
|
||||
|
||||
Returns:
|
||||
preprocessor (Preprocessor): Preprocessor for the env observations.
|
||||
"""
|
||||
|
||||
ATARI_OBS_SHAPE = (210, 160, 3)
|
||||
ATARI_RAM_OBS_SHAPE = (128,)
|
||||
|
||||
for k in options.keys():
|
||||
if k not in MODEL_CONFIGS:
|
||||
raise Exception(
|
||||
"Unknown config key `{}`, all keys: {}".format(
|
||||
k, MODEL_CONFIGS))
|
||||
|
||||
print("Observation shape is {}".format(obs_shape))
|
||||
|
||||
if env_name in cls._registered_preprocessor:
|
||||
return cls._registered_preprocessor[env_name](options)
|
||||
|
||||
if obs_shape == ATARI_OBS_SHAPE:
|
||||
if obs_shape == cls.ATARI_OBS_SHAPE:
|
||||
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
|
||||
return AtariPixelPreprocessor(options)
|
||||
elif obs_shape == ATARI_RAM_OBS_SHAPE:
|
||||
elif obs_shape == cls.ATARI_RAM_OBS_SHAPE:
|
||||
print("Assuming Atari ram env, using AtariRamPreprocessor.")
|
||||
return AtariRamPreprocessor(options)
|
||||
|
||||
print("Non-atari env, not using any observation preprocessor.")
|
||||
return NoPreprocessor(options)
|
||||
|
||||
@classmethod
|
||||
def get_preprocessor_as_wrapper(cls, env, options=dict()):
|
||||
"""Returns a preprocessor as a gym observation wrapper.
|
||||
|
||||
Args:
|
||||
env (gym.Env): The gym environment to wrap.
|
||||
options (dict): Options to pass to the preprocessor.
|
||||
|
||||
Returns:
|
||||
wrapper (gym.ObservationWrapper): Preprocessor in wrapper form.
|
||||
"""
|
||||
|
||||
preprocessor = cls.get_preprocessor(
|
||||
env.spec.id, env.observation_space.shape, options)
|
||||
return _RLlibPreprocessorWrapper(env, preprocessor)
|
||||
|
||||
@classmethod
|
||||
def register_preprocessor(cls, env_name, preprocessor_class):
|
||||
"""Register a preprocessor class for a specific environment.
|
||||
|
@ -125,3 +144,19 @@ class ModelCatalog(object):
|
|||
Python class of the distribution.
|
||||
"""
|
||||
cls._registered_preprocessor[env_name] = preprocessor_class
|
||||
|
||||
|
||||
class _RLlibPreprocessorWrapper(gym.ObservationWrapper):
|
||||
"""Adapts a RLlib preprocessor for use as an observation wrapper."""
|
||||
|
||||
def __init__(self, env, preprocessor):
|
||||
super(_RLlibPreprocessorWrapper, self).__init__(env)
|
||||
self.preprocessor = preprocessor
|
||||
|
||||
from gym.spaces.box import Box
|
||||
self.observation_space = Box(
|
||||
-1.0, 1.0,
|
||||
preprocessor.transform_shape(env.observation_space.shape))
|
||||
|
||||
def _observation(self, observation):
|
||||
return self.preprocessor.transform(observation)
|
||||
|
|
|
@ -100,7 +100,7 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
|||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env PongNoFrameskip-v4 \
|
||||
--env PongDeterministic-v4 \
|
||||
--alg DQN \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"lr": 1e-4, "schedule_max_timesteps": 2000000, "buffer_size": 10000, "exploration_fraction": 0.1, "exploration_final_eps": 0.01, "sample_batch_size": 4, "learning_starts": 10000, "target_network_update_freq": 1000, "gamma": 0.99, "prioritized_replay": true}'
|
||||
|
|
Loading…
Add table
Reference in a new issue