mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] validate observation in NoPreprocessor (#4546)
This commit is contained in:
parent
f9b8e77e3b
commit
da5a471485
7 changed files with 79 additions and 25 deletions
|
@ -68,13 +68,13 @@ class ParametricActionCartpole(gym.Env):
|
|||
self.wrapped = gym.make("CartPole-v0")
|
||||
self.observation_space = Dict({
|
||||
"action_mask": Box(0, 1, shape=(max_avail_actions, )),
|
||||
"avail_actions": Box(-1, 1, shape=(max_avail_actions, 2)),
|
||||
"avail_actions": Box(-10, 10, shape=(max_avail_actions, 2)),
|
||||
"cart": self.wrapped.observation_space,
|
||||
})
|
||||
|
||||
def update_avail_actions(self):
|
||||
self.action_assignments = [[0, 0]] * self.action_space.n
|
||||
self.action_mask = [0] * self.action_space.n
|
||||
self.action_assignments = np.array([[0., 0.]] * self.action_space.n)
|
||||
self.action_mask = np.array([0.] * self.action_space.n)
|
||||
self.left_idx, self.right_idx = random.sample(
|
||||
range(self.action_space.n), 2)
|
||||
self.action_assignments[self.left_idx] = self.left_action_embed
|
||||
|
|
|
@ -12,6 +12,7 @@ from ray.rllib.utils.annotations import override, PublicAPI
|
|||
|
||||
ATARI_OBS_SHAPE = (210, 160, 3)
|
||||
ATARI_RAM_OBS_SHAPE = (128, )
|
||||
VALIDATION_INTERVAL = 100
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -31,6 +32,7 @@ class Preprocessor(object):
|
|||
self._options = options or {}
|
||||
self.shape = self._init_shape(obs_space, options)
|
||||
self._size = int(np.product(self.shape))
|
||||
self._i = 0
|
||||
|
||||
@PublicAPI
|
||||
def _init_shape(self, obs_space, options):
|
||||
|
@ -46,6 +48,23 @@ class Preprocessor(object):
|
|||
"""Alternative to transform for more efficient flattening."""
|
||||
array[offset:offset + self._size] = self.transform(observation)
|
||||
|
||||
def check_shape(self, observation):
|
||||
"""Checks the shape of the given observation."""
|
||||
if self._i % VALIDATION_INTERVAL == 0:
|
||||
if type(observation) is list and isinstance(
|
||||
self._obs_space, gym.spaces.Box):
|
||||
observation = np.array(observation)
|
||||
try:
|
||||
if not self._obs_space.contains(observation):
|
||||
raise ValueError(
|
||||
"Observation outside expected value range",
|
||||
self._obs_space, observation)
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"Observation for a Box space should be an np.array, "
|
||||
"not a Python list.", observation)
|
||||
self._i += 1
|
||||
|
||||
@property
|
||||
@PublicAPI
|
||||
def size(self):
|
||||
|
@ -85,6 +104,7 @@ class GenericPixelPreprocessor(Preprocessor):
|
|||
@override(Preprocessor)
|
||||
def transform(self, observation):
|
||||
"""Downsamples images from (210, 160, 3) by the configured factor."""
|
||||
self.check_shape(observation)
|
||||
scaled = observation[25:-25, :, :]
|
||||
if self._dim < 84:
|
||||
scaled = cv2.resize(scaled, (84, 84))
|
||||
|
@ -111,6 +131,7 @@ class AtariRamPreprocessor(Preprocessor):
|
|||
|
||||
@override(Preprocessor)
|
||||
def transform(self, observation):
|
||||
self.check_shape(observation)
|
||||
return (observation - 128) / 128
|
||||
|
||||
|
||||
|
@ -121,10 +142,8 @@ class OneHotPreprocessor(Preprocessor):
|
|||
|
||||
@override(Preprocessor)
|
||||
def transform(self, observation):
|
||||
self.check_shape(observation)
|
||||
arr = np.zeros(self._obs_space.n)
|
||||
if not self._obs_space.contains(observation):
|
||||
raise ValueError("Observation outside expected value range",
|
||||
self._obs_space, observation)
|
||||
arr[observation] = 1
|
||||
return arr
|
||||
|
||||
|
@ -140,6 +159,7 @@ class NoPreprocessor(Preprocessor):
|
|||
|
||||
@override(Preprocessor)
|
||||
def transform(self, observation):
|
||||
self.check_shape(observation)
|
||||
return observation
|
||||
|
||||
@override(Preprocessor)
|
||||
|
@ -169,6 +189,7 @@ class TupleFlatteningPreprocessor(Preprocessor):
|
|||
|
||||
@override(Preprocessor)
|
||||
def transform(self, observation):
|
||||
self.check_shape(observation)
|
||||
array = np.zeros(self.shape)
|
||||
self.write(observation, array, 0)
|
||||
return array
|
||||
|
@ -201,6 +222,7 @@ class DictFlatteningPreprocessor(Preprocessor):
|
|||
|
||||
@override(Preprocessor)
|
||||
def transform(self, observation):
|
||||
self.check_shape(observation)
|
||||
array = np.zeros(self.shape)
|
||||
self.write(observation, array, 0)
|
||||
return array
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from gym.spaces import Tuple, Discrete, Dict, Box
|
||||
|
||||
import ray
|
||||
|
@ -20,7 +21,7 @@ class AvailActionsTestEnv(MultiAgentEnv):
|
|||
def __init__(self, env_config):
|
||||
self.state = None
|
||||
self.avail = env_config["avail_action"]
|
||||
self.action_mask = [0] * 10
|
||||
self.action_mask = np.array([0] * 10)
|
||||
self.action_mask[env_config["avail_action"]] = 1
|
||||
|
||||
def reset(self):
|
||||
|
|
|
@ -47,12 +47,12 @@ class ModelCatalogTest(unittest.TestCase):
|
|||
def __init__(self):
|
||||
self.observation_space = Tuple(
|
||||
[Discrete(5),
|
||||
Box(0, 1, shape=(3, ), dtype=np.float32)])
|
||||
Box(0, 5, shape=(3, ), dtype=np.float32)])
|
||||
|
||||
p1 = ModelCatalog.get_preprocessor(TupleEnv())
|
||||
self.assertEqual(p1.shape, (8, ))
|
||||
self.assertEqual(
|
||||
list(p1.transform((0, [1, 2, 3]))),
|
||||
list(p1.transform((0, np.array([1, 2, 3])))),
|
||||
[float(x) for x in [1, 0, 0, 0, 0, 1, 2, 3]])
|
||||
|
||||
def testCustomPreprocessor(self):
|
||||
|
|
|
@ -6,6 +6,7 @@ from __future__ import print_function
|
|||
|
||||
import os
|
||||
import shutil
|
||||
import gym
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
|
@ -63,9 +64,11 @@ def test_ckpt_restore(use_object_store, alg_name, failures):
|
|||
if "DDPG" in alg_name:
|
||||
alg1 = cls(config=CONFIGS[name], env="Pendulum-v0")
|
||||
alg2 = cls(config=CONFIGS[name], env="Pendulum-v0")
|
||||
env = gym.make("Pendulum-v0")
|
||||
else:
|
||||
alg1 = cls(config=CONFIGS[name], env="CartPole-v0")
|
||||
alg2 = cls(config=CONFIGS[name], env="CartPole-v0")
|
||||
env = gym.make("CartPole-v0")
|
||||
|
||||
for _ in range(3):
|
||||
res = alg1.train()
|
||||
|
@ -79,9 +82,15 @@ def test_ckpt_restore(use_object_store, alg_name, failures):
|
|||
|
||||
for _ in range(10):
|
||||
if "DDPG" in alg_name:
|
||||
obs = np.random.uniform(size=3)
|
||||
obs = np.clip(
|
||||
np.random.uniform(size=3),
|
||||
env.observation_space.low,
|
||||
env.observation_space.high)
|
||||
else:
|
||||
obs = np.random.uniform(size=4)
|
||||
obs = np.clip(
|
||||
np.random.uniform(size=4),
|
||||
env.observation_space.low,
|
||||
env.observation_space.high)
|
||||
a1 = get_mean_action(alg1, obs)
|
||||
a2 = get_mean_action(alg2, obs)
|
||||
print("Checking computed actions", alg1, obs, a1, a2)
|
||||
|
|
36
python/ray/rllib/tests/test_perf.py
Normal file
36
python/ray/rllib/tests/test_perf.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.tests.test_policy_evaluator import MockPolicyGraph
|
||||
|
||||
|
||||
class TestPerf(unittest.TestCase):
|
||||
# Tested on Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz
|
||||
# 11/23/18: Samples per second 8501.125113727468
|
||||
# 03/01/19: Samples per second 8610.164353268685
|
||||
def testBaselinePerformance(self):
|
||||
for _ in range(20):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=100)
|
||||
start = time.time()
|
||||
count = 0
|
||||
while time.time() - start < 1:
|
||||
count += ev.sample().count
|
||||
print()
|
||||
print("Samples per second {}".format(
|
||||
count / (time.time() - start)))
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init(num_cpus=5)
|
||||
unittest.main(verbosity=2)
|
|
@ -166,20 +166,6 @@ class TestPolicyEvaluator(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2)
|
||||
|
||||
# 11/23/18: Samples per second 8501.125113727468
|
||||
def testBaselinePerformance(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=100)
|
||||
start = time.time()
|
||||
count = 0
|
||||
while time.time() - start < 1:
|
||||
count += ev.sample().count
|
||||
print()
|
||||
print("Samples per second {}".format(count / (time.time() - start)))
|
||||
print()
|
||||
|
||||
def testGlobalVarsUpdate(self):
|
||||
agent = A2CTrainer(
|
||||
env="CartPole-v0",
|
||||
|
|
Loading…
Add table
Reference in a new issue