mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[rllib] ModelV2 support for pytorch (#5249)
This commit is contained in:
parent
40395acadf
commit
bf9199ad77
20 changed files with 222 additions and 220 deletions
|
@ -191,35 +191,44 @@ Similarly, you can create and register custom PyTorch models for use with PyTorc
|
|||
import ray
|
||||
from ray.rllib.agents import a3c
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.torch.model import TorchModel
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
|
||||
class CustomTorchModel(TorchModel):
|
||||
class CustomTorchModel(TorchModelV2):
|
||||
|
||||
def __init__(self, obs_space, num_outputs, options):
|
||||
TorchModel.__init__(self, obs_space, num_outputs, options)
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
super(CustomTorchModel, self).__init__(
|
||||
obs_space, action_space, num_outputs, model_config, name)
|
||||
... # setup hidden layers
|
||||
|
||||
def _forward(self, input_dict, hidden_state):
|
||||
"""Forward pass for the model.
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
"""Call the model with the given input tensors and state.
|
||||
|
||||
Prefer implementing this instead of forward() directly for proper
|
||||
handling of Dict and Tuple observations.
|
||||
Any complex observations (dicts, tuples, etc.) will be unpacked by
|
||||
__call__ before being passed to forward(). To access the flattened
|
||||
observation tensor, refer to input_dict["obs_flat"].
|
||||
|
||||
This method can be called any number of times. In eager execution,
|
||||
each call to forward() will eagerly evaluate the model. In symbolic
|
||||
execution, each call to forward creates a computation graph that
|
||||
operates over the variables of this model (i.e., shares weights).
|
||||
|
||||
Custom models should override this instead of __call__.
|
||||
|
||||
Arguments:
|
||||
input_dict (dict): Dictionary of tensor inputs, commonly
|
||||
including "obs", "prev_action", "prev_reward", each of shape
|
||||
[BATCH_SIZE, ...].
|
||||
hidden_state (list): List of hidden state tensors, each of shape
|
||||
[BATCH_SIZE, h_size].
|
||||
input_dict (dict): dictionary of input tensors, including "obs",
|
||||
"obs_flat", "prev_action", "prev_reward", "is_training"
|
||||
state (list): list of state tensors with sizes matching those
|
||||
returned by get_initial_state + the batch dimension
|
||||
seq_lens (Tensor): 1d tensor holding input sequence lengths
|
||||
|
||||
Returns:
|
||||
(outputs, feature_layer, values, state): Tensors of size
|
||||
[BATCH_SIZE, num_outputs], [BATCH_SIZE, desired_feature_size],
|
||||
[BATCH_SIZE], and [len(hidden_state), BATCH_SIZE, h_size].
|
||||
(outputs, state): The model output tensor of size
|
||||
[BATCH, num_outputs]
|
||||
"""
|
||||
obs = input_dict["obs"]
|
||||
...
|
||||
return logits, features, value, hidden_state
|
||||
return logits, state
|
||||
|
||||
ModelCatalog.register_custom_model("my_model", CustomTorchModel)
|
||||
|
||||
|
|
|
@ -14,9 +14,10 @@ from ray.rllib.policy.torch_policy_template import build_torch_policy
|
|||
|
||||
|
||||
def actor_critic_loss(policy, batch_tensors):
|
||||
logits, _, values, _ = policy.model({
|
||||
logits, _ = policy.model({
|
||||
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
|
||||
}, [])
|
||||
}) # TODO(ekl) seq lens shouldn't be None
|
||||
values = policy.model.value_function()
|
||||
dist = policy.dist_class(logits)
|
||||
log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS])
|
||||
policy.entropy = dist.entropy().mean()
|
||||
|
@ -53,8 +54,8 @@ def add_advantages(policy,
|
|||
policy.config["lambda"])
|
||||
|
||||
|
||||
def model_value_predictions(policy, input_dict, state_batches, model_out):
|
||||
return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()}
|
||||
def model_value_predictions(policy, input_dict, state_batches, model):
|
||||
return {SampleBatch.VF_PREDS: model.value_function().cpu().numpy()}
|
||||
|
||||
|
||||
def apply_grad_clipping(policy):
|
||||
|
@ -74,8 +75,8 @@ class ValueNetworkMixin(object):
|
|||
def _value(self, obs):
|
||||
with self.lock:
|
||||
obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device)
|
||||
_, _, vf, _ = self.model({"obs": obs}, [])
|
||||
return vf.detach().cpu().numpy().squeeze()
|
||||
_ = self.model({"obs": obs}, [], [1])
|
||||
return self.model.value_function().detach().cpu().numpy().squeeze()
|
||||
|
||||
|
||||
A3CTorchPolicy = build_torch_policy(
|
||||
|
|
|
@ -10,9 +10,9 @@ from ray.rllib.policy.torch_policy_template import build_torch_policy
|
|||
|
||||
|
||||
def pg_torch_loss(policy, batch_tensors):
|
||||
logits, _, values, _ = policy.model({
|
||||
logits, _ = policy.model({
|
||||
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
|
||||
}, [])
|
||||
})
|
||||
action_dist = policy.dist_class(logits)
|
||||
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
|
||||
# save the error in the policy object
|
||||
|
|
|
@ -6,33 +6,35 @@ from torch import nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.models.torch.model import TorchModel
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class RNNModel(TorchModel):
|
||||
class RNNModel(TorchModelV2):
|
||||
"""The default RNN model for QMIX."""
|
||||
|
||||
def __init__(self, obs_space, num_outputs, options):
|
||||
TorchModel.__init__(self, obs_space, num_outputs, options)
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
super(RNNModel, self).__init__(obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
self.obs_size = _get_size(obs_space)
|
||||
self.rnn_hidden_dim = options["lstm_cell_size"]
|
||||
self.rnn_hidden_dim = model_config["lstm_cell_size"]
|
||||
self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
|
||||
self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
|
||||
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
|
||||
|
||||
@override(TorchModel)
|
||||
def state_init(self):
|
||||
@override(TorchModelV2)
|
||||
def get_initial_state(self):
|
||||
# make hidden states on same device as model
|
||||
return [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
|
||||
|
||||
@override(TorchModel)
|
||||
def _forward(self, input_dict, hidden_state):
|
||||
x = F.relu(self.fc1(input_dict["obs"]))
|
||||
@override(TorchModelV2)
|
||||
def forward(self, input_dict, hidden_state, seq_lens):
|
||||
x = F.relu(self.fc1(input_dict["obs_flat"].float()))
|
||||
h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim)
|
||||
h = self.rnn(x, h_in)
|
||||
q = self.fc2(h)
|
||||
return q, h, None, [h]
|
||||
return q, [h]
|
||||
|
||||
|
||||
def _get_size(obs_space):
|
||||
|
|
|
@ -65,7 +65,10 @@ class QMixLoss(nn.Module):
|
|||
|
||||
# Calculate estimated Q-Values
|
||||
mac_out = []
|
||||
h = [s.expand([B, self.n_agents, -1]) for s in self.model.state_init()]
|
||||
h = [
|
||||
s.expand([B, self.n_agents, -1])
|
||||
for s in self.model.get_initial_state()
|
||||
]
|
||||
for t in range(T):
|
||||
q, h = _mac(self.model, obs[:, t], h)
|
||||
mac_out.append(q)
|
||||
|
@ -79,7 +82,7 @@ class QMixLoss(nn.Module):
|
|||
target_mac_out = []
|
||||
target_h = [
|
||||
s.expand([B, self.n_agents, -1])
|
||||
for s in self.target_model.state_init()
|
||||
for s in self.target_model.get_initial_state()
|
||||
]
|
||||
for t in range(T):
|
||||
target_q, target_h = _mac(self.target_model, next_obs[:, t],
|
||||
|
@ -171,16 +174,23 @@ class QMixTorchPolicy(Policy):
|
|||
self.has_action_mask = False
|
||||
self.obs_size = _get_size(agent_obs_space)
|
||||
|
||||
self.model = ModelCatalog.get_torch_model(
|
||||
self.model = ModelCatalog.get_model_v2(
|
||||
agent_obs_space,
|
||||
action_space.spaces[0],
|
||||
self.n_actions,
|
||||
config["model"],
|
||||
default_model_cls=RNNModel)
|
||||
self.target_model = ModelCatalog.get_torch_model(
|
||||
framework="torch",
|
||||
name="model",
|
||||
default_model=RNNModel)
|
||||
|
||||
self.target_model = ModelCatalog.get_model_v2(
|
||||
agent_obs_space,
|
||||
action_space.spaces[0],
|
||||
self.n_actions,
|
||||
config["model"],
|
||||
default_model_cls=RNNModel)
|
||||
framework="torch",
|
||||
name="target_model",
|
||||
default_model=RNNModel)
|
||||
|
||||
# Setup the mixer network.
|
||||
# The global state is just the stacked agent observations for now.
|
||||
|
@ -320,7 +330,7 @@ class QMixTorchPolicy(Policy):
|
|||
def get_initial_state(self):
|
||||
return [
|
||||
s.expand([self.n_agents, -1]).numpy()
|
||||
for s in self.model.state_init()
|
||||
for s in self.model.get_initial_state()
|
||||
]
|
||||
|
||||
@override(Policy)
|
||||
|
@ -425,7 +435,7 @@ def _mac(model, obs, h):
|
|||
"""Forward pass of the multi-agent controller.
|
||||
|
||||
Arguments:
|
||||
model: TorchModel class
|
||||
model: TorchModelV2 class
|
||||
obs: Tensor of shape [B, n_agents, obs_size]
|
||||
h: List of tensors of shape [B, n_agents, h_size]
|
||||
|
||||
|
@ -436,6 +446,6 @@ def _mac(model, obs, h):
|
|||
B, n_agents = obs.size(0), obs.size(1)
|
||||
obs_flat = obs.reshape([B * n_agents, -1])
|
||||
h_flat = [s.reshape([B * n_agents, -1]) for s in h]
|
||||
q_flat, _, _, h_flat = model.forward({"obs": obs_flat}, h_flat)
|
||||
q_flat, h_flat = model({"obs": obs_flat}, h_flat, None)
|
||||
return q_flat.reshape(
|
||||
[B, n_agents, -1]), [s.reshape([B, n_agents, -1]) for s in h_flat]
|
||||
|
|
|
@ -15,9 +15,9 @@ parser.add_argument("--iters", type=int, default=200)
|
|||
|
||||
|
||||
def policy_gradient_loss(policy, batch_tensors):
|
||||
logits, _, values, _ = policy.model({
|
||||
logits, _ = policy.model({
|
||||
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
|
||||
}, [])
|
||||
})
|
||||
action_dist = policy.dist_class(logits)
|
||||
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
|
||||
return -batch_tensors[SampleBatch.REWARDS].dot(log_probs)
|
||||
|
|
|
@ -208,7 +208,7 @@ class ModelCatalog(object):
|
|||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
framework="tf",
|
||||
framework,
|
||||
name=None,
|
||||
model_interface=None,
|
||||
default_model=None,
|
||||
|
@ -240,29 +240,37 @@ class ModelCatalog(object):
|
|||
model_interface):
|
||||
raise ValueError("The given model must subclass",
|
||||
model_interface)
|
||||
created = set()
|
||||
|
||||
# Track and warn if variables were created but no registered
|
||||
def track_var_creation(next_creator, **kw):
|
||||
v = next_creator(**kw)
|
||||
created.add(v)
|
||||
return v
|
||||
if framework == "tf":
|
||||
created = set()
|
||||
|
||||
with tf.variable_creator_scope(track_var_creation):
|
||||
# Track and warn if vars were created but not registered
|
||||
def track_var_creation(next_creator, **kw):
|
||||
v = next_creator(**kw)
|
||||
created.add(v)
|
||||
return v
|
||||
|
||||
with tf.variable_creator_scope(track_var_creation):
|
||||
instance = model_cls(obs_space, action_space,
|
||||
num_outputs, model_config, name,
|
||||
**model_kwargs)
|
||||
registered = set(instance.variables())
|
||||
not_registered = set()
|
||||
for var in created:
|
||||
if var not in registered:
|
||||
not_registered.add(var)
|
||||
if not_registered:
|
||||
raise ValueError(
|
||||
"It looks like variables {} were created as part "
|
||||
"of {} but does not appear in model.variables() "
|
||||
"({}). Did you forget to call "
|
||||
"model.register_variables() on the variables in "
|
||||
"question?".format(not_registered, instance,
|
||||
registered))
|
||||
else:
|
||||
# no variable tracking
|
||||
instance = model_cls(obs_space, action_space, num_outputs,
|
||||
model_config, name, **model_kwargs)
|
||||
registered = set(instance.variables())
|
||||
not_registered = set()
|
||||
for var in created:
|
||||
if var not in registered:
|
||||
not_registered.add(var)
|
||||
if not_registered:
|
||||
raise ValueError(
|
||||
"It looks like variables {} were created as part of "
|
||||
"{} but does not appear in model.variables() ({}). "
|
||||
"Did you forget to call model.register_variables() "
|
||||
"on the variables in question?".format(
|
||||
not_registered, instance, registered))
|
||||
return instance
|
||||
|
||||
if framework == "tf":
|
||||
|
@ -271,8 +279,15 @@ class ModelCatalog(object):
|
|||
make_v1_wrapper(legacy_model_cls), model_interface)
|
||||
return wrapper(obs_space, action_space, num_outputs, model_config,
|
||||
name, **model_kwargs)
|
||||
|
||||
raise NotImplementedError("TODO: support {} models".format(framework))
|
||||
elif framework == "torch":
|
||||
if default_model:
|
||||
return default_model(obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
return ModelCatalog._get_default_torch_model_v2(
|
||||
obs_space, action_space, num_outputs, model_config, name)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Framework must be 'tf' or 'torch': {}".format(framework))
|
||||
|
||||
@staticmethod
|
||||
def _wrap_if_needed(model_cls, model_interface):
|
||||
|
@ -367,46 +382,32 @@ class ModelCatalog(object):
|
|||
num_outputs,
|
||||
options=None,
|
||||
default_model_cls=None):
|
||||
"""Returns a custom model for PyTorch algorithms.
|
||||
raise DeprecationWarning("Please use get_model_v2() instead.")
|
||||
|
||||
Args:
|
||||
obs_space (Space): The input observation space.
|
||||
num_outputs (int): The size of the output vector of the model.
|
||||
options (dict): Optional args to pass to the model constructor.
|
||||
default_model_cls (cls): Optional class to use if no custom model.
|
||||
|
||||
Returns:
|
||||
model (models.Model): Neural network model.
|
||||
"""
|
||||
def _get_default_torch_model_v2(obs_space, action_space, num_outputs,
|
||||
model_config, name):
|
||||
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
|
||||
PyTorchFCNet)
|
||||
from ray.rllib.models.torch.visionnet import (VisionNetwork as
|
||||
PyTorchVisionNet)
|
||||
|
||||
options = options or MODEL_DEFAULTS
|
||||
model_config = model_config or MODEL_DEFAULTS
|
||||
|
||||
if options.get("custom_model"):
|
||||
model = options["custom_model"]
|
||||
logger.debug("Using custom torch model {}".format(model))
|
||||
return _global_registry.get(RLLIB_MODEL,
|
||||
model)(obs_space, num_outputs, options)
|
||||
|
||||
if options.get("use_lstm"):
|
||||
if model_config.get("use_lstm"):
|
||||
raise NotImplementedError(
|
||||
"LSTM auto-wrapping not implemented for torch")
|
||||
|
||||
if default_model_cls:
|
||||
return default_model_cls(obs_space, num_outputs, options)
|
||||
|
||||
if isinstance(obs_space, gym.spaces.Discrete):
|
||||
obs_rank = 1
|
||||
else:
|
||||
obs_rank = len(obs_space.shape)
|
||||
|
||||
if obs_rank > 1:
|
||||
return PyTorchVisionNet(obs_space, num_outputs, options)
|
||||
return PyTorchVisionNet(obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
|
||||
return PyTorchFCNet(obs_space, num_outputs, options)
|
||||
return PyTorchFCNet(obs_space, action_space, num_outputs, model_config,
|
||||
name)
|
||||
|
||||
@staticmethod
|
||||
@DeveloperAPI
|
||||
|
|
|
@ -10,6 +10,7 @@ from ray.rllib.utils import try_import_tf
|
|||
tf = try_import_tf()
|
||||
|
||||
|
||||
# TODO(ekl) rewrite this using ModelV2
|
||||
class FullyConnectedNetwork(Model):
|
||||
"""Generic fully connected network."""
|
||||
|
||||
|
|
|
@ -14,10 +14,12 @@ from ray.rllib.utils import try_import_tf
|
|||
tf = try_import_tf()
|
||||
|
||||
|
||||
@PublicAPI
|
||||
# Deprecated: use TFModelV2 instead
|
||||
class Model(object):
|
||||
"""Defines an abstract network model for use with RLlib.
|
||||
|
||||
This class is deprecated: please use TFModelV2 instead.
|
||||
|
||||
Models convert input tensors to a number of output features. These features
|
||||
can then be interpreted by ActionDistribution classes to determine
|
||||
e.g. agent action values.
|
||||
|
|
|
@ -3,8 +3,10 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.models.model import restore_original_dimensions
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class ModelV2(object):
|
||||
"""Defines a Keras-style abstract network model for use with RLlib.
|
||||
|
||||
|
@ -39,7 +41,6 @@ class ModelV2(object):
|
|||
self.model_config = model_config
|
||||
self.name = name or "default_model"
|
||||
self.framework = framework
|
||||
self.var_list = []
|
||||
|
||||
def get_initial_state(self):
|
||||
"""Get the initial recurrent state values for the model.
|
||||
|
@ -118,19 +119,7 @@ class ModelV2(object):
|
|||
"""
|
||||
return {}
|
||||
|
||||
def register_variables(self, variables):
|
||||
"""Register the given list of variables with this model."""
|
||||
self.var_list.extend(variables)
|
||||
|
||||
def variables(self):
|
||||
"""Returns the list of variables for this model."""
|
||||
return list(self.var_list)
|
||||
|
||||
def trainable_variables(self):
|
||||
"""Returns the list of trainable variables for this model."""
|
||||
return [v for v in self.variables() if v.trainable]
|
||||
|
||||
def __call__(self, input_dict, state, seq_lens):
|
||||
def __call__(self, input_dict, state=None, seq_lens=None):
|
||||
"""Call the model with the given input tensors and state.
|
||||
|
||||
This is the method used by RLlib to execute the forward pass. It calls
|
||||
|
@ -156,7 +145,7 @@ class ModelV2(object):
|
|||
restored["obs"] = restore_original_dimensions(
|
||||
input_dict["obs"], self.obs_space, self.framework)
|
||||
restored["obs_flat"] = input_dict["obs"]
|
||||
outputs, state = self.forward(restored, state, seq_lens)
|
||||
outputs, state = self.forward(restored, state or [], seq_lens)
|
||||
|
||||
try:
|
||||
shape = outputs.shape
|
||||
|
|
|
@ -98,7 +98,7 @@ def make_v1_wrapper(legacy_model_cls):
|
|||
"Cannot get update ops before wrapped v1 model init")
|
||||
return list(self._update_ops)
|
||||
|
||||
@override(ModelV2)
|
||||
@override(TFModelV2)
|
||||
def variables(self):
|
||||
var_list = super(ModelV1Wrapper, self).variables()
|
||||
for v in scope_vars(self.variable_scope):
|
||||
|
|
|
@ -3,13 +3,18 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class TFModelV2(ModelV2):
|
||||
"""TF version of ModelV2."""
|
||||
"""TF version of ModelV2.
|
||||
|
||||
Note that this class by itself is not a valid model unless you
|
||||
implement forward() in a subclass."""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
|
@ -21,9 +26,22 @@ class TFModelV2(ModelV2):
|
|||
model_config,
|
||||
name,
|
||||
framework="tf")
|
||||
self.var_list = []
|
||||
|
||||
def update_ops(self):
|
||||
"""Return the list of update ops for this model.
|
||||
|
||||
For example, this should include any BatchNorm update ops."""
|
||||
return []
|
||||
|
||||
def register_variables(self, variables):
|
||||
"""Register the given list of variables with this model."""
|
||||
self.var_list.extend(variables)
|
||||
|
||||
def variables(self):
|
||||
"""Returns the list of variables for this model."""
|
||||
return list(self.var_list)
|
||||
|
||||
def trainable_variables(self):
|
||||
"""Returns the list of trainable variables for this model."""
|
||||
return [v for v in self.variables() if v.trainable]
|
||||
|
|
|
@ -6,7 +6,7 @@ import logging
|
|||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
from ray.rllib.models.torch.model import TorchModel
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.torch.misc import normc_initializer, SlimFC, \
|
||||
_get_activation_fn
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -14,13 +14,16 @@ from ray.rllib.utils.annotations import override
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FullyConnectedNetwork(TorchModel):
|
||||
class FullyConnectedNetwork(TorchModelV2):
|
||||
"""Generic fully connected network."""
|
||||
|
||||
def __init__(self, obs_space, num_outputs, options):
|
||||
TorchModel.__init__(self, obs_space, num_outputs, options)
|
||||
hiddens = options.get("fcnet_hiddens")
|
||||
activation = _get_activation_fn(options.get("fcnet_activation"))
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
super(FullyConnectedNetwork, self).__init__(
|
||||
obs_space, action_space, num_outputs, model_config, name)
|
||||
|
||||
hiddens = model_config.get("fcnet_hiddens")
|
||||
activation = _get_activation_fn(model_config.get("fcnet_activation"))
|
||||
logger.debug("Constructing fcnet {} {}".format(hiddens, activation))
|
||||
layers = []
|
||||
last_layer_size = np.product(obs_space.shape)
|
||||
|
@ -45,13 +48,17 @@ class FullyConnectedNetwork(TorchModel):
|
|||
out_size=1,
|
||||
initializer=normc_initializer(1.0),
|
||||
activation_fn=None)
|
||||
self._cur_value = None
|
||||
|
||||
@override(nn.Module)
|
||||
def forward(self, input_dict, hidden_state):
|
||||
# Note that we override forward() and not _forward() to get the
|
||||
# flattened obs here.
|
||||
obs = input_dict["obs"]
|
||||
@override(TorchModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
obs = input_dict["obs_flat"]
|
||||
features = self._hidden_layers(obs.reshape(obs.shape[0], -1))
|
||||
logits = self._logits(features)
|
||||
value = self._value_branch(features).squeeze(1)
|
||||
return logits, features, value, hidden_state
|
||||
self._cur_value = self._value_branch(features).squeeze(1)
|
||||
return logits, state
|
||||
|
||||
@override(TorchModelV2)
|
||||
def value_function(self):
|
||||
assert self._cur_value is not None, "must call forward() first"
|
||||
return self._cur_value
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ray.rllib.models.model import restore_original_dimensions
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
# TODO(ekl) rewrite using modelv2
|
||||
@PublicAPI
|
||||
class TorchModel(nn.Module):
|
||||
"""Defines an abstract network model for use with RLlib / PyTorch."""
|
||||
|
||||
def __init__(self, obs_space, num_outputs, options):
|
||||
"""All custom RLlib torch models must support this constructor.
|
||||
|
||||
Arguments:
|
||||
obs_space (gym.Space): Input observation space.
|
||||
num_outputs (int): Output tensor must be of size
|
||||
[BATCH_SIZE, num_outputs].
|
||||
options (dict): Dictionary of model options.
|
||||
"""
|
||||
nn.Module.__init__(self)
|
||||
self.obs_space = obs_space
|
||||
self.num_outputs = num_outputs
|
||||
self.options = options
|
||||
|
||||
@PublicAPI
|
||||
def forward(self, input_dict, hidden_state):
|
||||
"""Wraps _forward() to unpack flattened Dict and Tuple observations."""
|
||||
input_dict["obs"] = input_dict["obs"].float() # TODO(ekl): avoid cast
|
||||
input_dict["obs_flat"] = input_dict["obs"]
|
||||
input_dict["obs"] = restore_original_dimensions(
|
||||
input_dict["obs"], self.obs_space, tensorlib=torch)
|
||||
outputs, features, vf, h = self._forward(input_dict, hidden_state)
|
||||
return outputs, features, vf, h
|
||||
|
||||
@PublicAPI
|
||||
def state_init(self):
|
||||
"""Returns a list of initial hidden state tensors, if any."""
|
||||
return []
|
||||
|
||||
@PublicAPI
|
||||
def _forward(self, input_dict, hidden_state):
|
||||
"""Forward pass for the model.
|
||||
|
||||
Prefer implementing this instead of forward() directly for proper
|
||||
handling of Dict and Tuple observations.
|
||||
|
||||
Arguments:
|
||||
input_dict (dict): Dictionary of tensor inputs, commonly
|
||||
including "obs", "prev_action", "prev_reward", each of shape
|
||||
[BATCH_SIZE, ...].
|
||||
hidden_state (list): List of hidden state tensors, each of shape
|
||||
[BATCH_SIZE, h_size].
|
||||
|
||||
Returns:
|
||||
(outputs, feature_layer, values, state): Tensors of size
|
||||
[BATCH_SIZE, num_outputs], [BATCH_SIZE, desired_feature_size],
|
||||
[BATCH_SIZE], and [len(hidden_state), BATCH_SIZE, h_size].
|
||||
"""
|
||||
raise NotImplementedError
|
|
@ -2,19 +2,27 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
class TorchModelV2(ModelV2):
|
||||
"""Torch version of ModelV2."""
|
||||
@PublicAPI
|
||||
class TorchModelV2(ModelV2, nn.Module):
|
||||
"""Torch version of ModelV2.
|
||||
|
||||
def __init__(self, obs_space, action_space, output_spec, model_config,
|
||||
Note that this class by itself is not a valid model unless you
|
||||
implement forward() in a subclass."""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
ModelV2.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
output_spec,
|
||||
num_outputs,
|
||||
model_config,
|
||||
name,
|
||||
framework="torch")
|
||||
nn.Module.__init__(self)
|
||||
|
|
|
@ -4,19 +4,22 @@ from __future__ import print_function
|
|||
|
||||
import torch.nn as nn
|
||||
|
||||
from ray.rllib.models.torch.model import TorchModel
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.models.torch.misc import normc_initializer, valid_padding, \
|
||||
SlimConv2d, SlimFC
|
||||
from ray.rllib.models.visionnet import _get_filter_config
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class VisionNetwork(TorchModel):
|
||||
class VisionNetwork(TorchModelV2):
|
||||
"""Generic vision network."""
|
||||
|
||||
def __init__(self, obs_space, num_outputs, options):
|
||||
TorchModel.__init__(self, obs_space, num_outputs, options)
|
||||
filters = options.get("conv_filters")
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
super(VisionNetwork, self).__init__(obs_space, action_space,
|
||||
num_outputs, model_config, name)
|
||||
|
||||
filters = model_config.get("conv_filters")
|
||||
if not filters:
|
||||
filters = _get_filter_config(obs_space.shape)
|
||||
layers = []
|
||||
|
@ -40,13 +43,19 @@ class VisionNetwork(TorchModel):
|
|||
out_channels, num_outputs, initializer=nn.init.xavier_uniform_)
|
||||
self._value_branch = SlimFC(
|
||||
out_channels, 1, initializer=normc_initializer())
|
||||
self._cur_value = None
|
||||
|
||||
@override(TorchModel)
|
||||
def _forward(self, input_dict, hidden_state):
|
||||
features = self._hidden_layers(input_dict["obs"])
|
||||
@override(TorchModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
features = self._hidden_layers(input_dict["obs"].float())
|
||||
logits = self._logits(features)
|
||||
value = self._value_branch(features).squeeze(1)
|
||||
return logits, features, value, hidden_state
|
||||
self._cur_value = self._value_branch(features).squeeze(1)
|
||||
return logits, state
|
||||
|
||||
@override(TorchModelV2)
|
||||
def value_function(self):
|
||||
assert self._cur_value is not None, "must call forward() first"
|
||||
return self._cur_value
|
||||
|
||||
def _hidden_layers(self, obs):
|
||||
res = self._convs(obs.permute(0, 3, 1, 2)) # switch to channel-major
|
||||
|
|
|
@ -10,6 +10,7 @@ from ray.rllib.utils import try_import_tf
|
|||
tf = try_import_tf()
|
||||
|
||||
|
||||
# TODO(ekl) rewrite this using ModelV2
|
||||
class VisionNetwork(Model):
|
||||
"""Generic vision network."""
|
||||
|
||||
|
|
|
@ -76,14 +76,14 @@ class TorchPolicy(Policy):
|
|||
input_dict["prev_actions"] = prev_action_batch
|
||||
if prev_reward_batch:
|
||||
input_dict["prev_rewards"] = prev_reward_batch
|
||||
model_out = self._model(input_dict, state_batches)
|
||||
logits, _, vf, state = model_out
|
||||
model_out = self._model(input_dict, state_batches, [1])
|
||||
logits, state = model_out
|
||||
action_dist = self._action_dist_cls(logits)
|
||||
actions = action_dist.sample()
|
||||
return (actions.cpu().numpy(),
|
||||
[h.cpu().numpy() for h in state],
|
||||
self.extra_action_out(input_dict, state_batches,
|
||||
model_out))
|
||||
self._model))
|
||||
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
|
@ -145,20 +145,20 @@ class TorchPolicy(Policy):
|
|||
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
return [s.numpy() for s in self._model.state_init()]
|
||||
return [s.numpy() for s in self._model.get_initial_state()]
|
||||
|
||||
def extra_grad_process(self):
|
||||
"""Allow subclass to do extra processing on gradients and
|
||||
return processing info."""
|
||||
return {}
|
||||
|
||||
def extra_action_out(self, input_dict, state_batches, model_out):
|
||||
def extra_action_out(self, input_dict, state_batches, model):
|
||||
"""Returns dict of extra info to include in experience batch.
|
||||
|
||||
Arguments:
|
||||
input_dict (dict): Dict of model input tensors.
|
||||
state_batches (list): List of state tensors.
|
||||
model_out (list): Outputs of the policy model module."""
|
||||
model (TorchModelV2): Reference to the model."""
|
||||
return {}
|
||||
|
||||
def extra_grad_info(self, batch_tensors):
|
||||
|
|
|
@ -74,8 +74,12 @@ def build_torch_policy(name,
|
|||
else:
|
||||
self.dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"], torch=True)
|
||||
self.model = ModelCatalog.get_torch_model(
|
||||
obs_space, logit_dim, self.config["model"])
|
||||
self.model = ModelCatalog.get_model_v2(
|
||||
obs_space,
|
||||
action_space,
|
||||
logit_dim,
|
||||
self.config["model"],
|
||||
framework="torch")
|
||||
|
||||
TorchPolicy.__init__(self, obs_space, action_space, self.model,
|
||||
loss_fn, self.dist_class)
|
||||
|
@ -101,13 +105,13 @@ def build_torch_policy(name,
|
|||
return TorchPolicy.extra_grad_process(self)
|
||||
|
||||
@override(TorchPolicy)
|
||||
def extra_action_out(self, input_dict, state_batches, model_out):
|
||||
def extra_action_out(self, input_dict, state_batches, model):
|
||||
if extra_action_out_fn:
|
||||
return extra_action_out_fn(self, input_dict, state_batches,
|
||||
model_out)
|
||||
model)
|
||||
else:
|
||||
return TorchPolicy.extra_action_out(self, input_dict,
|
||||
state_batches, model_out)
|
||||
state_batches, model)
|
||||
|
||||
@override(TorchPolicy)
|
||||
def optimizer(self):
|
||||
|
|
|
@ -19,7 +19,7 @@ from ray.rllib.env.vector_env import VectorEnv
|
|||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.torch.model import TorchModel
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.rollout import rollout
|
||||
from ray.rllib.tests.test_external_env import SimpleServing
|
||||
from ray.tune.registry import register_env
|
||||
|
@ -133,16 +133,18 @@ class InvalidModel2(Model):
|
|||
return tf.constant(0), tf.constant(0)
|
||||
|
||||
|
||||
class TorchSpyModel(TorchModel):
|
||||
class TorchSpyModel(TorchModelV2):
|
||||
capture_index = 0
|
||||
|
||||
def __init__(self, obs_space, num_outputs, options):
|
||||
TorchModel.__init__(self, obs_space, num_outputs, options)
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
super(TorchSpyModel, self).__init__(obs_space, action_space,
|
||||
num_outputs, model_config, name)
|
||||
self.fc = FullyConnectedNetwork(
|
||||
obs_space.original_space.spaces["sensors"].spaces["position"],
|
||||
num_outputs, options)
|
||||
action_space, num_outputs, model_config, name)
|
||||
|
||||
def _forward(self, input_dict, hidden_state):
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
pos = input_dict["obs"]["sensors"]["position"].numpy()
|
||||
front_cam = input_dict["obs"]["sensors"]["front_cam"][0].numpy()
|
||||
task = input_dict["obs"]["inner_state"]["job_status"]["task"].numpy()
|
||||
|
@ -153,7 +155,10 @@ class TorchSpyModel(TorchModel):
|
|||
TorchSpyModel.capture_index += 1
|
||||
return self.fc({
|
||||
"obs": input_dict["obs"]["sensors"]["position"]
|
||||
}, hidden_state)
|
||||
}, state, seq_lens)
|
||||
|
||||
def value_function(self):
|
||||
return self.fc.value_function()
|
||||
|
||||
|
||||
class DictSpyModel(Model):
|
||||
|
|
Loading…
Add table
Reference in a new issue