diff --git a/python/ray/rllib/examples/parametric_action_cartpole.py b/python/ray/rllib/examples/parametric_action_cartpole.py index a36e58b98..3d57c268c 100644 --- a/python/ray/rllib/examples/parametric_action_cartpole.py +++ b/python/ray/rllib/examples/parametric_action_cartpole.py @@ -68,13 +68,13 @@ class ParametricActionCartpole(gym.Env): self.wrapped = gym.make("CartPole-v0") self.observation_space = Dict({ "action_mask": Box(0, 1, shape=(max_avail_actions, )), - "avail_actions": Box(-1, 1, shape=(max_avail_actions, 2)), + "avail_actions": Box(-10, 10, shape=(max_avail_actions, 2)), "cart": self.wrapped.observation_space, }) def update_avail_actions(self): - self.action_assignments = [[0, 0]] * self.action_space.n - self.action_mask = [0] * self.action_space.n + self.action_assignments = np.array([[0., 0.]] * self.action_space.n) + self.action_mask = np.array([0.] * self.action_space.n) self.left_idx, self.right_idx = random.sample( range(self.action_space.n), 2) self.action_assignments[self.left_idx] = self.left_action_embed diff --git a/python/ray/rllib/models/preprocessors.py b/python/ray/rllib/models/preprocessors.py index dd133f2aa..5a2e23d6d 100644 --- a/python/ray/rllib/models/preprocessors.py +++ b/python/ray/rllib/models/preprocessors.py @@ -12,6 +12,7 @@ from ray.rllib.utils.annotations import override, PublicAPI ATARI_OBS_SHAPE = (210, 160, 3) ATARI_RAM_OBS_SHAPE = (128, ) +VALIDATION_INTERVAL = 100 logger = logging.getLogger(__name__) @@ -31,6 +32,7 @@ class Preprocessor(object): self._options = options or {} self.shape = self._init_shape(obs_space, options) self._size = int(np.product(self.shape)) + self._i = 0 @PublicAPI def _init_shape(self, obs_space, options): @@ -46,6 +48,23 @@ class Preprocessor(object): """Alternative to transform for more efficient flattening.""" array[offset:offset + self._size] = self.transform(observation) + def check_shape(self, observation): + """Checks the shape of the given observation.""" + if self._i % VALIDATION_INTERVAL == 0: + if type(observation) is list and isinstance( + self._obs_space, gym.spaces.Box): + observation = np.array(observation) + try: + if not self._obs_space.contains(observation): + raise ValueError( + "Observation outside expected value range", + self._obs_space, observation) + except AttributeError: + raise ValueError( + "Observation for a Box space should be an np.array, " + "not a Python list.", observation) + self._i += 1 + @property @PublicAPI def size(self): @@ -85,6 +104,7 @@ class GenericPixelPreprocessor(Preprocessor): @override(Preprocessor) def transform(self, observation): """Downsamples images from (210, 160, 3) by the configured factor.""" + self.check_shape(observation) scaled = observation[25:-25, :, :] if self._dim < 84: scaled = cv2.resize(scaled, (84, 84)) @@ -111,6 +131,7 @@ class AtariRamPreprocessor(Preprocessor): @override(Preprocessor) def transform(self, observation): + self.check_shape(observation) return (observation - 128) / 128 @@ -121,10 +142,8 @@ class OneHotPreprocessor(Preprocessor): @override(Preprocessor) def transform(self, observation): + self.check_shape(observation) arr = np.zeros(self._obs_space.n) - if not self._obs_space.contains(observation): - raise ValueError("Observation outside expected value range", - self._obs_space, observation) arr[observation] = 1 return arr @@ -140,6 +159,7 @@ class NoPreprocessor(Preprocessor): @override(Preprocessor) def transform(self, observation): + self.check_shape(observation) return observation @override(Preprocessor) @@ -169,6 +189,7 @@ class TupleFlatteningPreprocessor(Preprocessor): @override(Preprocessor) def transform(self, observation): + self.check_shape(observation) array = np.zeros(self.shape) self.write(observation, array, 0) return array @@ -201,6 +222,7 @@ class DictFlatteningPreprocessor(Preprocessor): @override(Preprocessor) def transform(self, observation): + self.check_shape(observation) array = np.zeros(self.shape) self.write(observation, array, 0) return array diff --git a/python/ray/rllib/tests/test_avail_actions_qmix.py b/python/ray/rllib/tests/test_avail_actions_qmix.py index fc45e5d2a..f38cffa5e 100644 --- a/python/ray/rllib/tests/test_avail_actions_qmix.py +++ b/python/ray/rllib/tests/test_avail_actions_qmix.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np from gym.spaces import Tuple, Discrete, Dict, Box import ray @@ -20,7 +21,7 @@ class AvailActionsTestEnv(MultiAgentEnv): def __init__(self, env_config): self.state = None self.avail = env_config["avail_action"] - self.action_mask = [0] * 10 + self.action_mask = np.array([0] * 10) self.action_mask[env_config["avail_action"]] = 1 def reset(self): diff --git a/python/ray/rllib/tests/test_catalog.py b/python/ray/rllib/tests/test_catalog.py index fc9b71d2c..fe89152c6 100644 --- a/python/ray/rllib/tests/test_catalog.py +++ b/python/ray/rllib/tests/test_catalog.py @@ -47,12 +47,12 @@ class ModelCatalogTest(unittest.TestCase): def __init__(self): self.observation_space = Tuple( [Discrete(5), - Box(0, 1, shape=(3, ), dtype=np.float32)]) + Box(0, 5, shape=(3, ), dtype=np.float32)]) p1 = ModelCatalog.get_preprocessor(TupleEnv()) self.assertEqual(p1.shape, (8, )) self.assertEqual( - list(p1.transform((0, [1, 2, 3]))), + list(p1.transform((0, np.array([1, 2, 3])))), [float(x) for x in [1, 0, 0, 0, 0, 1, 2, 3]]) def testCustomPreprocessor(self): diff --git a/python/ray/rllib/tests/test_checkpoint_restore.py b/python/ray/rllib/tests/test_checkpoint_restore.py index e1ad5a5a9..3b16ad1dd 100644 --- a/python/ray/rllib/tests/test_checkpoint_restore.py +++ b/python/ray/rllib/tests/test_checkpoint_restore.py @@ -6,6 +6,7 @@ from __future__ import print_function import os import shutil +import gym import numpy as np import ray @@ -63,9 +64,11 @@ def test_ckpt_restore(use_object_store, alg_name, failures): if "DDPG" in alg_name: alg1 = cls(config=CONFIGS[name], env="Pendulum-v0") alg2 = cls(config=CONFIGS[name], env="Pendulum-v0") + env = gym.make("Pendulum-v0") else: alg1 = cls(config=CONFIGS[name], env="CartPole-v0") alg2 = cls(config=CONFIGS[name], env="CartPole-v0") + env = gym.make("CartPole-v0") for _ in range(3): res = alg1.train() @@ -79,9 +82,15 @@ def test_ckpt_restore(use_object_store, alg_name, failures): for _ in range(10): if "DDPG" in alg_name: - obs = np.random.uniform(size=3) + obs = np.clip( + np.random.uniform(size=3), + env.observation_space.low, + env.observation_space.high) else: - obs = np.random.uniform(size=4) + obs = np.clip( + np.random.uniform(size=4), + env.observation_space.low, + env.observation_space.high) a1 = get_mean_action(alg1, obs) a2 = get_mean_action(alg2, obs) print("Checking computed actions", alg1, obs, a1, a2) diff --git a/python/ray/rllib/tests/test_perf.py b/python/ray/rllib/tests/test_perf.py new file mode 100644 index 000000000..f437c9628 --- /dev/null +++ b/python/ray/rllib/tests/test_perf.py @@ -0,0 +1,36 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gym +import time +import unittest + +import ray +from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator +from ray.rllib.tests.test_policy_evaluator import MockPolicyGraph + + +class TestPerf(unittest.TestCase): + # Tested on Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz + # 11/23/18: Samples per second 8501.125113727468 + # 03/01/19: Samples per second 8610.164353268685 + def testBaselinePerformance(self): + for _ in range(20): + ev = PolicyEvaluator( + env_creator=lambda _: gym.make("CartPole-v0"), + policy_graph=MockPolicyGraph, + batch_steps=100) + start = time.time() + count = 0 + while time.time() - start < 1: + count += ev.sample().count + print() + print("Samples per second {}".format( + count / (time.time() - start))) + print() + + +if __name__ == "__main__": + ray.init(num_cpus=5) + unittest.main(verbosity=2) diff --git a/python/ray/rllib/tests/test_policy_evaluator.py b/python/ray/rllib/tests/test_policy_evaluator.py index 56cbbca6d..6283a5b66 100644 --- a/python/ray/rllib/tests/test_policy_evaluator.py +++ b/python/ray/rllib/tests/test_policy_evaluator.py @@ -166,20 +166,6 @@ class TestPolicyEvaluator(unittest.TestCase): self.assertEqual( len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2) - # 11/23/18: Samples per second 8501.125113727468 - def testBaselinePerformance(self): - ev = PolicyEvaluator( - env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=MockPolicyGraph, - batch_steps=100) - start = time.time() - count = 0 - while time.time() - start < 1: - count += ev.sample().count - print() - print("Samples per second {}".format(count / (time.time() - start))) - print() - def testGlobalVarsUpdate(self): agent = A2CTrainer( env="CartPole-v0",