ray/rllib/env/wrappers/dm_env_wrapper.py

98 lines
2.7 KiB
Python

import gym
from gym import spaces
import numpy as np
try:
from dm_env import specs
except ImportError:
specs = None
from ray.rllib.utils.annotations import PublicAPI
def _convert_spec_to_space(spec):
if isinstance(spec, dict):
return spaces.Dict({k: _convert_spec_to_space(v) for k, v in spec.items()})
if isinstance(spec, specs.DiscreteArray):
return spaces.Discrete(spec.num_values)
elif isinstance(spec, specs.BoundedArray):
return spaces.Box(
low=np.asscalar(spec.minimum),
high=np.asscalar(spec.maximum),
shape=spec.shape,
dtype=spec.dtype,
)
elif isinstance(spec, specs.Array):
return spaces.Box(
low=-float("inf"), high=float("inf"), shape=spec.shape, dtype=spec.dtype
)
raise NotImplementedError(
(
"Could not convert `Array` spec of type {} to Gym space. "
"Attempted to convert: {}"
).format(type(spec), spec)
)
@PublicAPI
class DMEnv(gym.Env):
"""A `gym.Env` wrapper for the `dm_env` API."""
metadata = {"render.modes": ["rgb_array"]}
def __init__(self, dm_env):
super(DMEnv, self).__init__()
self._env = dm_env
self._prev_obs = None
if specs is None:
raise RuntimeError(
(
"The `specs` module from `dm_env` was not imported. Make sure "
"`dm_env` is installed and visible in the current python "
"environment."
)
)
def step(self, action):
ts = self._env.step(action)
reward = ts.reward
if reward is None:
reward = 0.0
return ts.observation, reward, ts.last(), {"discount": ts.discount}
def reset(self):
ts = self._env.reset()
return ts.observation
def render(self, mode="rgb_array"):
if self._prev_obs is None:
raise ValueError(
"Environment not started. Make sure to reset before rendering."
)
if mode == "rgb_array":
return self._prev_obs
else:
raise NotImplementedError("Render mode '{}' is not supported.".format(mode))
@property
def action_space(self):
spec = self._env.action_spec()
return _convert_spec_to_space(spec)
@property
def observation_space(self):
spec = self._env.observation_spec()
return _convert_spec_to_space(spec)
@property
def reward_range(self):
spec = self._env.reward_spec()
if isinstance(spec, specs.BoundedArray):
return spec.minimum, spec.maximum
return -float("inf"), float("inf")