ray/rllib/env/dm_env_wrapper.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

94 lines
2.6 KiB
Python

import gym
from gym import spaces
import numpy as np
try:
from dm_env import specs
except ImportError:
specs = None
def _convert_spec_to_space(spec):
if isinstance(spec, dict):
return spaces.Dict(
{k: _convert_spec_to_space(v)
for k, v in spec.items()})
if isinstance(spec, specs.DiscreteArray):
return spaces.Discrete(spec.num_values)
elif isinstance(spec, specs.BoundedArray):
return spaces.Box(
low=np.asscalar(spec.minimum),
high=np.asscalar(spec.maximum),
shape=spec.shape,
dtype=spec.dtype)
elif isinstance(spec, specs.Array):
return spaces.Box(
low=-float("inf"),
high=float("inf"),
shape=spec.shape,
dtype=spec.dtype)
raise NotImplementedError(
("Could not convert `Array` spec of type {} to Gym space. "
"Attempted to convert: {}").format(type(spec), spec))
class DMEnv(gym.Env):
"""A `gym.Env` wrapper for the `dm_env` API.
"""
metadata = {"render.modes": ["rgb_array"]}
def __init__(self, dm_env):
super(DMEnv, self).__init__()
self._env = dm_env
self._prev_obs = None
if specs is None:
raise RuntimeError((
"The `specs` module from `dm_env` was not imported. Make sure "
"`dm_env` is installed and visible in the current python "
"environment."))
def step(self, action):
ts = self._env.step(action)
reward = ts.reward
if reward is None:
reward = 0.
return ts.observation, reward, ts.last(), {"discount": ts.discount}
def reset(self):
ts = self._env.reset()
return ts.observation
def render(self, mode="rgb_array"):
if self._prev_obs is None:
raise ValueError(
"Environment not started. Make sure to reset before rendering."
)
if mode == "rgb_array":
return self._prev_obs
else:
raise NotImplementedError(
"Render mode '{}' is not supported.".format(mode))
@property
def action_space(self):
spec = self._env.action_spec()
return _convert_spec_to_space(spec)
@property
def observation_space(self):
spec = self._env.observation_spec()
return _convert_spec_to_space(spec)
@property
def reward_range(self):
spec = self._env.reward_spec()
if isinstance(spec, specs.BoundedArray):
return spec.minimum, spec.maximum
return -float("inf"), float("inf")