[rllib] Fix Preprocessor for ATARI (#1066)

* Removing squeeze, fix atari preprocessing

* nit comment

* comments

* jenkins

* Lint
This commit is contained in:
Richard Liaw 2017-10-03 18:45:02 -07:00 committed by Eric Liang
parent 0dcf36c91e
commit cb6dea94bc
5 changed files with 17 additions and 8 deletions

View file

@ -30,7 +30,7 @@ class RLLibPreprocessing(gym.ObservationWrapper):
self.observation_space = Box(-1.0, 1.0, self._process_shape)
def _observation(self, observation):
return self.preprocessor.transform(observation).squeeze(0)
return self.preprocessor.transform(observation)
class Diagnostic(gym.Wrapper):

View file

@ -99,15 +99,13 @@ class Policy:
t = 0
if save_obs:
obs = []
# TODO(ekl) the squeeze() is needed for Pong-v0, but we should fix
# this in the preprocessor instead
ob = preprocessor.transform(env.reset()).squeeze()
ob = preprocessor.transform(env.reset())
for _ in range(timestep_limit):
ac = self.act(ob[None], random_stream=random_stream)[0]
if save_obs:
obs.append(ob)
ob, rew, done, _ = env.step(ac)
ob = preprocessor.transform(ob).squeeze()
ob = preprocessor.transform(ob)
rews.append(rew)
t += 1
if render:

View file

@ -42,12 +42,15 @@ class AtariPixelPreprocessor(Preprocessor):
scaled = observation[25:-25, :, :]
if self.dim < 80:
scaled = cv2.resize(scaled, (80, 80))
# OpenAI: Resize by half, then down to 42x42 (essentially mipmapping).
# If we resize directly we lose pixels that, when mapped to 42x42,
# aren't close enough to the pixel boundary.
scaled = cv2.resize(scaled, (self.dim, self.dim))
if self.grayscale:
scaled = scaled.mean(2)
scaled = scaled.astype(np.float32)
# Rescale needed for maintaining 1 channel
scaled = np.reshape(scaled, [self.dim, self.dim, 1])
scaled = scaled[None]
if self.zero_mean:
scaled = (scaled - 128) / 128
else:

View file

@ -22,7 +22,8 @@ class BatchedEnv(object):
def reset(self):
observations = [
self.preprocessor.transform(env.reset()) for env in self.envs]
self.preprocessor.transform(env.reset())[None]
for env in self.envs]
self.shape = observations[0].shape
self.dones = [False for _ in range(self.batchsize)]
return np.vstack(observations)
@ -43,7 +44,7 @@ class BatchedEnv(object):
break
if render:
self.envs[0].render()
observations.append(self.preprocessor.transform(observation))
observations.append(self.preprocessor.transform(observation)[None])
rewards.append(reward)
self.dones[i] = done
return (np.vstack(observations), np.array(rewards, dtype="float32"),

View file

@ -84,6 +84,13 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \
--num-iterations 2 \
--config '{"stepsize": 0.01}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env CartPole-v0 \
--alg A3C \
--num-iterations 2 \
--config '{"use_lstm": false}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env CartPole-v0 \