mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Support discrete observation spaces such as FrozenLake-v0 (#1140)
* add * remove transform_shape * fix test * fix
This commit is contained in:
parent
0c9817fa76
commit
cd9dc398ff
10 changed files with 98 additions and 66 deletions
|
@ -216,7 +216,8 @@ class FrameStack(gym.Wrapper):
|
|||
def wrap_dqn(env, options):
|
||||
"""Apply a common set of wrappers for DQN."""
|
||||
|
||||
is_atari = (env.observation_space.shape == ModelCatalog.ATARI_OBS_SHAPE)
|
||||
is_atari = (hasattr(env.observation_space, "shape") and
|
||||
env.observation_space.shape == ModelCatalog.ATARI_OBS_SHAPE)
|
||||
|
||||
if is_atari:
|
||||
env = EpisodicLifeEnv(env)
|
||||
|
|
|
@ -198,8 +198,8 @@ class Actor(object):
|
|||
self.dqn_graph.apply_gradients(self.sess, grad)
|
||||
|
||||
def stats(self, num_timesteps):
|
||||
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 1)
|
||||
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 1)
|
||||
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 5)
|
||||
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 5)
|
||||
exploration = self.exploration.value(num_timesteps)
|
||||
return (
|
||||
mean_100ep_reward,
|
||||
|
|
|
@ -74,10 +74,7 @@ class Worker(object):
|
|||
self.noise = SharedNoiseTable(noise)
|
||||
|
||||
self.env = env_creator()
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(
|
||||
self.env.spec.id, self.env.observation_space.shape)
|
||||
self.preprocessor_shape = self.preprocessor.transform_shape(
|
||||
self.env.observation_space.shape)
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(self.env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=True)
|
||||
self.policy = policies.GenericPolicy(
|
||||
|
@ -118,7 +115,7 @@ class Worker(object):
|
|||
|
||||
noise_inds, returns, sign_returns, lengths = [], [], [], []
|
||||
# We set eps=0 because we're incrementing only.
|
||||
task_ob_stat = utils.RunningStat(self.preprocessor_shape, eps=0)
|
||||
task_ob_stat = utils.RunningStat(self.preprocessor.shape, eps=0)
|
||||
|
||||
# Perform some rollouts with noise.
|
||||
task_tstart = time.time()
|
||||
|
@ -169,10 +166,7 @@ class ESAgent(Agent):
|
|||
}
|
||||
|
||||
env = self.env_creator()
|
||||
preprocessor = ModelCatalog.get_preprocessor(
|
||||
env.spec.id, env.observation_space.shape)
|
||||
preprocessor_shape = preprocessor.transform_shape(
|
||||
env.observation_space.shape)
|
||||
preprocessor = ModelCatalog.get_preprocessor(env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=False)
|
||||
self.policy = policies.GenericPolicy(
|
||||
|
@ -180,7 +174,7 @@ class ESAgent(Agent):
|
|||
**policy_params)
|
||||
tf_util.initialize()
|
||||
self.optimizer = optimizers.Adam(self.policy, self.config["stepsize"])
|
||||
self.ob_stat = utils.RunningStat(preprocessor_shape, eps=1e-2)
|
||||
self.ob_stat = utils.RunningStat(preprocessor.shape, eps=1e-2)
|
||||
|
||||
# Create the shared noise table.
|
||||
print("Creating shared noise table.")
|
||||
|
|
|
@ -144,25 +144,25 @@ class GenericPolicy(Policy):
|
|||
def _initialize(self, ob_space, ac_space, preprocessor, ac_noise_std):
|
||||
self.ac_space = ac_space
|
||||
self.ac_noise_std = ac_noise_std
|
||||
self.preprocessor_shape = preprocessor.transform_shape(ob_space.shape)
|
||||
self.preprocessor = preprocessor
|
||||
|
||||
with tf.variable_scope(type(self).__name__) as scope:
|
||||
# Observation normalization.
|
||||
ob_mean = tf.get_variable(
|
||||
'ob_mean', self.preprocessor_shape, tf.float32,
|
||||
'ob_mean', self.preprocessor.shape, tf.float32,
|
||||
tf.constant_initializer(np.nan), trainable=False)
|
||||
ob_std = tf.get_variable(
|
||||
'ob_std', self.preprocessor_shape, tf.float32,
|
||||
'ob_std', self.preprocessor.shape, tf.float32,
|
||||
tf.constant_initializer(np.nan), trainable=False)
|
||||
in_mean = tf.placeholder(tf.float32, self.preprocessor_shape)
|
||||
in_std = tf.placeholder(tf.float32, self.preprocessor_shape)
|
||||
in_mean = tf.placeholder(tf.float32, self.preprocessor.shape)
|
||||
in_std = tf.placeholder(tf.float32, self.preprocessor.shape)
|
||||
self._set_ob_mean_std = U.function([in_mean, in_std], [], updates=[
|
||||
tf.assign(ob_mean, in_mean),
|
||||
tf.assign(ob_std, in_std),
|
||||
])
|
||||
|
||||
inputs = tf.placeholder(
|
||||
tf.float32, [None] + list(self.preprocessor_shape))
|
||||
tf.float32, [None] + list(self.preprocessor.shape))
|
||||
|
||||
# TODO(ekl): we should do clipping in a standard RLlib preprocessor
|
||||
clipped_inputs = tf.clip_by_value(
|
||||
|
|
|
@ -7,7 +7,8 @@ import gym
|
|||
from ray.rllib.models.action_dist import (
|
||||
Categorical, Deterministic, DiagGaussian)
|
||||
from ray.rllib.models.preprocessors import (
|
||||
NoPreprocessor, AtariRamPreprocessor, AtariPixelPreprocessor)
|
||||
NoPreprocessor, AtariRamPreprocessor, AtariPixelPreprocessor,
|
||||
OneHotPreprocessor)
|
||||
from ray.rllib.models.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.visionnet import VisionNetwork
|
||||
|
||||
|
@ -84,18 +85,25 @@ class ModelCatalog(object):
|
|||
return FullyConnectedNetwork(inputs, num_outputs, options)
|
||||
|
||||
@classmethod
|
||||
def get_preprocessor(cls, env_name, obs_shape, options=dict()):
|
||||
def get_preprocessor(cls, env, options=dict()):
|
||||
"""Returns a suitable processor for the given environment.
|
||||
|
||||
Args:
|
||||
env_name (str): The name of the environment.
|
||||
obs_shape (tuple): The shape of the env observation space.
|
||||
env (gym.Env): The gym environment to preprocess.
|
||||
options (dict): Options to pass to the preprocessor.
|
||||
|
||||
Returns:
|
||||
preprocessor (Preprocessor): Preprocessor for the env observations.
|
||||
"""
|
||||
|
||||
# For older gym versions that don't set shape for Discrete
|
||||
if not hasattr(env.observation_space, "shape") and \
|
||||
isinstance(env.observation_space, gym.spaces.Discrete):
|
||||
env.observation_space.shape = ()
|
||||
|
||||
env_name = env.spec.id
|
||||
obs_shape = env.observation_space.shape
|
||||
|
||||
for k in options.keys():
|
||||
if k not in MODEL_CONFIGS:
|
||||
raise Exception(
|
||||
|
@ -107,15 +115,20 @@ class ModelCatalog(object):
|
|||
if env_name in cls._registered_preprocessor:
|
||||
return cls._registered_preprocessor[env_name](options)
|
||||
|
||||
if obs_shape == cls.ATARI_OBS_SHAPE:
|
||||
if obs_shape == ():
|
||||
print("Using one-hot preprocessor for discrete envs.")
|
||||
preprocessor = OneHotPreprocessor
|
||||
elif obs_shape == cls.ATARI_OBS_SHAPE:
|
||||
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
|
||||
return AtariPixelPreprocessor(options)
|
||||
preprocessor = AtariPixelPreprocessor
|
||||
elif obs_shape == cls.ATARI_RAM_OBS_SHAPE:
|
||||
print("Assuming Atari ram env, using AtariRamPreprocessor.")
|
||||
return AtariRamPreprocessor(options)
|
||||
preprocessor = AtariRamPreprocessor
|
||||
else:
|
||||
print("Non-atari env, not using any observation preprocessor.")
|
||||
preprocessor = NoPreprocessor
|
||||
|
||||
print("Non-atari env, not using any observation preprocessor.")
|
||||
return NoPreprocessor(options)
|
||||
return preprocessor(env.observation_space, options)
|
||||
|
||||
@classmethod
|
||||
def get_preprocessor_as_wrapper(cls, env, options=dict()):
|
||||
|
@ -129,8 +142,7 @@ class ModelCatalog(object):
|
|||
wrapper (gym.ObservationWrapper): Preprocessor in wrapper form.
|
||||
"""
|
||||
|
||||
preprocessor = cls.get_preprocessor(
|
||||
env.spec.id, env.observation_space.shape, options)
|
||||
preprocessor = cls.get_preprocessor(env, options)
|
||||
return _RLlibPreprocessorWrapper(env, preprocessor)
|
||||
|
||||
@classmethod
|
||||
|
@ -154,9 +166,7 @@ class _RLlibPreprocessorWrapper(gym.ObservationWrapper):
|
|||
self.preprocessor = preprocessor
|
||||
|
||||
from gym.spaces.box import Box
|
||||
self.observation_space = Box(
|
||||
-1.0, 1.0,
|
||||
preprocessor.transform_shape(env.observation_space.shape))
|
||||
self.observation_space = Box(-1.0, 1.0, preprocessor.shape)
|
||||
|
||||
def _observation(self, observation):
|
||||
return self.preprocessor.transform(observation)
|
||||
|
|
|
@ -6,19 +6,20 @@ import numpy as np
|
|||
|
||||
|
||||
class Preprocessor(object):
|
||||
"""Defines an abstract observation preprocessor function."""
|
||||
"""Defines an abstract observation preprocessor function.
|
||||
|
||||
def __init__(self, options):
|
||||
self.options = options
|
||||
Attributes:
|
||||
shape (obj): Shape of the preprocessed output.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, options):
|
||||
self._obs_space = obs_space
|
||||
self._options = options
|
||||
self._init()
|
||||
|
||||
def _init(self):
|
||||
pass
|
||||
|
||||
def transform_shape(self, obs_shape):
|
||||
"""Returns the preprocessed observation shape."""
|
||||
raise NotImplementedError
|
||||
|
||||
def transform(self, observation):
|
||||
"""Returns the preprocessed observation."""
|
||||
raise NotImplementedError
|
||||
|
@ -26,32 +27,30 @@ class Preprocessor(object):
|
|||
|
||||
class AtariPixelPreprocessor(Preprocessor):
|
||||
def _init(self):
|
||||
self.grayscale = self.options.get("grayscale", False)
|
||||
self.zero_mean = self.options.get("zero_mean", True)
|
||||
self.dim = self.options.get("dim", 80)
|
||||
|
||||
def transform_shape(self, obs_shape):
|
||||
if self.grayscale:
|
||||
return (self.dim, self.dim, 1)
|
||||
self._grayscale = self._options.get("grayscale", False)
|
||||
self._zero_mean = self._options.get("zero_mean", True)
|
||||
self._dim = self._options.get("dim", 80)
|
||||
if self._grayscale:
|
||||
self.shape = (self._dim, self._dim, 1)
|
||||
else:
|
||||
return (self.dim, self.dim, 3)
|
||||
self.shape = (self._dim, self._dim, 3)
|
||||
|
||||
# TODO(ekl) why does this need to return an extra size-1 dim (the [None])
|
||||
def transform(self, observation):
|
||||
"""Downsamples images from (210, 160, 3) by the configured factor."""
|
||||
scaled = observation[25:-25, :, :]
|
||||
if self.dim < 80:
|
||||
if self._dim < 80:
|
||||
scaled = cv2.resize(scaled, (80, 80))
|
||||
# OpenAI: Resize by half, then down to 42x42 (essentially mipmapping).
|
||||
# If we resize directly we lose pixels that, when mapped to 42x42,
|
||||
# aren't close enough to the pixel boundary.
|
||||
scaled = cv2.resize(scaled, (self.dim, self.dim))
|
||||
if self.grayscale:
|
||||
scaled = cv2.resize(scaled, (self._dim, self._dim))
|
||||
if self._grayscale:
|
||||
scaled = scaled.mean(2)
|
||||
scaled = scaled.astype(np.float32)
|
||||
# Rescale needed for maintaining 1 channel
|
||||
scaled = np.reshape(scaled, [self.dim, self.dim, 1])
|
||||
if self.zero_mean:
|
||||
scaled = np.reshape(scaled, [self._dim, self._dim, 1])
|
||||
if self._zero_mean:
|
||||
scaled = (scaled - 128) / 128
|
||||
else:
|
||||
scaled *= 1.0 / 255.0
|
||||
|
@ -60,16 +59,27 @@ class AtariPixelPreprocessor(Preprocessor):
|
|||
|
||||
# TODO(rliaw): Also should include the deepmind preprocessor
|
||||
class AtariRamPreprocessor(Preprocessor):
|
||||
def transform_shape(self, obs_shape):
|
||||
return (128,)
|
||||
def _init(self):
|
||||
self.shape = (128,)
|
||||
|
||||
def transform(self, observation):
|
||||
return (observation - 128) / 128
|
||||
|
||||
|
||||
class OneHotPreprocessor(Preprocessor):
|
||||
def _init(self):
|
||||
assert self._obs_space.shape == ()
|
||||
self.shape = (self._obs_space.n,)
|
||||
|
||||
def transform(self, observation):
|
||||
arr = np.zeros(self._obs_space.n)
|
||||
arr[observation] = 1
|
||||
return arr
|
||||
|
||||
|
||||
class NoPreprocessor(Preprocessor):
|
||||
def transform_shape(self, obs_shape):
|
||||
return obs_shape
|
||||
def _init(self):
|
||||
self.shape = self._obs_space.shape
|
||||
|
||||
def transform(self, observation):
|
||||
return observation
|
||||
|
|
|
@ -15,8 +15,7 @@ class BatchedEnv(object):
|
|||
self.action_space = self.envs[0].action_space
|
||||
self.batchsize = batchsize
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(
|
||||
self.envs[0].spec.id, self.envs[0].observation_space.shape,
|
||||
options["model"])
|
||||
self.envs[0], options["model"])
|
||||
self.extra_frameskip = options.get("extra_frameskip", 1)
|
||||
assert self.extra_frameskip >= 1
|
||||
|
||||
|
|
|
@ -63,12 +63,9 @@ class Runner(object):
|
|||
self.kl_coeff = tf.placeholder(
|
||||
name="newkl", shape=(), dtype=tf.float32)
|
||||
|
||||
# The shape of the preprocessed observations.
|
||||
self.preprocessor_shape = self.preprocessor.transform_shape(
|
||||
self.env.observation_space.shape)
|
||||
# The input observations.
|
||||
self.observations = tf.placeholder(
|
||||
tf.float32, shape=(None,) + self.preprocessor_shape)
|
||||
tf.float32, shape=(None,) + self.preprocessor.shape)
|
||||
# Targets of the value function.
|
||||
self.returns = tf.placeholder(tf.float32, shape=(None,))
|
||||
# Advantage values in the policy gradient estimator.
|
||||
|
@ -142,7 +139,7 @@ class Runner(object):
|
|||
self.common_policy.loss, self.sess)
|
||||
if config["observation_filter"] == "MeanStdFilter":
|
||||
self.observation_filter = MeanStdFilter(
|
||||
self.preprocessor_shape, clip=None)
|
||||
self.preprocessor.shape, clip=None)
|
||||
elif config["observation_filter"] == "NoFilter":
|
||||
self.observation_filter = NoFilter()
|
||||
else:
|
||||
|
|
|
@ -3,12 +3,20 @@ from ray.rllib.models.preprocessors import Preprocessor
|
|||
|
||||
|
||||
class FakePreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, options):
|
||||
pass
|
||||
|
||||
|
||||
class FakeEnv(object):
|
||||
def __init__(self):
|
||||
self.observation_space = lambda: None
|
||||
self.observation_space.shape = ()
|
||||
self.spec = lambda: None
|
||||
self.spec.id = "FakeEnv-v0"
|
||||
|
||||
|
||||
def test_preprocessor():
|
||||
ModelCatalog.register_preprocessor("FakeEnv-v0", FakePreprocessor)
|
||||
preprocessor = ModelCatalog.get_preprocessor("FakeEnv-v0", (1, 1))
|
||||
env = FakeEnv()
|
||||
preprocessor = ModelCatalog.get_preprocessor(env)
|
||||
assert type(preprocessor) == FakePreprocessor
|
||||
|
|
|
@ -98,6 +98,19 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
|||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"lr": 1e-3, "schedule_max_timesteps": 100000, "exploration_fraction": 0.1, "exploration_final_eps": 0.02, "dueling": false, "hiddens": [], "model": {"fcnet_hiddens": [64], "fcnet_activation": "relu"}}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env FrozenLake-v0 \
|
||||
--alg DQN \
|
||||
--stop '{"training_iteration": 2}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env FrozenLake-v0 \
|
||||
--alg PPO \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"num_sgd_iter": 10, "sgd_batchsize": 64, "timesteps_per_batch": 1000, "num_workers": 1}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env PongDeterministic-v4 \
|
||||
|
|
Loading…
Add table
Reference in a new issue