mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[rllib] Use RLlib preprocessors in DQN (fixes PongDeterministic-v4) (#1124)
* fix pong * rename * update
This commit is contained in:
parent
d6062ef8f6
commit
802941994d
6 changed files with 65 additions and 63 deletions
|
@ -3,7 +3,6 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.spaces.box import Box
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
@ -15,24 +14,11 @@ logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
def create_and_wrap(env_creator, options):
|
def create_and_wrap(env_creator, options):
|
||||||
env = env_creator()
|
env = env_creator()
|
||||||
env = RLLibPreprocessing(env.spec.id, env, options)
|
env = ModelCatalog.get_preprocessor_as_wrapper(env, options)
|
||||||
env = Diagnostic(env)
|
env = Diagnostic(env)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
class RLLibPreprocessing(gym.ObservationWrapper):
|
|
||||||
def __init__(self, env_id, env=None, options=dict()):
|
|
||||||
super(RLLibPreprocessing, self).__init__(env)
|
|
||||||
self.preprocessor = ModelCatalog.get_preprocessor(
|
|
||||||
env_id, env.observation_space.shape, options)
|
|
||||||
self._process_shape = self.preprocessor.transform_shape(
|
|
||||||
env.observation_space.shape)
|
|
||||||
self.observation_space = Box(-1.0, 1.0, self._process_shape)
|
|
||||||
|
|
||||||
def _observation(self, observation):
|
|
||||||
return self.preprocessor.transform(observation)
|
|
||||||
|
|
||||||
|
|
||||||
class Diagnostic(gym.Wrapper):
|
class Diagnostic(gym.Wrapper):
|
||||||
def __init__(self, env=None):
|
def __init__(self, env=None):
|
||||||
super(Diagnostic, self).__init__(env)
|
super(Diagnostic, self).__init__(env)
|
||||||
|
|
|
@ -9,6 +9,8 @@ import numpy as np
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
|
|
||||||
|
from ray.rllib.models import ModelCatalog
|
||||||
|
|
||||||
|
|
||||||
class NoopResetEnv(gym.Wrapper):
|
class NoopResetEnv(gym.Wrapper):
|
||||||
def __init__(self, env=None, noop_max=30):
|
def __init__(self, env=None, noop_max=30):
|
||||||
|
@ -186,7 +188,7 @@ class FrameStack(gym.Wrapper):
|
||||||
|
|
||||||
See Also
|
See Also
|
||||||
--------
|
--------
|
||||||
ray.rllib.dqn.common.atari_wrappers.LazyFrames
|
LazyFrames
|
||||||
"""
|
"""
|
||||||
gym.Wrapper.__init__(self, env)
|
gym.Wrapper.__init__(self, env)
|
||||||
self.k = k
|
self.k = k
|
||||||
|
@ -211,41 +213,23 @@ class FrameStack(gym.Wrapper):
|
||||||
return LazyFrames(list(self.frames))
|
return LazyFrames(list(self.frames))
|
||||||
|
|
||||||
|
|
||||||
class ScaledFloatFrame(gym.ObservationWrapper):
|
def wrap_dqn(env, options):
|
||||||
def _observation(self, obs):
|
"""Apply a common set of wrappers for DQN."""
|
||||||
# careful! This undoes the memory optimization, use
|
|
||||||
# with smaller replay buffers only.
|
|
||||||
return np.array(obs).astype(np.float32) / 255.0
|
|
||||||
|
|
||||||
|
is_atari = (env.observation_space.shape == ModelCatalog.ATARI_OBS_SHAPE)
|
||||||
|
|
||||||
def wrap_dqn(env):
|
if is_atari:
|
||||||
"""Apply a common set of wrappers for Atari games."""
|
|
||||||
assert 'NoFrameskip' in env.spec.id
|
|
||||||
env = EpisodicLifeEnv(env)
|
env = EpisodicLifeEnv(env)
|
||||||
env = NoopResetEnv(env, noop_max=30)
|
env = NoopResetEnv(env, noop_max=30)
|
||||||
|
if 'NoFrameskip' in env.spec.id:
|
||||||
env = MaxAndSkipEnv(env, skip=4)
|
env = MaxAndSkipEnv(env, skip=4)
|
||||||
if 'FIRE' in env.unwrapped.get_action_meanings():
|
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||||
env = FireResetEnv(env)
|
env = FireResetEnv(env)
|
||||||
env = ProcessFrame80(env)
|
|
||||||
|
env = ModelCatalog.get_preprocessor_as_wrapper(env, options)
|
||||||
|
|
||||||
|
if is_atari:
|
||||||
env = FrameStack(env, 4)
|
env = FrameStack(env, 4)
|
||||||
env = ClippedRewardsWrapper(env)
|
env = ClippedRewardsWrapper(env)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
class A2cProcessFrame(gym.Wrapper):
|
|
||||||
def __init__(self, env):
|
|
||||||
gym.Wrapper.__init__(self, env)
|
|
||||||
self.observation_space = spaces.Box(low=0, high=255, shape=(80, 80, 1))
|
|
||||||
|
|
||||||
def _step(self, action):
|
|
||||||
ob, reward, done, info = self.env.step(action)
|
|
||||||
return A2cProcessFrame.process(ob), reward, done, info
|
|
||||||
|
|
||||||
def _reset(self):
|
|
||||||
return A2cProcessFrame.process(self.env.reset())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def process(frame):
|
|
||||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
|
||||||
frame = cv2.resize(frame, (80, 80), interpolation=cv2.INTER_AREA)
|
|
||||||
return frame.reshape(80, 80, 1)
|
|
|
@ -12,8 +12,7 @@ import tensorflow as tf
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.common import Agent, TrainingResult
|
from ray.rllib.common import Agent, TrainingResult
|
||||||
from ray.rllib.dqn import logger, models
|
from ray.rllib.dqn import logger, models
|
||||||
from ray.rllib.dqn.common.atari_wrappers_deprecated \
|
from ray.rllib.dqn.common.wrappers import wrap_dqn
|
||||||
import wrap_dqn, ScaledFloatFrame
|
|
||||||
from ray.rllib.dqn.common.schedules import LinearSchedule
|
from ray.rllib.dqn.common.schedules import LinearSchedule
|
||||||
from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||||
|
|
||||||
|
@ -104,9 +103,7 @@ DEFAULT_CONFIG = dict(
|
||||||
class Actor(object):
|
class Actor(object):
|
||||||
def __init__(self, env_creator, config, logdir):
|
def __init__(self, env_creator, config, logdir):
|
||||||
env = env_creator()
|
env = env_creator()
|
||||||
# TODO(ekl): replace this with RLlib preprocessors
|
env = wrap_dqn(env, config["model"])
|
||||||
if "NoFrameskip" in env.spec.id:
|
|
||||||
env = ScaledFloatFrame(wrap_dqn(env))
|
|
||||||
self.env = env
|
self.env = env
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
|
|
@ -152,7 +152,7 @@ class DQNGraph(object):
|
||||||
with tf.variable_scope("q_func", reuse=True):
|
with tf.variable_scope("q_func", reuse=True):
|
||||||
q_tp1_using_online_net = _build_q_network(
|
q_tp1_using_online_net = _build_q_network(
|
||||||
self.obs_tp1, num_actions, config)
|
self.obs_tp1, num_actions, config)
|
||||||
q_tp1_best_using_online_net = tf.arg_max(q_tp1_using_online_net, 1)
|
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
|
||||||
q_tp1_best = tf.reduce_sum(
|
q_tp1_best = tf.reduce_sum(
|
||||||
self.q_tp1 * tf.one_hot(
|
self.q_tp1 * tf.one_hot(
|
||||||
q_tp1_best_using_online_net, num_actions), 1)
|
q_tp1_best_using_online_net, num_actions), 1)
|
||||||
|
|
|
@ -34,6 +34,9 @@ class ModelCatalog(object):
|
||||||
action_op = dist.sample()
|
action_op = dist.sample()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
ATARI_OBS_SHAPE = (210, 160, 3)
|
||||||
|
ATARI_RAM_OBS_SHAPE = (128,)
|
||||||
|
|
||||||
_registered_preprocessor = dict()
|
_registered_preprocessor = dict()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -87,33 +90,49 @@ class ModelCatalog(object):
|
||||||
Args:
|
Args:
|
||||||
env_name (str): The name of the environment.
|
env_name (str): The name of the environment.
|
||||||
obs_shape (tuple): The shape of the env observation space.
|
obs_shape (tuple): The shape of the env observation space.
|
||||||
|
options (dict): Options to pass to the preprocessor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
preprocessor (Preprocessor): Preprocessor for the env observations.
|
preprocessor (Preprocessor): Preprocessor for the env observations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ATARI_OBS_SHAPE = (210, 160, 3)
|
|
||||||
ATARI_RAM_OBS_SHAPE = (128,)
|
|
||||||
|
|
||||||
for k in options.keys():
|
for k in options.keys():
|
||||||
if k not in MODEL_CONFIGS:
|
if k not in MODEL_CONFIGS:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Unknown config key `{}`, all keys: {}".format(
|
"Unknown config key `{}`, all keys: {}".format(
|
||||||
k, MODEL_CONFIGS))
|
k, MODEL_CONFIGS))
|
||||||
|
|
||||||
|
print("Observation shape is {}".format(obs_shape))
|
||||||
|
|
||||||
if env_name in cls._registered_preprocessor:
|
if env_name in cls._registered_preprocessor:
|
||||||
return cls._registered_preprocessor[env_name](options)
|
return cls._registered_preprocessor[env_name](options)
|
||||||
|
|
||||||
if obs_shape == ATARI_OBS_SHAPE:
|
if obs_shape == cls.ATARI_OBS_SHAPE:
|
||||||
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
|
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
|
||||||
return AtariPixelPreprocessor(options)
|
return AtariPixelPreprocessor(options)
|
||||||
elif obs_shape == ATARI_RAM_OBS_SHAPE:
|
elif obs_shape == cls.ATARI_RAM_OBS_SHAPE:
|
||||||
print("Assuming Atari ram env, using AtariRamPreprocessor.")
|
print("Assuming Atari ram env, using AtariRamPreprocessor.")
|
||||||
return AtariRamPreprocessor(options)
|
return AtariRamPreprocessor(options)
|
||||||
|
|
||||||
print("Non-atari env, not using any observation preprocessor.")
|
print("Non-atari env, not using any observation preprocessor.")
|
||||||
return NoPreprocessor(options)
|
return NoPreprocessor(options)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_preprocessor_as_wrapper(cls, env, options=dict()):
|
||||||
|
"""Returns a preprocessor as a gym observation wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env (gym.Env): The gym environment to wrap.
|
||||||
|
options (dict): Options to pass to the preprocessor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
wrapper (gym.ObservationWrapper): Preprocessor in wrapper form.
|
||||||
|
"""
|
||||||
|
|
||||||
|
preprocessor = cls.get_preprocessor(
|
||||||
|
env.spec.id, env.observation_space.shape, options)
|
||||||
|
return _RLlibPreprocessorWrapper(env, preprocessor)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_preprocessor(cls, env_name, preprocessor_class):
|
def register_preprocessor(cls, env_name, preprocessor_class):
|
||||||
"""Register a preprocessor class for a specific environment.
|
"""Register a preprocessor class for a specific environment.
|
||||||
|
@ -125,3 +144,19 @@ class ModelCatalog(object):
|
||||||
Python class of the distribution.
|
Python class of the distribution.
|
||||||
"""
|
"""
|
||||||
cls._registered_preprocessor[env_name] = preprocessor_class
|
cls._registered_preprocessor[env_name] = preprocessor_class
|
||||||
|
|
||||||
|
|
||||||
|
class _RLlibPreprocessorWrapper(gym.ObservationWrapper):
|
||||||
|
"""Adapts a RLlib preprocessor for use as an observation wrapper."""
|
||||||
|
|
||||||
|
def __init__(self, env, preprocessor):
|
||||||
|
super(_RLlibPreprocessorWrapper, self).__init__(env)
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
|
||||||
|
from gym.spaces.box import Box
|
||||||
|
self.observation_space = Box(
|
||||||
|
-1.0, 1.0,
|
||||||
|
preprocessor.transform_shape(env.observation_space.shape))
|
||||||
|
|
||||||
|
def _observation(self, observation):
|
||||||
|
return self.preprocessor.transform(observation)
|
||||||
|
|
|
@ -100,7 +100,7 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||||
|
|
||||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||||
python /ray/python/ray/rllib/train.py \
|
python /ray/python/ray/rllib/train.py \
|
||||||
--env PongNoFrameskip-v4 \
|
--env PongDeterministic-v4 \
|
||||||
--alg DQN \
|
--alg DQN \
|
||||||
--stop '{"training_iteration": 2}' \
|
--stop '{"training_iteration": 2}' \
|
||||||
--config '{"lr": 1e-4, "schedule_max_timesteps": 2000000, "buffer_size": 10000, "exploration_fraction": 0.1, "exploration_final_eps": 0.01, "sample_batch_size": 4, "learning_starts": 10000, "target_network_update_freq": 1000, "gamma": 0.99, "prioritized_replay": true}'
|
--config '{"lr": 1e-4, "schedule_max_timesteps": 2000000, "buffer_size": 10000, "exploration_fraction": 0.1, "exploration_final_eps": 0.01, "sample_batch_size": 4, "learning_starts": 10000, "target_network_update_freq": 1000, "gamma": 0.99, "prioritized_replay": true}'
|
||||||
|
|
Loading…
Add table
Reference in a new issue