mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Fix Preprocessor for ATARI (#1066)
* Removing squeeze, fix atari preprocessing * nit comment * comments * jenkins * Lint
This commit is contained in:
parent
0dcf36c91e
commit
cb6dea94bc
5 changed files with 17 additions and 8 deletions
|
@ -30,7 +30,7 @@ class RLLibPreprocessing(gym.ObservationWrapper):
|
|||
self.observation_space = Box(-1.0, 1.0, self._process_shape)
|
||||
|
||||
def _observation(self, observation):
|
||||
return self.preprocessor.transform(observation).squeeze(0)
|
||||
return self.preprocessor.transform(observation)
|
||||
|
||||
|
||||
class Diagnostic(gym.Wrapper):
|
||||
|
|
|
@ -99,15 +99,13 @@ class Policy:
|
|||
t = 0
|
||||
if save_obs:
|
||||
obs = []
|
||||
# TODO(ekl) the squeeze() is needed for Pong-v0, but we should fix
|
||||
# this in the preprocessor instead
|
||||
ob = preprocessor.transform(env.reset()).squeeze()
|
||||
ob = preprocessor.transform(env.reset())
|
||||
for _ in range(timestep_limit):
|
||||
ac = self.act(ob[None], random_stream=random_stream)[0]
|
||||
if save_obs:
|
||||
obs.append(ob)
|
||||
ob, rew, done, _ = env.step(ac)
|
||||
ob = preprocessor.transform(ob).squeeze()
|
||||
ob = preprocessor.transform(ob)
|
||||
rews.append(rew)
|
||||
t += 1
|
||||
if render:
|
||||
|
|
|
@ -42,12 +42,15 @@ class AtariPixelPreprocessor(Preprocessor):
|
|||
scaled = observation[25:-25, :, :]
|
||||
if self.dim < 80:
|
||||
scaled = cv2.resize(scaled, (80, 80))
|
||||
# OpenAI: Resize by half, then down to 42x42 (essentially mipmapping).
|
||||
# If we resize directly we lose pixels that, when mapped to 42x42,
|
||||
# aren't close enough to the pixel boundary.
|
||||
scaled = cv2.resize(scaled, (self.dim, self.dim))
|
||||
if self.grayscale:
|
||||
scaled = scaled.mean(2)
|
||||
scaled = scaled.astype(np.float32)
|
||||
# Rescale needed for maintaining 1 channel
|
||||
scaled = np.reshape(scaled, [self.dim, self.dim, 1])
|
||||
scaled = scaled[None]
|
||||
if self.zero_mean:
|
||||
scaled = (scaled - 128) / 128
|
||||
else:
|
||||
|
|
|
@ -22,7 +22,8 @@ class BatchedEnv(object):
|
|||
|
||||
def reset(self):
|
||||
observations = [
|
||||
self.preprocessor.transform(env.reset()) for env in self.envs]
|
||||
self.preprocessor.transform(env.reset())[None]
|
||||
for env in self.envs]
|
||||
self.shape = observations[0].shape
|
||||
self.dones = [False for _ in range(self.batchsize)]
|
||||
return np.vstack(observations)
|
||||
|
@ -43,7 +44,7 @@ class BatchedEnv(object):
|
|||
break
|
||||
if render:
|
||||
self.envs[0].render()
|
||||
observations.append(self.preprocessor.transform(observation))
|
||||
observations.append(self.preprocessor.transform(observation)[None])
|
||||
rewards.append(reward)
|
||||
self.dones[i] = done
|
||||
return (np.vstack(observations), np.array(rewards, dtype="float32"),
|
||||
|
|
|
@ -84,6 +84,13 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
|||
--num-iterations 2 \
|
||||
--config '{"stepsize": 0.01}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env CartPole-v0 \
|
||||
--alg A3C \
|
||||
--num-iterations 2 \
|
||||
--config '{"use_lstm": false}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env CartPole-v0 \
|
||||
|
|
Loading…
Add table
Reference in a new issue