[rllib] validate observation in NoPreprocessor (#4546)

This commit is contained in:
Jones Wong 2019-04-07 16:11:50 -07:00 committed by Eric Liang
parent f9b8e77e3b
commit da5a471485
7 changed files with 79 additions and 25 deletions

View file

@ -68,13 +68,13 @@ class ParametricActionCartpole(gym.Env):
self.wrapped = gym.make("CartPole-v0")
self.observation_space = Dict({
"action_mask": Box(0, 1, shape=(max_avail_actions, )),
"avail_actions": Box(-1, 1, shape=(max_avail_actions, 2)),
"avail_actions": Box(-10, 10, shape=(max_avail_actions, 2)),
"cart": self.wrapped.observation_space,
})
def update_avail_actions(self):
self.action_assignments = [[0, 0]] * self.action_space.n
self.action_mask = [0] * self.action_space.n
self.action_assignments = np.array([[0., 0.]] * self.action_space.n)
self.action_mask = np.array([0.] * self.action_space.n)
self.left_idx, self.right_idx = random.sample(
range(self.action_space.n), 2)
self.action_assignments[self.left_idx] = self.left_action_embed

View file

@ -12,6 +12,7 @@ from ray.rllib.utils.annotations import override, PublicAPI
ATARI_OBS_SHAPE = (210, 160, 3)
ATARI_RAM_OBS_SHAPE = (128, )
VALIDATION_INTERVAL = 100
logger = logging.getLogger(__name__)
@ -31,6 +32,7 @@ class Preprocessor(object):
self._options = options or {}
self.shape = self._init_shape(obs_space, options)
self._size = int(np.product(self.shape))
self._i = 0
@PublicAPI
def _init_shape(self, obs_space, options):
@ -46,6 +48,23 @@ class Preprocessor(object):
"""Alternative to transform for more efficient flattening."""
array[offset:offset + self._size] = self.transform(observation)
def check_shape(self, observation):
"""Checks the shape of the given observation."""
if self._i % VALIDATION_INTERVAL == 0:
if type(observation) is list and isinstance(
self._obs_space, gym.spaces.Box):
observation = np.array(observation)
try:
if not self._obs_space.contains(observation):
raise ValueError(
"Observation outside expected value range",
self._obs_space, observation)
except AttributeError:
raise ValueError(
"Observation for a Box space should be an np.array, "
"not a Python list.", observation)
self._i += 1
@property
@PublicAPI
def size(self):
@ -85,6 +104,7 @@ class GenericPixelPreprocessor(Preprocessor):
@override(Preprocessor)
def transform(self, observation):
"""Downsamples images from (210, 160, 3) by the configured factor."""
self.check_shape(observation)
scaled = observation[25:-25, :, :]
if self._dim < 84:
scaled = cv2.resize(scaled, (84, 84))
@ -111,6 +131,7 @@ class AtariRamPreprocessor(Preprocessor):
@override(Preprocessor)
def transform(self, observation):
self.check_shape(observation)
return (observation - 128) / 128
@ -121,10 +142,8 @@ class OneHotPreprocessor(Preprocessor):
@override(Preprocessor)
def transform(self, observation):
self.check_shape(observation)
arr = np.zeros(self._obs_space.n)
if not self._obs_space.contains(observation):
raise ValueError("Observation outside expected value range",
self._obs_space, observation)
arr[observation] = 1
return arr
@ -140,6 +159,7 @@ class NoPreprocessor(Preprocessor):
@override(Preprocessor)
def transform(self, observation):
self.check_shape(observation)
return observation
@override(Preprocessor)
@ -169,6 +189,7 @@ class TupleFlatteningPreprocessor(Preprocessor):
@override(Preprocessor)
def transform(self, observation):
self.check_shape(observation)
array = np.zeros(self.shape)
self.write(observation, array, 0)
return array
@ -201,6 +222,7 @@ class DictFlatteningPreprocessor(Preprocessor):
@override(Preprocessor)
def transform(self, observation):
self.check_shape(observation)
array = np.zeros(self.shape)
self.write(observation, array, 0)
return array

View file

@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from gym.spaces import Tuple, Discrete, Dict, Box
import ray
@ -20,7 +21,7 @@ class AvailActionsTestEnv(MultiAgentEnv):
def __init__(self, env_config):
self.state = None
self.avail = env_config["avail_action"]
self.action_mask = [0] * 10
self.action_mask = np.array([0] * 10)
self.action_mask[env_config["avail_action"]] = 1
def reset(self):

View file

@ -47,12 +47,12 @@ class ModelCatalogTest(unittest.TestCase):
def __init__(self):
self.observation_space = Tuple(
[Discrete(5),
Box(0, 1, shape=(3, ), dtype=np.float32)])
Box(0, 5, shape=(3, ), dtype=np.float32)])
p1 = ModelCatalog.get_preprocessor(TupleEnv())
self.assertEqual(p1.shape, (8, ))
self.assertEqual(
list(p1.transform((0, [1, 2, 3]))),
list(p1.transform((0, np.array([1, 2, 3])))),
[float(x) for x in [1, 0, 0, 0, 0, 1, 2, 3]])
def testCustomPreprocessor(self):

View file

@ -6,6 +6,7 @@ from __future__ import print_function
import os
import shutil
import gym
import numpy as np
import ray
@ -63,9 +64,11 @@ def test_ckpt_restore(use_object_store, alg_name, failures):
if "DDPG" in alg_name:
alg1 = cls(config=CONFIGS[name], env="Pendulum-v0")
alg2 = cls(config=CONFIGS[name], env="Pendulum-v0")
env = gym.make("Pendulum-v0")
else:
alg1 = cls(config=CONFIGS[name], env="CartPole-v0")
alg2 = cls(config=CONFIGS[name], env="CartPole-v0")
env = gym.make("CartPole-v0")
for _ in range(3):
res = alg1.train()
@ -79,9 +82,15 @@ def test_ckpt_restore(use_object_store, alg_name, failures):
for _ in range(10):
if "DDPG" in alg_name:
obs = np.random.uniform(size=3)
obs = np.clip(
np.random.uniform(size=3),
env.observation_space.low,
env.observation_space.high)
else:
obs = np.random.uniform(size=4)
obs = np.clip(
np.random.uniform(size=4),
env.observation_space.low,
env.observation_space.high)
a1 = get_mean_action(alg1, obs)
a2 = get_mean_action(alg2, obs)
print("Checking computed actions", alg1, obs, a1, a2)

View file

@ -0,0 +1,36 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
import time
import unittest
import ray
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.tests.test_policy_evaluator import MockPolicyGraph
class TestPerf(unittest.TestCase):
# Tested on Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz
# 11/23/18: Samples per second 8501.125113727468
# 03/01/19: Samples per second 8610.164353268685
def testBaselinePerformance(self):
for _ in range(20):
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph,
batch_steps=100)
start = time.time()
count = 0
while time.time() - start < 1:
count += ev.sample().count
print()
print("Samples per second {}".format(
count / (time.time() - start)))
print()
if __name__ == "__main__":
ray.init(num_cpus=5)
unittest.main(verbosity=2)

View file

@ -166,20 +166,6 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertEqual(
len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2)
# 11/23/18: Samples per second 8501.125113727468
def testBaselinePerformance(self):
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph,
batch_steps=100)
start = time.time()
count = 0
while time.time() - start < 1:
count += ev.sample().count
print()
print("Samples per second {}".format(count / (time.time() - start)))
print()
def testGlobalVarsUpdate(self):
agent = A2CTrainer(
env="CartPole-v0",