mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Fix Preprocessor for ATARI (#1066)
* Removing squeeze, fix atari preprocessing * nit comment * comments * jenkins * Lint
This commit is contained in:
parent
0dcf36c91e
commit
cb6dea94bc
5 changed files with 17 additions and 8 deletions
|
@ -30,7 +30,7 @@ class RLLibPreprocessing(gym.ObservationWrapper):
|
||||||
self.observation_space = Box(-1.0, 1.0, self._process_shape)
|
self.observation_space = Box(-1.0, 1.0, self._process_shape)
|
||||||
|
|
||||||
def _observation(self, observation):
|
def _observation(self, observation):
|
||||||
return self.preprocessor.transform(observation).squeeze(0)
|
return self.preprocessor.transform(observation)
|
||||||
|
|
||||||
|
|
||||||
class Diagnostic(gym.Wrapper):
|
class Diagnostic(gym.Wrapper):
|
||||||
|
|
|
@ -99,15 +99,13 @@ class Policy:
|
||||||
t = 0
|
t = 0
|
||||||
if save_obs:
|
if save_obs:
|
||||||
obs = []
|
obs = []
|
||||||
# TODO(ekl) the squeeze() is needed for Pong-v0, but we should fix
|
ob = preprocessor.transform(env.reset())
|
||||||
# this in the preprocessor instead
|
|
||||||
ob = preprocessor.transform(env.reset()).squeeze()
|
|
||||||
for _ in range(timestep_limit):
|
for _ in range(timestep_limit):
|
||||||
ac = self.act(ob[None], random_stream=random_stream)[0]
|
ac = self.act(ob[None], random_stream=random_stream)[0]
|
||||||
if save_obs:
|
if save_obs:
|
||||||
obs.append(ob)
|
obs.append(ob)
|
||||||
ob, rew, done, _ = env.step(ac)
|
ob, rew, done, _ = env.step(ac)
|
||||||
ob = preprocessor.transform(ob).squeeze()
|
ob = preprocessor.transform(ob)
|
||||||
rews.append(rew)
|
rews.append(rew)
|
||||||
t += 1
|
t += 1
|
||||||
if render:
|
if render:
|
||||||
|
|
|
@ -42,12 +42,15 @@ class AtariPixelPreprocessor(Preprocessor):
|
||||||
scaled = observation[25:-25, :, :]
|
scaled = observation[25:-25, :, :]
|
||||||
if self.dim < 80:
|
if self.dim < 80:
|
||||||
scaled = cv2.resize(scaled, (80, 80))
|
scaled = cv2.resize(scaled, (80, 80))
|
||||||
|
# OpenAI: Resize by half, then down to 42x42 (essentially mipmapping).
|
||||||
|
# If we resize directly we lose pixels that, when mapped to 42x42,
|
||||||
|
# aren't close enough to the pixel boundary.
|
||||||
scaled = cv2.resize(scaled, (self.dim, self.dim))
|
scaled = cv2.resize(scaled, (self.dim, self.dim))
|
||||||
if self.grayscale:
|
if self.grayscale:
|
||||||
scaled = scaled.mean(2)
|
scaled = scaled.mean(2)
|
||||||
scaled = scaled.astype(np.float32)
|
scaled = scaled.astype(np.float32)
|
||||||
|
# Rescale needed for maintaining 1 channel
|
||||||
scaled = np.reshape(scaled, [self.dim, self.dim, 1])
|
scaled = np.reshape(scaled, [self.dim, self.dim, 1])
|
||||||
scaled = scaled[None]
|
|
||||||
if self.zero_mean:
|
if self.zero_mean:
|
||||||
scaled = (scaled - 128) / 128
|
scaled = (scaled - 128) / 128
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -22,7 +22,8 @@ class BatchedEnv(object):
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
observations = [
|
observations = [
|
||||||
self.preprocessor.transform(env.reset()) for env in self.envs]
|
self.preprocessor.transform(env.reset())[None]
|
||||||
|
for env in self.envs]
|
||||||
self.shape = observations[0].shape
|
self.shape = observations[0].shape
|
||||||
self.dones = [False for _ in range(self.batchsize)]
|
self.dones = [False for _ in range(self.batchsize)]
|
||||||
return np.vstack(observations)
|
return np.vstack(observations)
|
||||||
|
@ -43,7 +44,7 @@ class BatchedEnv(object):
|
||||||
break
|
break
|
||||||
if render:
|
if render:
|
||||||
self.envs[0].render()
|
self.envs[0].render()
|
||||||
observations.append(self.preprocessor.transform(observation))
|
observations.append(self.preprocessor.transform(observation)[None])
|
||||||
rewards.append(reward)
|
rewards.append(reward)
|
||||||
self.dones[i] = done
|
self.dones[i] = done
|
||||||
return (np.vstack(observations), np.array(rewards, dtype="float32"),
|
return (np.vstack(observations), np.array(rewards, dtype="float32"),
|
||||||
|
|
|
@ -84,6 +84,13 @@ docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||||
--num-iterations 2 \
|
--num-iterations 2 \
|
||||||
--config '{"stepsize": 0.01}'
|
--config '{"stepsize": 0.01}'
|
||||||
|
|
||||||
|
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||||
|
python /ray/python/ray/rllib/train.py \
|
||||||
|
--env CartPole-v0 \
|
||||||
|
--alg A3C \
|
||||||
|
--num-iterations 2 \
|
||||||
|
--config '{"use_lstm": false}'
|
||||||
|
|
||||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||||
python /ray/python/ray/rllib/train.py \
|
python /ray/python/ray/rllib/train.py \
|
||||||
--env CartPole-v0 \
|
--env CartPole-v0 \
|
||||||
|
|
Loading…
Add table
Reference in a new issue