2017-07-17 01:58:54 -07:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import gym
|
2018-10-21 23:43:57 -07:00
|
|
|
import logging
|
2018-01-18 19:51:31 -08:00
|
|
|
import numpy as np
|
|
|
|
import tensorflow as tf
|
|
|
|
from functools import partial
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2017-12-28 13:19:04 -08:00
|
|
|
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
|
2018-06-19 22:47:00 -07:00
|
|
|
_global_registry
|
2017-12-28 13:19:04 -08:00
|
|
|
|
2018-11-12 16:31:27 -08:00
|
|
|
from ray.rllib.env.async_vector_env import _ExternalEnvToAsync
|
|
|
|
from ray.rllib.env.external_env import ExternalEnv
|
2018-10-20 15:21:22 -07:00
|
|
|
from ray.rllib.env.vector_env import VectorEnv
|
2017-07-17 01:58:54 -07:00
|
|
|
from ray.rllib.models.action_dist import (
|
2018-06-19 19:47:26 -07:00
|
|
|
Categorical, Deterministic, DiagGaussian, MultiActionDistribution,
|
|
|
|
squash_to_range)
|
2018-01-05 21:32:41 -08:00
|
|
|
from ray.rllib.models.preprocessors import get_preprocessor
|
2017-07-17 01:58:54 -07:00
|
|
|
from ray.rllib.models.fcnet import FullyConnectedNetwork
|
|
|
|
from ray.rllib.models.visionnet import VisionNetwork
|
2018-06-26 13:17:15 -07:00
|
|
|
from ray.rllib.models.lstm import LSTM
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2018-10-21 23:43:57 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
# yapf: disable
|
2018-10-16 15:55:11 -07:00
|
|
|
# __sphinx_doc_begin__
|
|
|
|
MODEL_DEFAULTS = {
|
2017-12-28 13:19:04 -08:00
|
|
|
# === Built-in options ===
|
2018-08-02 14:29:40 -07:00
|
|
|
# Filter config. List of [out_channels, kernel, stride] for each filter
|
2018-10-16 15:55:11 -07:00
|
|
|
"conv_filters": None,
|
|
|
|
# Nonlinearity for built-in convnet
|
|
|
|
"conv_activation": "relu",
|
|
|
|
# Nonlinearity for fully connected net (tanh, relu)
|
|
|
|
"fcnet_activation": "tanh",
|
|
|
|
# Number of hidden layers for fully connected net
|
|
|
|
"fcnet_hiddens": [256, 256],
|
|
|
|
# For control envs, documented in ray.rllib.models.Model
|
|
|
|
"free_log_std": False,
|
|
|
|
# Whether to squash the action output to space range
|
|
|
|
"squash_to_range": False,
|
|
|
|
|
|
|
|
# == LSTM ==
|
|
|
|
# Whether to wrap the model with a LSTM
|
|
|
|
"use_lstm": False,
|
|
|
|
# Max seq len for training the LSTM, defaults to 20
|
|
|
|
"max_seq_len": 20,
|
|
|
|
# Size of the LSTM cell
|
|
|
|
"lstm_cell_size": 256,
|
2018-10-20 15:21:22 -07:00
|
|
|
# Whether to feed a_{t-1}, r_{t-1} to LSTM
|
|
|
|
"lstm_use_prev_action_reward": False,
|
2018-10-16 15:55:11 -07:00
|
|
|
|
|
|
|
# == Atari ==
|
|
|
|
# Whether to enable framestack for Atari envs
|
|
|
|
"framestack": True,
|
|
|
|
# Final resized frame dimension
|
|
|
|
"dim": 84,
|
|
|
|
# Pytorch conv requires images to be channel-major
|
|
|
|
"channel_major": False,
|
|
|
|
# (deprecated) Converts ATARI frame to 1 Channel Grayscale image
|
|
|
|
"grayscale": False,
|
|
|
|
# (deprecated) Changes frame to range from [-1, 1] if true
|
|
|
|
"zero_mean": True,
|
2017-12-28 13:19:04 -08:00
|
|
|
|
|
|
|
# === Options for custom models ===
|
2018-10-16 15:55:11 -07:00
|
|
|
# Name of a custom preprocessor to use
|
|
|
|
"custom_preprocessor": None,
|
|
|
|
# Name of a custom model to use
|
|
|
|
"custom_model": None,
|
|
|
|
# Extra options to pass to the custom classes
|
|
|
|
"custom_options": {},
|
|
|
|
}
|
|
|
|
# __sphinx_doc_end__
|
2018-10-21 23:43:57 -07:00
|
|
|
# yapf: enable
|
2017-09-02 17:20:56 -07:00
|
|
|
|
|
|
|
|
2017-07-17 01:58:54 -07:00
|
|
|
class ModelCatalog(object):
|
2018-01-01 11:10:44 -08:00
|
|
|
"""Registry of models, preprocessors, and action distributions for envs.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> prep = ModelCatalog.get_preprocessor(env)
|
|
|
|
>>> observation = prep.transform(raw_observation)
|
|
|
|
|
2018-10-01 12:49:39 -07:00
|
|
|
>>> dist_cls, dist_dim = ModelCatalog.get_action_dist(
|
|
|
|
env.action_space, {})
|
|
|
|
>>> model = ModelCatalog.get_model(inputs, dist_dim, options)
|
2018-01-01 11:10:44 -08:00
|
|
|
>>> dist = dist_cls(model.outputs)
|
|
|
|
>>> action = dist.sample()
|
|
|
|
"""
|
2017-07-17 01:58:54 -07:00
|
|
|
|
|
|
|
@staticmethod
|
2018-10-01 12:49:39 -07:00
|
|
|
def get_action_dist(action_space, config, dist_type=None):
|
2017-07-17 01:58:54 -07:00
|
|
|
"""Returns action distribution class and size for the given action space.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
action_space (Space): Action space of the target gym env.
|
2018-06-19 19:47:26 -07:00
|
|
|
config (dict): Optional model config.
|
2017-11-24 10:36:57 -08:00
|
|
|
dist_type (str): Optional identifier of the action distribution.
|
2017-07-17 01:58:54 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
dist_class (ActionDistribution): Python class of the distribution.
|
|
|
|
dist_dim (int): The size of the input vector to the distribution.
|
|
|
|
"""
|
|
|
|
|
2018-10-16 15:55:11 -07:00
|
|
|
config = config or MODEL_DEFAULTS
|
2017-07-17 01:58:54 -07:00
|
|
|
if isinstance(action_space, gym.spaces.Box):
|
2018-10-26 16:55:00 -07:00
|
|
|
if len(action_space.shape) > 1:
|
|
|
|
raise ValueError(
|
|
|
|
"Action space has multiple dimensions "
|
|
|
|
"{}. ".format(action_space.shape) +
|
|
|
|
"Consider reshaping this into a single dimension, "
|
|
|
|
"using a Tuple action space, or the multi-agent API.")
|
2017-07-17 01:58:54 -07:00
|
|
|
if dist_type is None:
|
2018-06-19 19:47:26 -07:00
|
|
|
dist = DiagGaussian
|
|
|
|
if config.get("squash_to_range"):
|
2018-07-19 15:30:36 -07:00
|
|
|
dist = squash_to_range(dist, action_space.low,
|
|
|
|
action_space.high)
|
2018-06-19 19:47:26 -07:00
|
|
|
return dist, action_space.shape[0] * 2
|
2018-10-16 15:55:11 -07:00
|
|
|
elif dist_type == "deterministic":
|
2017-07-17 01:58:54 -07:00
|
|
|
return Deterministic, action_space.shape[0]
|
|
|
|
elif isinstance(action_space, gym.spaces.Discrete):
|
|
|
|
return Categorical, action_space.n
|
2018-01-24 11:03:43 -08:00
|
|
|
elif isinstance(action_space, gym.spaces.Tuple):
|
2018-01-18 19:51:31 -08:00
|
|
|
child_dist = []
|
2018-08-11 10:57:40 -07:00
|
|
|
input_lens = []
|
2018-01-24 11:03:43 -08:00
|
|
|
for action in action_space.spaces:
|
2018-10-01 12:49:39 -07:00
|
|
|
dist, action_size = ModelCatalog.get_action_dist(
|
|
|
|
action, config)
|
2018-01-18 19:51:31 -08:00
|
|
|
child_dist.append(dist)
|
2018-08-11 10:57:40 -07:00
|
|
|
input_lens.append(action_size)
|
2018-07-19 15:30:36 -07:00
|
|
|
return partial(
|
|
|
|
MultiActionDistribution,
|
|
|
|
child_distributions=child_dist,
|
2018-08-11 10:57:40 -07:00
|
|
|
action_space=action_space,
|
|
|
|
input_lens=input_lens), sum(input_lens)
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2018-07-19 15:30:36 -07:00
|
|
|
raise NotImplementedError("Unsupported args: {} {}".format(
|
|
|
|
action_space, dist_type))
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2018-01-18 19:51:31 -08:00
|
|
|
@staticmethod
|
|
|
|
def get_action_placeholder(action_space):
|
|
|
|
"""Returns an action placeholder that is consistent with the action space
|
|
|
|
|
|
|
|
Args:
|
|
|
|
action_space (Space): Action space of the target gym env.
|
|
|
|
Returns:
|
|
|
|
action_placeholder (Tensor): A placeholder for the actions
|
|
|
|
"""
|
|
|
|
|
|
|
|
if isinstance(action_space, gym.spaces.Box):
|
|
|
|
return tf.placeholder(
|
2018-06-27 22:51:04 -07:00
|
|
|
tf.float32, shape=(None, action_space.shape[0]), name="action")
|
2018-01-18 19:51:31 -08:00
|
|
|
elif isinstance(action_space, gym.spaces.Discrete):
|
2018-07-19 15:30:36 -07:00
|
|
|
return tf.placeholder(tf.int64, shape=(None, ), name="action")
|
2018-01-24 11:03:43 -08:00
|
|
|
elif isinstance(action_space, gym.spaces.Tuple):
|
2018-01-18 19:51:31 -08:00
|
|
|
size = 0
|
2018-01-24 11:03:43 -08:00
|
|
|
all_discrete = True
|
|
|
|
for i in range(len(action_space.spaces)):
|
|
|
|
if isinstance(action_space.spaces[i], gym.spaces.Discrete):
|
|
|
|
size += 1
|
|
|
|
else:
|
|
|
|
all_discrete = False
|
|
|
|
size += np.product(action_space.spaces[i].shape)
|
|
|
|
return tf.placeholder(
|
2018-07-19 15:30:36 -07:00
|
|
|
tf.int64 if all_discrete else tf.float32,
|
|
|
|
shape=(None, size),
|
2018-06-27 22:51:04 -07:00
|
|
|
name="action")
|
2018-01-18 19:51:31 -08:00
|
|
|
else:
|
|
|
|
raise NotImplementedError("action space {}"
|
|
|
|
" not supported".format(action_space))
|
|
|
|
|
2017-07-17 01:58:54 -07:00
|
|
|
@staticmethod
|
2018-10-20 15:21:22 -07:00
|
|
|
def get_model(input_dict,
|
|
|
|
obs_space,
|
|
|
|
num_outputs,
|
|
|
|
options,
|
|
|
|
state_in=None,
|
|
|
|
seq_lens=None):
|
2017-07-17 01:58:54 -07:00
|
|
|
"""Returns a suitable model conforming to given input and output specs.
|
|
|
|
|
|
|
|
Args:
|
2018-10-20 15:21:22 -07:00
|
|
|
input_dict (dict): Dict of input tensors to the model, including
|
|
|
|
the observation under the "obs" key.
|
|
|
|
obs_space (Space): Observation space of the target gym env.
|
2017-07-17 01:58:54 -07:00
|
|
|
num_outputs (int): The size of the output vector of the model.
|
2017-07-26 12:29:00 -07:00
|
|
|
options (dict): Optional args to pass to the model constructor.
|
2018-07-17 06:55:46 +02:00
|
|
|
state_in (list): Optional RNN state in tensors.
|
|
|
|
seq_in (Tensor): Optional RNN sequence length tensor.
|
2017-07-17 01:58:54 -07:00
|
|
|
|
|
|
|
Returns:
|
2018-11-10 21:52:20 -08:00
|
|
|
model (models.Model): Neural network model.
|
2017-07-17 01:58:54 -07:00
|
|
|
"""
|
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
assert isinstance(input_dict, dict)
|
2018-10-16 15:55:11 -07:00
|
|
|
options = options or MODEL_DEFAULTS
|
2018-10-20 15:21:22 -07:00
|
|
|
model = ModelCatalog._get_model(input_dict, obs_space, num_outputs,
|
|
|
|
options, state_in, seq_lens)
|
2018-06-27 22:51:04 -07:00
|
|
|
|
|
|
|
if options.get("use_lstm"):
|
2018-10-20 15:21:22 -07:00
|
|
|
copy = dict(input_dict)
|
|
|
|
copy["obs"] = model.last_layer
|
|
|
|
model = LSTM(copy, obs_space, num_outputs, options, state_in,
|
2018-07-19 15:30:36 -07:00
|
|
|
seq_lens)
|
2018-06-27 22:51:04 -07:00
|
|
|
|
2018-11-07 14:54:28 -08:00
|
|
|
logger.debug("Created model {}: ({} of {}, {}, {}) -> {}, {}".format(
|
|
|
|
model, input_dict, obs_space, state_in, seq_lens, model.outputs,
|
|
|
|
model.state_out))
|
2018-11-14 14:14:07 -08:00
|
|
|
|
|
|
|
model._validate_output_shape()
|
2018-06-27 22:51:04 -07:00
|
|
|
return model
|
|
|
|
|
|
|
|
@staticmethod
|
2018-10-20 15:21:22 -07:00
|
|
|
def _get_model(input_dict, obs_space, num_outputs, options, state_in,
|
|
|
|
seq_lens):
|
2018-10-16 15:55:11 -07:00
|
|
|
if options.get("custom_model"):
|
2017-12-28 13:19:04 -08:00
|
|
|
model = options["custom_model"]
|
2018-10-21 23:43:57 -07:00
|
|
|
logger.info("Using custom model {}".format(model))
|
2018-06-19 22:47:00 -07:00
|
|
|
return _global_registry.get(RLLIB_MODEL, model)(
|
2018-10-20 15:21:22 -07:00
|
|
|
input_dict,
|
|
|
|
obs_space,
|
2018-07-19 15:30:36 -07:00
|
|
|
num_outputs,
|
|
|
|
options,
|
|
|
|
state_in=state_in,
|
|
|
|
seq_lens=seq_lens)
|
2017-12-28 13:19:04 -08:00
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
obs_rank = len(input_dict["obs"].shape) - 1
|
2017-07-17 01:58:54 -07:00
|
|
|
|
|
|
|
if obs_rank > 1:
|
2018-10-20 15:21:22 -07:00
|
|
|
return VisionNetwork(input_dict, obs_space, num_outputs, options)
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
return FullyConnectedNetwork(input_dict, obs_space, num_outputs,
|
|
|
|
options)
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2017-11-12 00:20:33 -08:00
|
|
|
@staticmethod
|
2018-10-16 15:55:11 -07:00
|
|
|
def get_torch_model(input_shape, num_outputs, options=None):
|
2017-11-24 10:36:57 -08:00
|
|
|
"""Returns a PyTorch suitable model. This is currently only supported
|
|
|
|
in A3C.
|
2017-11-12 00:20:33 -08:00
|
|
|
|
|
|
|
Args:
|
2017-11-24 10:36:57 -08:00
|
|
|
input_shape (tuple): The input shape to the model.
|
2017-11-12 00:20:33 -08:00
|
|
|
num_outputs (int): The size of the output vector of the model.
|
|
|
|
options (dict): Optional args to pass to the model constructor.
|
|
|
|
|
|
|
|
Returns:
|
2018-11-10 21:52:20 -08:00
|
|
|
model (models.Model): Neural network model.
|
2017-11-12 00:20:33 -08:00
|
|
|
"""
|
2018-07-19 15:30:36 -07:00
|
|
|
from ray.rllib.models.pytorch.fcnet import (FullyConnectedNetwork as
|
|
|
|
PyTorchFCNet)
|
|
|
|
from ray.rllib.models.pytorch.visionnet import (VisionNetwork as
|
|
|
|
PyTorchVisionNet)
|
2017-11-12 00:20:33 -08:00
|
|
|
|
2018-10-16 15:55:11 -07:00
|
|
|
options = options or MODEL_DEFAULTS
|
|
|
|
if options.get("custom_model"):
|
2017-12-28 13:19:04 -08:00
|
|
|
model = options["custom_model"]
|
2018-10-21 23:43:57 -07:00
|
|
|
logger.info("Using custom torch model {}".format(model))
|
2018-06-19 22:47:00 -07:00
|
|
|
return _global_registry.get(RLLIB_MODEL, model)(
|
2017-12-28 13:19:04 -08:00
|
|
|
input_shape, num_outputs, options)
|
|
|
|
|
2018-05-30 10:48:11 -07:00
|
|
|
# TODO(alok): fix to handle Discrete(n) state spaces
|
2017-11-12 00:20:33 -08:00
|
|
|
obs_rank = len(input_shape) - 1
|
|
|
|
|
|
|
|
if obs_rank > 1:
|
|
|
|
return PyTorchVisionNet(input_shape, num_outputs, options)
|
|
|
|
|
2018-05-30 10:48:11 -07:00
|
|
|
# TODO(alok): overhaul PyTorchFCNet so it can just
|
|
|
|
# take input shape directly
|
2017-11-12 00:20:33 -08:00
|
|
|
return PyTorchFCNet(input_shape[0], num_outputs, options)
|
|
|
|
|
2017-12-28 13:19:04 -08:00
|
|
|
@staticmethod
|
2018-10-16 15:55:11 -07:00
|
|
|
def get_preprocessor(env, options=None):
|
2017-07-17 01:58:54 -07:00
|
|
|
"""Returns a suitable processor for the given environment.
|
|
|
|
|
|
|
|
Args:
|
2018-11-12 16:31:27 -08:00
|
|
|
env (gym.Env|VectorEnv|ExternalEnv): The environment to wrap.
|
2017-10-14 20:16:36 -07:00
|
|
|
options (dict): Options to pass to the preprocessor.
|
2017-07-17 01:58:54 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
preprocessor (Preprocessor): Preprocessor for the env observations.
|
|
|
|
"""
|
2018-10-16 15:55:11 -07:00
|
|
|
options = options or MODEL_DEFAULTS
|
2017-09-02 17:20:56 -07:00
|
|
|
for k in options.keys():
|
2018-10-16 15:55:11 -07:00
|
|
|
if k not in MODEL_DEFAULTS:
|
2018-07-19 15:30:36 -07:00
|
|
|
raise Exception("Unknown config key `{}`, all keys: {}".format(
|
2018-10-16 15:55:11 -07:00
|
|
|
k, list(MODEL_DEFAULTS)))
|
2017-09-02 17:20:56 -07:00
|
|
|
|
2018-10-16 15:55:11 -07:00
|
|
|
if options.get("custom_preprocessor"):
|
2017-12-28 13:19:04 -08:00
|
|
|
preprocessor = options["custom_preprocessor"]
|
2018-10-21 23:43:57 -07:00
|
|
|
logger.info("Using custom preprocessor {}".format(preprocessor))
|
2018-11-07 14:54:28 -08:00
|
|
|
prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
|
2017-12-28 13:19:04 -08:00
|
|
|
env.observation_space, options)
|
2018-11-07 14:54:28 -08:00
|
|
|
else:
|
|
|
|
cls = get_preprocessor(env.observation_space)
|
|
|
|
prep = cls(env.observation_space, options)
|
2017-09-16 15:53:19 -07:00
|
|
|
|
2018-11-07 14:54:28 -08:00
|
|
|
logger.debug("Created preprocessor {}: {} -> {}".format(
|
|
|
|
prep, env.observation_space, prep.shape))
|
|
|
|
return prep
|
2017-09-16 15:53:19 -07:00
|
|
|
|
2017-12-28 13:19:04 -08:00
|
|
|
@staticmethod
|
2018-10-16 15:55:11 -07:00
|
|
|
def get_preprocessor_as_wrapper(env, options=None):
|
2017-10-14 20:16:36 -07:00
|
|
|
"""Returns a preprocessor as a gym observation wrapper.
|
|
|
|
|
|
|
|
Args:
|
2018-11-12 16:31:27 -08:00
|
|
|
env (gym.Env|VectorEnv|ExternalEnv): The environment to wrap.
|
2017-10-14 20:16:36 -07:00
|
|
|
options (dict): Options to pass to the preprocessor.
|
|
|
|
|
|
|
|
Returns:
|
2018-10-20 15:21:22 -07:00
|
|
|
env (RLlib env): Wrapped environment
|
2017-10-14 20:16:36 -07:00
|
|
|
"""
|
|
|
|
|
2018-10-16 15:55:11 -07:00
|
|
|
options = options or MODEL_DEFAULTS
|
2018-06-19 22:47:00 -07:00
|
|
|
preprocessor = ModelCatalog.get_preprocessor(env, options)
|
2018-10-20 15:21:22 -07:00
|
|
|
if isinstance(env, gym.Env):
|
|
|
|
return _RLlibPreprocessorWrapper(env, preprocessor)
|
|
|
|
elif isinstance(env, VectorEnv):
|
|
|
|
return _RLlibVectorPreprocessorWrapper(env, preprocessor)
|
2018-11-12 16:31:27 -08:00
|
|
|
elif isinstance(env, ExternalEnv):
|
|
|
|
return _ExternalEnvToAsync(env, preprocessor)
|
2018-10-20 15:21:22 -07:00
|
|
|
else:
|
|
|
|
raise ValueError("Don't know how to wrap {}".format(env))
|
2017-10-14 20:16:36 -07:00
|
|
|
|
2017-12-28 13:19:04 -08:00
|
|
|
@staticmethod
|
|
|
|
def register_custom_preprocessor(preprocessor_name, preprocessor_class):
|
|
|
|
"""Register a custom preprocessor class by name.
|
|
|
|
|
|
|
|
The preprocessor can be later used by specifying
|
|
|
|
{"custom_preprocessor": preprocesor_name} in the model config.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
preprocessor_name (str): Name to register the preprocessor under.
|
|
|
|
preprocessor_class (type): Python class of the preprocessor.
|
|
|
|
"""
|
2018-07-19 15:30:36 -07:00
|
|
|
_global_registry.register(RLLIB_PREPROCESSOR, preprocessor_name,
|
|
|
|
preprocessor_class)
|
2017-12-28 13:19:04 -08:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def register_custom_model(model_name, model_class):
|
|
|
|
"""Register a custom model class by name.
|
|
|
|
|
|
|
|
The model can be later used by specifying {"custom_model": model_name}
|
|
|
|
in the model config.
|
2017-09-16 15:53:19 -07:00
|
|
|
|
|
|
|
Args:
|
2017-12-28 13:19:04 -08:00
|
|
|
model_name (str): Name to register the model under.
|
|
|
|
model_class (type): Python class of the model.
|
2017-09-16 15:53:19 -07:00
|
|
|
"""
|
2018-06-19 22:47:00 -07:00
|
|
|
_global_registry.register(RLLIB_MODEL, model_name, model_class)
|
2017-10-14 20:16:36 -07:00
|
|
|
|
|
|
|
|
|
|
|
class _RLlibPreprocessorWrapper(gym.ObservationWrapper):
|
|
|
|
"""Adapts a RLlib preprocessor for use as an observation wrapper."""
|
|
|
|
|
|
|
|
def __init__(self, env, preprocessor):
|
|
|
|
super(_RLlibPreprocessorWrapper, self).__init__(env)
|
|
|
|
self.preprocessor = preprocessor
|
2018-10-20 15:21:22 -07:00
|
|
|
self.observation_space = preprocessor.observation_space
|
2017-10-14 20:16:36 -07:00
|
|
|
|
2018-03-06 08:31:02 +00:00
|
|
|
def observation(self, observation):
|
2017-10-14 20:16:36 -07:00
|
|
|
return self.preprocessor.transform(observation)
|
2018-10-20 15:21:22 -07:00
|
|
|
|
|
|
|
|
|
|
|
class _RLlibVectorPreprocessorWrapper(VectorEnv):
|
|
|
|
"""Preprocessing wrapper for vector envs."""
|
|
|
|
|
|
|
|
def __init__(self, env, preprocessor):
|
|
|
|
self.env = env
|
|
|
|
self.prep = preprocessor
|
|
|
|
self.action_space = env.action_space
|
|
|
|
self.observation_space = preprocessor.observation_space
|
|
|
|
self.num_envs = env.num_envs
|
|
|
|
|
|
|
|
def vector_reset(self):
|
|
|
|
return [self.prep.transform(obs) for obs in self.env.vector_reset()]
|
|
|
|
|
|
|
|
def reset_at(self, index):
|
|
|
|
return self.prep.transform(self.env.reset_at(index))
|
|
|
|
|
|
|
|
def vector_step(self, actions):
|
|
|
|
obs, rewards, dones, infos = self.env.vector_step(actions)
|
|
|
|
obs = [self.prep.transform(o) for o in obs]
|
|
|
|
return obs, rewards, dones, infos
|
|
|
|
|
|
|
|
def get_unwrapped(self):
|
|
|
|
return self.env.get_unwrapped()
|