import logging from gym.envs.classic_control import CartPoleEnv import numpy as np import time from ray.rllib.examples.env.multi_agent import make_multi_agent from ray.rllib.utils.annotations import override from ray.rllib.utils.error import EnvError logger = logging.getLogger(__name__) class CartPoleCrashing(CartPoleEnv): """A CartPole env that crashes from time to time. Useful for testing faulty sub-env (within a vectorized env) handling by RolloutWorkers. After crashing, the env expects a `reset()` call next (calling `step()` will result in yet another error), which may or may not take a very long time to complete. This simulates the env having to reinitialize some sub-processes, e.g. an external connection. """ def __init__(self, config=None): super().__init__() config = config or {} # Crash probability (in each `step()`). self.p_crash = config.get("p_crash", 0.005) self.p_crash_reset = config.get("p_crash_reset", self.p_crash) self.crash_after_n_steps = config.get("crash_after_n_steps") # Only crash (with prob=p_crash) if on certain worker indices. faulty_indices = config.get("crash_on_worker_indices", None) if faulty_indices and config.worker_index not in faulty_indices: self.p_crash = 0.0 self.p_crash_reset = 0.0 self.crash_after_n_steps = None # Timestep counter for the ongoing episode. self.timesteps = 0 # Time in seconds to initialize (in this c'tor). init_time_s = config.get("init_time_s", 0) time.sleep(init_time_s) # Time in seconds to re-initialize, while `reset()` is called after a crash. self.re_init_time_s = config.get("re_init_time_s", 10) # No env pre-checking? self._skip_env_checking = config.get("skip_env_checking", False) @override(CartPoleEnv) def reset(self): # Reset timestep counter for the new episode. self.timesteps = 0 # Should we crash? if np.random.random() < self.p_crash_reset or ( self.crash_after_n_steps is not None and self.crash_after_n_steps == 0 ): raise EnvError( "Simulated env crash in `reset()`! Feel free to use any " "other exception type here instead." ) return super().reset() @override(CartPoleEnv) def step(self, action): # Increase timestep counter for the ongoing episode. self.timesteps += 1 # Should we crash? if np.random.random() < self.p_crash or ( self.crash_after_n_steps and self.crash_after_n_steps == self.timesteps ): raise EnvError( "Simulated env crash in `step()`! Feel free to use any " "other exception type here instead." ) # No crash. return super().step(action) MultiAgentCartPoleCrashing = make_multi_agent(lambda config: CartPoleCrashing(config))