ray/rllib/agents/dyna/dyna_torch_model.py

59 lines
2.1 KiB
Python
Raw Normal View History

import gym
from gym.spaces import Discrete
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()
class DYNATorchModel(TorchModelV2, nn.Module):
"""Extension of standard TorchModelV2 for Env dynamics learning.
Data flow:
obs.cat(action) -> forward() -> next_obs|next_obs_delta
get_next_state(obs, action) -> next_obs|next_obs_delta
Note that this class by itself is not a valid model unless you
implement forward() in a subclass.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
"""Initializes a DYNATorchModel object.
"""
nn.Module.__init__(self)
# Construct the wrapped model handing it a concat'd observation and
# action space as "input_space" and our obs_space as "output_space".
# TODO: (sven) get rid of these restrictions on obs/action spaces.
assert isinstance(action_space, Discrete)
input_space = gym.spaces.Box(
obs_space.low[0],
obs_space.high[0],
shape=(obs_space.shape[0] + action_space.n, ))
super(DYNATorchModel, self).__init__(input_space, action_space,
num_outputs, model_config, name)
def get_next_observation(self, observations, actions):
"""Returns a next obs prediction given current observation and action.
This implements p^(s'|s, a). With p being the environment dynamics.
Arguments:
observations (Tensor): The current observation Tensor.
actions (Tensor): The actions taken in `observations`.
Returns:
TensorType: The predicted next observations.
"""
# One-hot the actions.
actions_flat = nn.functional.one_hot(
actions, num_classes=self.action_space.n).float()
# Push through our underlying Model.
next_obs, _ = self.forward({
"obs_flat": torch.cat([observations, actions_flat], -1)
}, [], None)
return next_obs