ray/rllib/env/wrappers/exception_wrapper.py

38 lines
1.1 KiB
Python
Raw Normal View History

import logging
import traceback
import gym
logger = logging.getLogger(__name__)
class TooManyResetAttemptsException(Exception):
def __init__(self, max_attempts: int):
super().__init__(
f"Reached the maximum number of attempts ({max_attempts}) "
f"to reset an environment.")
class ResetOnExceptionWrapper(gym.Wrapper):
def __init__(self, env: gym.Env, max_reset_attempts: int = 5):
super().__init__(env)
self.max_reset_attempts = max_reset_attempts
def reset(self, **kwargs):
attempt = 0
while attempt < self.max_reset_attempts:
try:
return self.env.reset(**kwargs)
except Exception:
logger.error(traceback.format_exc())
attempt += 1
else:
raise TooManyResetAttemptsException(self.max_reset_attempts)
def step(self, action):
try:
return self.env.step(action)
except Exception:
logger.error(traceback.format_exc())
return self.reset(), 0.0, False, {"__terminated__": True}