[rllib] Fix Preprocessor for ATARI (#1066)

* Removing squeeze, fix atari preprocessing

* nit comment

* comments

* jenkins

* Lint
This commit is contained in:
Richard Liaw 2017-10-03 18:45:02 -07:00 committed by Eric Liang
parent 0dcf36c91e
commit cb6dea94bc
5 changed files with 17 additions and 8 deletions

View file

@ -30,7 +30,7 @@ class RLLibPreprocessing(gym.ObservationWrapper):
self.observation_space = Box(-1.0, 1.0, self._process_shape) self.observation_space = Box(-1.0, 1.0, self._process_shape)
def _observation(self, observation): def _observation(self, observation):
return self.preprocessor.transform(observation).squeeze(0) return self.preprocessor.transform(observation)
class Diagnostic(gym.Wrapper): class Diagnostic(gym.Wrapper):

View file

@ -99,15 +99,13 @@ class Policy:
t = 0 t = 0
if save_obs: if save_obs:
obs = [] obs = []
# TODO(ekl) the squeeze() is needed for Pong-v0, but we should fix ob = preprocessor.transform(env.reset())
# this in the preprocessor instead
ob = preprocessor.transform(env.reset()).squeeze()
for _ in range(timestep_limit): for _ in range(timestep_limit):
ac = self.act(ob[None], random_stream=random_stream)[0] ac = self.act(ob[None], random_stream=random_stream)[0]
if save_obs: if save_obs:
obs.append(ob) obs.append(ob)
ob, rew, done, _ = env.step(ac) ob, rew, done, _ = env.step(ac)
ob = preprocessor.transform(ob).squeeze() ob = preprocessor.transform(ob)
rews.append(rew) rews.append(rew)
t += 1 t += 1
if render: if render:

View file

@ -42,12 +42,15 @@ class AtariPixelPreprocessor(Preprocessor):
scaled = observation[25:-25, :, :] scaled = observation[25:-25, :, :]
if self.dim < 80: if self.dim < 80:
scaled = cv2.resize(scaled, (80, 80)) scaled = cv2.resize(scaled, (80, 80))
# OpenAI: Resize by half, then down to 42x42 (essentially mipmapping).
# If we resize directly we lose pixels that, when mapped to 42x42,
# aren't close enough to the pixel boundary.
scaled = cv2.resize(scaled, (self.dim, self.dim)) scaled = cv2.resize(scaled, (self.dim, self.dim))
if self.grayscale: if self.grayscale:
scaled = scaled.mean(2) scaled = scaled.mean(2)
scaled = scaled.astype(np.float32) scaled = scaled.astype(np.float32)
# Rescale needed for maintaining 1 channel
scaled = np.reshape(scaled, [self.dim, self.dim, 1]) scaled = np.reshape(scaled, [self.dim, self.dim, 1])
scaled = scaled[None]
if self.zero_mean: if self.zero_mean:
scaled = (scaled - 128) / 128 scaled = (scaled - 128) / 128
else: else:

View file

@ -22,7 +22,8 @@ class BatchedEnv(object):
def reset(self): def reset(self):
observations = [ observations = [
self.preprocessor.transform(env.reset()) for env in self.envs] self.preprocessor.transform(env.reset())[None]
for env in self.envs]
self.shape = observations[0].shape self.shape = observations[0].shape
self.dones = [False for _ in range(self.batchsize)] self.dones = [False for _ in range(self.batchsize)]
return np.vstack(observations) return np.vstack(observations)
@ -43,7 +44,7 @@ class BatchedEnv(object):
break break
if render: if render:
self.envs[0].render() self.envs[0].render()
observations.append(self.preprocessor.transform(observation)) observations.append(self.preprocessor.transform(observation)[None])
rewards.append(reward) rewards.append(reward)
self.dones[i] = done self.dones[i] = done
return (np.vstack(observations), np.array(rewards, dtype="float32"), return (np.vstack(observations), np.array(rewards, dtype="float32"),

View file

@ -84,6 +84,13 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \
--num-iterations 2 \ --num-iterations 2 \
--config '{"stepsize": 0.01}' --config '{"stepsize": 0.01}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env CartPole-v0 \
--alg A3C \
--num-iterations 2 \
--config '{"use_lstm": false}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \ docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \ python /ray/python/ray/rllib/train.py \
--env CartPole-v0 \ --env CartPole-v0 \