mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
![]() |
import random
|
||
|
import unittest
|
||
|
|
||
|
import gym
|
||
|
from ray.rllib.env.wrappers.exception_wrapper import ResetOnExceptionWrapper, \
|
||
|
TooManyResetAttemptsException
|
||
|
|
||
|
|
||
|
class TestResetOnExceptionWrapper(unittest.TestCase):
|
||
|
def test_unstable_env(self):
|
||
|
class UnstableEnv(gym.Env):
|
||
|
observation_space = gym.spaces.Discrete(2)
|
||
|
action_space = gym.spaces.Discrete(2)
|
||
|
|
||
|
def step(self, action):
|
||
|
if random.choice([True, False]):
|
||
|
raise ValueError("An error from a unstable environment.")
|
||
|
return self.observation_space.sample(), 0.0, False, {}
|
||
|
|
||
|
def reset(self):
|
||
|
return self.observation_space.sample()
|
||
|
|
||
|
env = UnstableEnv()
|
||
|
env = ResetOnExceptionWrapper(env)
|
||
|
|
||
|
try:
|
||
|
self._run_for_100_steps(env)
|
||
|
except Exception:
|
||
|
self.fail()
|
||
|
|
||
|
def test_very_unstable_env(self):
|
||
|
class VeryUnstableEnv(gym.Env):
|
||
|
observation_space = gym.spaces.Discrete(2)
|
||
|
action_space = gym.spaces.Discrete(2)
|
||
|
|
||
|
def step(self, action):
|
||
|
return self.observation_space.sample(), 0.0, False, {}
|
||
|
|
||
|
def reset(self):
|
||
|
raise ValueError("An error from a very unstable environment.")
|
||
|
|
||
|
env = VeryUnstableEnv()
|
||
|
env = ResetOnExceptionWrapper(env)
|
||
|
self.assertRaises(TooManyResetAttemptsException,
|
||
|
lambda: self._run_for_100_steps(env))
|
||
|
|
||
|
@staticmethod
|
||
|
def _run_for_100_steps(env):
|
||
|
env.reset()
|
||
|
for _ in range(100):
|
||
|
env.step(env.action_space.sample())
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import sys
|
||
|
import pytest
|
||
|
sys.exit(pytest.main(["-v", __file__]))
|