ray/rllib/env/wrappers/exception_wrapper.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

38 lines
1.1 KiB
Python

import logging
import traceback
import gym
logger = logging.getLogger(__name__)
class TooManyResetAttemptsException(Exception):
def __init__(self, max_attempts: int):
super().__init__(
f"Reached the maximum number of attempts ({max_attempts}) "
f"to reset an environment."
)
class ResetOnExceptionWrapper(gym.Wrapper):
def __init__(self, env: gym.Env, max_reset_attempts: int = 5):
super().__init__(env)
self.max_reset_attempts = max_reset_attempts
def reset(self, **kwargs):
attempt = 0
while attempt < self.max_reset_attempts:
try:
return self.env.reset(**kwargs)
except Exception:
logger.error(traceback.format_exc())
attempt += 1
else:
raise TooManyResetAttemptsException(self.max_reset_attempts)
def step(self, action):
try:
return self.env.step(action)
except Exception:
logger.error(traceback.format_exc())
return self.reset(), 0.0, False, {"__terminated__": True}