[rllib] ModelV2 support for pytorch (#5249)

This commit is contained in:
Eric Liang 2019-07-25 11:02:53 -07:00 committed by GitHub
parent 40395acadf
commit bf9199ad77
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 222 additions and 220 deletions

View file

@ -191,35 +191,44 @@ Similarly, you can create and register custom PyTorch models for use with PyTorc
import ray
from ray.rllib.agents import a3c
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.model import TorchModel
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
class CustomTorchModel(TorchModel):
class CustomTorchModel(TorchModelV2):
def __init__(self, obs_space, num_outputs, options):
TorchModel.__init__(self, obs_space, num_outputs, options)
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(CustomTorchModel, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
... # setup hidden layers
def _forward(self, input_dict, hidden_state):
"""Forward pass for the model.
def forward(self, input_dict, state, seq_lens):
"""Call the model with the given input tensors and state.
Prefer implementing this instead of forward() directly for proper
handling of Dict and Tuple observations.
Any complex observations (dicts, tuples, etc.) will be unpacked by
__call__ before being passed to forward(). To access the flattened
observation tensor, refer to input_dict["obs_flat"].
This method can be called any number of times. In eager execution,
each call to forward() will eagerly evaluate the model. In symbolic
execution, each call to forward creates a computation graph that
operates over the variables of this model (i.e., shares weights).
Custom models should override this instead of __call__.
Arguments:
input_dict (dict): Dictionary of tensor inputs, commonly
including "obs", "prev_action", "prev_reward", each of shape
[BATCH_SIZE, ...].
hidden_state (list): List of hidden state tensors, each of shape
[BATCH_SIZE, h_size].
input_dict (dict): dictionary of input tensors, including "obs",
"obs_flat", "prev_action", "prev_reward", "is_training"
state (list): list of state tensors with sizes matching those
returned by get_initial_state + the batch dimension
seq_lens (Tensor): 1d tensor holding input sequence lengths
Returns:
(outputs, feature_layer, values, state): Tensors of size
[BATCH_SIZE, num_outputs], [BATCH_SIZE, desired_feature_size],
[BATCH_SIZE], and [len(hidden_state), BATCH_SIZE, h_size].
(outputs, state): The model output tensor of size
[BATCH, num_outputs]
"""
obs = input_dict["obs"]
...
return logits, features, value, hidden_state
return logits, state
ModelCatalog.register_custom_model("my_model", CustomTorchModel)

View file

@ -14,9 +14,10 @@ from ray.rllib.policy.torch_policy_template import build_torch_policy
def actor_critic_loss(policy, batch_tensors):
logits, _, values, _ = policy.model({
logits, _ = policy.model({
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
}, [])
}) # TODO(ekl) seq lens shouldn't be None
values = policy.model.value_function()
dist = policy.dist_class(logits)
log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS])
policy.entropy = dist.entropy().mean()
@ -53,8 +54,8 @@ def add_advantages(policy,
policy.config["lambda"])
def model_value_predictions(policy, input_dict, state_batches, model_out):
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}
def model_value_predictions(policy, input_dict, state_batches, model):
return {SampleBatch.VF_PREDS: model.value_function().cpu().numpy()}
def apply_grad_clipping(policy):
@ -74,8 +75,8 @@ class ValueNetworkMixin(object):
def _value(self, obs):
with self.lock:
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
_, _, vf, _ = self.model({"obs": obs}, [])
return vf.detach().cpu().numpy().squeeze()
_ = self.model({"obs": obs}, [], [1])
return self.model.value_function().detach().cpu().numpy().squeeze()
A3CTorchPolicy = build_torch_policy(

View file

@ -10,9 +10,9 @@ from ray.rllib.policy.torch_policy_template import build_torch_policy
def pg_torch_loss(policy, batch_tensors):
logits, _, values, _ = policy.model({
logits, _ = policy.model({
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
}, [])
})
action_dist = policy.dist_class(logits)
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
# save the error in the policy object

View file

@ -6,33 +6,35 @@ from torch import nn
import torch.nn.functional as F
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.model import TorchModel
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
class RNNModel(TorchModel):
class RNNModel(TorchModelV2):
"""The default RNN model for QMIX."""
def __init__(self, obs_space, num_outputs, options):
TorchModel.__init__(self, obs_space, num_outputs, options)
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(RNNModel, self).__init__(obs_space, action_space, num_outputs,
model_config, name)
self.obs_size = _get_size(obs_space)
self.rnn_hidden_dim = options["lstm_cell_size"]
self.rnn_hidden_dim = model_config["lstm_cell_size"]
self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
@override(TorchModel)
def state_init(self):
@override(TorchModelV2)
def get_initial_state(self):
# make hidden states on same device as model
return [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
@override(TorchModel)
def _forward(self, input_dict, hidden_state):
x = F.relu(self.fc1(input_dict["obs"]))
@override(TorchModelV2)
def forward(self, input_dict, hidden_state, seq_lens):
x = F.relu(self.fc1(input_dict["obs_flat"].float()))
h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim)
h = self.rnn(x, h_in)
q = self.fc2(h)
return q, h, None, [h]
return q, [h]
def _get_size(obs_space):

View file

@ -65,7 +65,10 @@ class QMixLoss(nn.Module):
# Calculate estimated Q-Values
mac_out = []
h = [s.expand([B, self.n_agents, -1]) for s in self.model.state_init()]
h = [
s.expand([B, self.n_agents, -1])
for s in self.model.get_initial_state()
]
for t in range(T):
q, h = _mac(self.model, obs[:, t], h)
mac_out.append(q)
@ -79,7 +82,7 @@ class QMixLoss(nn.Module):
target_mac_out = []
target_h = [
s.expand([B, self.n_agents, -1])
for s in self.target_model.state_init()
for s in self.target_model.get_initial_state()
]
for t in range(T):
target_q, target_h = _mac(self.target_model, next_obs[:, t],
@ -171,16 +174,23 @@ class QMixTorchPolicy(Policy):
self.has_action_mask = False
self.obs_size = _get_size(agent_obs_space)
self.model = ModelCatalog.get_torch_model(
self.model = ModelCatalog.get_model_v2(
agent_obs_space,
action_space.spaces[0],
self.n_actions,
config["model"],
default_model_cls=RNNModel)
self.target_model = ModelCatalog.get_torch_model(
framework="torch",
name="model",
default_model=RNNModel)
self.target_model = ModelCatalog.get_model_v2(
agent_obs_space,
action_space.spaces[0],
self.n_actions,
config["model"],
default_model_cls=RNNModel)
framework="torch",
name="target_model",
default_model=RNNModel)
# Setup the mixer network.
# The global state is just the stacked agent observations for now.
@ -320,7 +330,7 @@ class QMixTorchPolicy(Policy):
def get_initial_state(self):
return [
s.expand([self.n_agents, -1]).numpy()
for s in self.model.state_init()
for s in self.model.get_initial_state()
]
@override(Policy)
@ -425,7 +435,7 @@ def _mac(model, obs, h):
"""Forward pass of the multi-agent controller.
Arguments:
model: TorchModel class
model: TorchModelV2 class
obs: Tensor of shape [B, n_agents, obs_size]
h: List of tensors of shape [B, n_agents, h_size]
@ -436,6 +446,6 @@ def _mac(model, obs, h):
B, n_agents = obs.size(0), obs.size(1)
obs_flat = obs.reshape([B * n_agents, -1])
h_flat = [s.reshape([B * n_agents, -1]) for s in h]
q_flat, _, _, h_flat = model.forward({"obs": obs_flat}, h_flat)
q_flat, h_flat = model({"obs": obs_flat}, h_flat, None)
return q_flat.reshape(
[B, n_agents, -1]), [s.reshape([B, n_agents, -1]) for s in h_flat]

View file

@ -15,9 +15,9 @@ parser.add_argument("--iters", type=int, default=200)
def policy_gradient_loss(policy, batch_tensors):
logits, _, values, _ = policy.model({
logits, _ = policy.model({
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
}, [])
})
action_dist = policy.dist_class(logits)
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
return -batch_tensors[SampleBatch.REWARDS].dot(log_probs)

View file

@ -208,7 +208,7 @@ class ModelCatalog(object):
action_space,
num_outputs,
model_config,
framework="tf",
framework,
name=None,
model_interface=None,
default_model=None,
@ -240,29 +240,37 @@ class ModelCatalog(object):
model_interface):
raise ValueError("The given model must subclass",
model_interface)
created = set()
# Track and warn if variables were created but no registered
def track_var_creation(next_creator, **kw):
v = next_creator(**kw)
created.add(v)
return v
if framework == "tf":
created = set()
with tf.variable_creator_scope(track_var_creation):
# Track and warn if vars were created but not registered
def track_var_creation(next_creator, **kw):
v = next_creator(**kw)
created.add(v)
return v
with tf.variable_creator_scope(track_var_creation):
instance = model_cls(obs_space, action_space,
num_outputs, model_config, name,
**model_kwargs)
registered = set(instance.variables())
not_registered = set()
for var in created:
if var not in registered:
not_registered.add(var)
if not_registered:
raise ValueError(
"It looks like variables {} were created as part "
"of {} but does not appear in model.variables() "
"({}). Did you forget to call "
"model.register_variables() on the variables in "
"question?".format(not_registered, instance,
registered))
else:
# no variable tracking
instance = model_cls(obs_space, action_space, num_outputs,
model_config, name, **model_kwargs)
registered = set(instance.variables())
not_registered = set()
for var in created:
if var not in registered:
not_registered.add(var)
if not_registered:
raise ValueError(
"It looks like variables {} were created as part of "
"{} but does not appear in model.variables() ({}). "
"Did you forget to call model.register_variables() "
"on the variables in question?".format(
not_registered, instance, registered))
return instance
if framework == "tf":
@ -271,8 +279,15 @@ class ModelCatalog(object):
make_v1_wrapper(legacy_model_cls), model_interface)
return wrapper(obs_space, action_space, num_outputs, model_config,
name, **model_kwargs)
raise NotImplementedError("TODO: support {} models".format(framework))
elif framework == "torch":
if default_model:
return default_model(obs_space, action_space, num_outputs,
model_config, name)
return ModelCatalog._get_default_torch_model_v2(
obs_space, action_space, num_outputs, model_config, name)
else:
raise NotImplementedError(
"Framework must be 'tf' or 'torch': {}".format(framework))
@staticmethod
def _wrap_if_needed(model_cls, model_interface):
@ -367,46 +382,32 @@ class ModelCatalog(object):
num_outputs,
options=None,
default_model_cls=None):
"""Returns a custom model for PyTorch algorithms.
raise DeprecationWarning("Please use get_model_v2() instead.")
Args:
obs_space (Space): The input observation space.
num_outputs (int): The size of the output vector of the model.
options (dict): Optional args to pass to the model constructor.
default_model_cls (cls): Optional class to use if no custom model.
Returns:
model (models.Model): Neural network model.
"""
def _get_default_torch_model_v2(obs_space, action_space, num_outputs,
model_config, name):
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
PyTorchFCNet)
from ray.rllib.models.torch.visionnet import (VisionNetwork as
PyTorchVisionNet)
options = options or MODEL_DEFAULTS
model_config = model_config or MODEL_DEFAULTS
if options.get("custom_model"):
model = options["custom_model"]
logger.debug("Using custom torch model {}".format(model))
return _global_registry.get(RLLIB_MODEL,
model)(obs_space, num_outputs, options)
if options.get("use_lstm"):
if model_config.get("use_lstm"):
raise NotImplementedError(
"LSTM auto-wrapping not implemented for torch")
if default_model_cls:
return default_model_cls(obs_space, num_outputs, options)
if isinstance(obs_space, gym.spaces.Discrete):
obs_rank = 1
else:
obs_rank = len(obs_space.shape)
if obs_rank > 1:
return PyTorchVisionNet(obs_space, num_outputs, options)
return PyTorchVisionNet(obs_space, action_space, num_outputs,
model_config, name)
return PyTorchFCNet(obs_space, num_outputs, options)
return PyTorchFCNet(obs_space, action_space, num_outputs, model_config,
name)
@staticmethod
@DeveloperAPI

View file

@ -10,6 +10,7 @@ from ray.rllib.utils import try_import_tf
tf = try_import_tf()
# TODO(ekl) rewrite this using ModelV2
class FullyConnectedNetwork(Model):
"""Generic fully connected network."""

View file

@ -14,10 +14,12 @@ from ray.rllib.utils import try_import_tf
tf = try_import_tf()
@PublicAPI
# Deprecated: use TFModelV2 instead
class Model(object):
"""Defines an abstract network model for use with RLlib.
This class is deprecated: please use TFModelV2 instead.
Models convert input tensors to a number of output features. These features
can then be interpreted by ActionDistribution classes to determine
e.g. agent action values.

View file

@ -3,8 +3,10 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.models.model import restore_original_dimensions
from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class ModelV2(object):
"""Defines a Keras-style abstract network model for use with RLlib.
@ -39,7 +41,6 @@ class ModelV2(object):
self.model_config = model_config
self.name = name or "default_model"
self.framework = framework
self.var_list = []
def get_initial_state(self):
"""Get the initial recurrent state values for the model.
@ -118,19 +119,7 @@ class ModelV2(object):
"""
return {}
def register_variables(self, variables):
"""Register the given list of variables with this model."""
self.var_list.extend(variables)
def variables(self):
"""Returns the list of variables for this model."""
return list(self.var_list)
def trainable_variables(self):
"""Returns the list of trainable variables for this model."""
return [v for v in self.variables() if v.trainable]
def __call__(self, input_dict, state, seq_lens):
def __call__(self, input_dict, state=None, seq_lens=None):
"""Call the model with the given input tensors and state.
This is the method used by RLlib to execute the forward pass. It calls
@ -156,7 +145,7 @@ class ModelV2(object):
restored["obs"] = restore_original_dimensions(
input_dict["obs"], self.obs_space, self.framework)
restored["obs_flat"] = input_dict["obs"]
outputs, state = self.forward(restored, state, seq_lens)
outputs, state = self.forward(restored, state or [], seq_lens)
try:
shape = outputs.shape

View file

@ -98,7 +98,7 @@ def make_v1_wrapper(legacy_model_cls):
"Cannot get update ops before wrapped v1 model init")
return list(self._update_ops)
@override(ModelV2)
@override(TFModelV2)
def variables(self):
var_list = super(ModelV1Wrapper, self).variables()
for v in scope_vars(self.variable_scope):

View file

@ -3,13 +3,18 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
@PublicAPI
class TFModelV2(ModelV2):
"""TF version of ModelV2."""
"""TF version of ModelV2.
Note that this class by itself is not a valid model unless you
implement forward() in a subclass."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
@ -21,9 +26,22 @@ class TFModelV2(ModelV2):
model_config,
name,
framework="tf")
self.var_list = []
def update_ops(self):
"""Return the list of update ops for this model.
For example, this should include any BatchNorm update ops."""
return []
def register_variables(self, variables):
"""Register the given list of variables with this model."""
self.var_list.extend(variables)
def variables(self):
"""Returns the list of variables for this model."""
return list(self.var_list)
def trainable_variables(self):
"""Returns the list of trainable variables for this model."""
return [v for v in self.variables() if v.trainable]

View file

@ -6,7 +6,7 @@ import logging
import numpy as np
import torch.nn as nn
from ray.rllib.models.torch.model import TorchModel
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import normc_initializer, SlimFC, \
_get_activation_fn
from ray.rllib.utils.annotations import override
@ -14,13 +14,16 @@ from ray.rllib.utils.annotations import override
logger = logging.getLogger(__name__)
class FullyConnectedNetwork(TorchModel):
class FullyConnectedNetwork(TorchModelV2):
"""Generic fully connected network."""
def __init__(self, obs_space, num_outputs, options):
TorchModel.__init__(self, obs_space, num_outputs, options)
hiddens = options.get("fcnet_hiddens")
activation = _get_activation_fn(options.get("fcnet_activation"))
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(FullyConnectedNetwork, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
hiddens = model_config.get("fcnet_hiddens")
activation = _get_activation_fn(model_config.get("fcnet_activation"))
logger.debug("Constructing fcnet {} {}".format(hiddens, activation))
layers = []
last_layer_size = np.product(obs_space.shape)
@ -45,13 +48,17 @@ class FullyConnectedNetwork(TorchModel):
out_size=1,
initializer=normc_initializer(1.0),
activation_fn=None)
self._cur_value = None
@override(nn.Module)
def forward(self, input_dict, hidden_state):
# Note that we override forward() and not _forward() to get the
# flattened obs here.
obs = input_dict["obs"]
@override(TorchModelV2)
def forward(self, input_dict, state, seq_lens):
obs = input_dict["obs_flat"]
features = self._hidden_layers(obs.reshape(obs.shape[0], -1))
logits = self._logits(features)
value = self._value_branch(features).squeeze(1)
return logits, features, value, hidden_state
self._cur_value = self._value_branch(features).squeeze(1)
return logits, state
@override(TorchModelV2)
def value_function(self):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value

View file

@ -1,65 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
from ray.rllib.models.model import restore_original_dimensions
from ray.rllib.utils.annotations import PublicAPI
# TODO(ekl) rewrite using modelv2
@PublicAPI
class TorchModel(nn.Module):
"""Defines an abstract network model for use with RLlib / PyTorch."""
def __init__(self, obs_space, num_outputs, options):
"""All custom RLlib torch models must support this constructor.
Arguments:
obs_space (gym.Space): Input observation space.
num_outputs (int): Output tensor must be of size
[BATCH_SIZE, num_outputs].
options (dict): Dictionary of model options.
"""
nn.Module.__init__(self)
self.obs_space = obs_space
self.num_outputs = num_outputs
self.options = options
@PublicAPI
def forward(self, input_dict, hidden_state):
"""Wraps _forward() to unpack flattened Dict and Tuple observations."""
input_dict["obs"] = input_dict["obs"].float() # TODO(ekl): avoid cast
input_dict["obs_flat"] = input_dict["obs"]
input_dict["obs"] = restore_original_dimensions(
input_dict["obs"], self.obs_space, tensorlib=torch)
outputs, features, vf, h = self._forward(input_dict, hidden_state)
return outputs, features, vf, h
@PublicAPI
def state_init(self):
"""Returns a list of initial hidden state tensors, if any."""
return []
@PublicAPI
def _forward(self, input_dict, hidden_state):
"""Forward pass for the model.
Prefer implementing this instead of forward() directly for proper
handling of Dict and Tuple observations.
Arguments:
input_dict (dict): Dictionary of tensor inputs, commonly
including "obs", "prev_action", "prev_reward", each of shape
[BATCH_SIZE, ...].
hidden_state (list): List of hidden state tensors, each of shape
[BATCH_SIZE, h_size].
Returns:
(outputs, feature_layer, values, state): Tensors of size
[BATCH_SIZE, num_outputs], [BATCH_SIZE, desired_feature_size],
[BATCH_SIZE], and [len(hidden_state), BATCH_SIZE, h_size].
"""
raise NotImplementedError

View file

@ -2,19 +2,27 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch.nn as nn
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import PublicAPI
class TorchModelV2(ModelV2):
"""Torch version of ModelV2."""
@PublicAPI
class TorchModelV2(ModelV2, nn.Module):
"""Torch version of ModelV2.
def __init__(self, obs_space, action_space, output_spec, model_config,
Note that this class by itself is not a valid model unless you
implement forward() in a subclass."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
ModelV2.__init__(
self,
obs_space,
action_space,
output_spec,
num_outputs,
model_config,
name,
framework="torch")
nn.Module.__init__(self)

View file

@ -4,19 +4,22 @@ from __future__ import print_function
import torch.nn as nn
from ray.rllib.models.torch.model import TorchModel
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import normc_initializer, valid_padding, \
SlimConv2d, SlimFC
from ray.rllib.models.visionnet import _get_filter_config
from ray.rllib.utils.annotations import override
class VisionNetwork(TorchModel):
class VisionNetwork(TorchModelV2):
"""Generic vision network."""
def __init__(self, obs_space, num_outputs, options):
TorchModel.__init__(self, obs_space, num_outputs, options)
filters = options.get("conv_filters")
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(VisionNetwork, self).__init__(obs_space, action_space,
num_outputs, model_config, name)
filters = model_config.get("conv_filters")
if not filters:
filters = _get_filter_config(obs_space.shape)
layers = []
@ -40,13 +43,19 @@ class VisionNetwork(TorchModel):
out_channels, num_outputs, initializer=nn.init.xavier_uniform_)
self._value_branch = SlimFC(
out_channels, 1, initializer=normc_initializer())
self._cur_value = None
@override(TorchModel)
def _forward(self, input_dict, hidden_state):
features = self._hidden_layers(input_dict["obs"])
@override(TorchModelV2)
def forward(self, input_dict, state, seq_lens):
features = self._hidden_layers(input_dict["obs"].float())
logits = self._logits(features)
value = self._value_branch(features).squeeze(1)
return logits, features, value, hidden_state
self._cur_value = self._value_branch(features).squeeze(1)
return logits, state
@override(TorchModelV2)
def value_function(self):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value
def _hidden_layers(self, obs):
res = self._convs(obs.permute(0, 3, 1, 2)) # switch to channel-major

View file

@ -10,6 +10,7 @@ from ray.rllib.utils import try_import_tf
tf = try_import_tf()
# TODO(ekl) rewrite this using ModelV2
class VisionNetwork(Model):
"""Generic vision network."""

View file

@ -76,14 +76,14 @@ class TorchPolicy(Policy):
input_dict["prev_actions"] = prev_action_batch
if prev_reward_batch:
input_dict["prev_rewards"] = prev_reward_batch
model_out = self._model(input_dict, state_batches)
logits, _, vf, state = model_out
model_out = self._model(input_dict, state_batches, [1])
logits, state = model_out
action_dist = self._action_dist_cls(logits)
actions = action_dist.sample()
return (actions.cpu().numpy(),
[h.cpu().numpy() for h in state],
self.extra_action_out(input_dict, state_batches,
model_out))
self._model))
@override(Policy)
def learn_on_batch(self, postprocessed_batch):
@ -145,20 +145,20 @@ class TorchPolicy(Policy):
@override(Policy)
def get_initial_state(self):
return [s.numpy() for s in self._model.state_init()]
return [s.numpy() for s in self._model.get_initial_state()]
def extra_grad_process(self):
"""Allow subclass to do extra processing on gradients and
return processing info."""
return {}
def extra_action_out(self, input_dict, state_batches, model_out):
def extra_action_out(self, input_dict, state_batches, model):
"""Returns dict of extra info to include in experience batch.
Arguments:
input_dict (dict): Dict of model input tensors.
state_batches (list): List of state tensors.
model_out (list): Outputs of the policy model module."""
model (TorchModelV2): Reference to the model."""
return {}
def extra_grad_info(self, batch_tensors):

View file

@ -74,8 +74,12 @@ def build_torch_policy(name,
else:
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"], torch=True)
self.model = ModelCatalog.get_torch_model(
obs_space, logit_dim, self.config["model"])
self.model = ModelCatalog.get_model_v2(
obs_space,
action_space,
logit_dim,
self.config["model"],
framework="torch")
TorchPolicy.__init__(self, obs_space, action_space, self.model,
loss_fn, self.dist_class)
@ -101,13 +105,13 @@ def build_torch_policy(name,
return TorchPolicy.extra_grad_process(self)
@override(TorchPolicy)
def extra_action_out(self, input_dict, state_batches, model_out):
def extra_action_out(self, input_dict, state_batches, model):
if extra_action_out_fn:
return extra_action_out_fn(self, input_dict, state_batches,
model_out)
model)
else:
return TorchPolicy.extra_action_out(self, input_dict,
state_batches, model_out)
state_batches, model)
@override(TorchPolicy)
def optimizer(self):

View file

@ -19,7 +19,7 @@ from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.model import Model
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
from ray.rllib.models.torch.model import TorchModel
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.rollout import rollout
from ray.rllib.tests.test_external_env import SimpleServing
from ray.tune.registry import register_env
@ -133,16 +133,18 @@ class InvalidModel2(Model):
return tf.constant(0), tf.constant(0)
class TorchSpyModel(TorchModel):
class TorchSpyModel(TorchModelV2):
capture_index = 0
def __init__(self, obs_space, num_outputs, options):
TorchModel.__init__(self, obs_space, num_outputs, options)
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(TorchSpyModel, self).__init__(obs_space, action_space,
num_outputs, model_config, name)
self.fc = FullyConnectedNetwork(
obs_space.original_space.spaces["sensors"].spaces["position"],
num_outputs, options)
action_space, num_outputs, model_config, name)
def _forward(self, input_dict, hidden_state):
def forward(self, input_dict, state, seq_lens):
pos = input_dict["obs"]["sensors"]["position"].numpy()
front_cam = input_dict["obs"]["sensors"]["front_cam"][0].numpy()
task = input_dict["obs"]["inner_state"]["job_status"]["task"].numpy()
@ -153,7 +155,10 @@ class TorchSpyModel(TorchModel):
TorchSpyModel.capture_index += 1
return self.fc({
"obs": input_dict["obs"]["sensors"]["position"]
}, hidden_state)
}, state, seq_lens)
def value_function(self):
return self.fc.value_function()
class DictSpyModel(Model):