2020-05-08 08:20:18 +02:00
|
|
|
from gym.spaces import Box
|
|
|
|
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
|
|
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
2020-05-18 17:26:40 +02:00
|
|
|
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
|
2020-05-08 08:20:18 +02:00
|
|
|
from ray.rllib.models.torch.misc import SlimFC
|
|
|
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|
|
|
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
|
|
|
|
from ray.rllib.utils.annotations import override
|
|
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2020-05-08 08:20:18 +02:00
|
|
|
torch, nn = try_import_torch()
|
|
|
|
|
|
|
|
|
|
|
|
class CentralizedCriticModel(TFModelV2):
|
|
|
|
"""Multi-agent model that implements a centralized value function."""
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
|
2020-05-08 08:20:18 +02:00
|
|
|
super(CentralizedCriticModel, self).__init__(
|
2022-01-29 18:41:57 -08:00
|
|
|
obs_space, action_space, num_outputs, model_config, name
|
|
|
|
)
|
2020-05-08 08:20:18 +02:00
|
|
|
# Base of the model
|
2022-01-29 18:41:57 -08:00
|
|
|
self.model = FullyConnectedNetwork(
|
|
|
|
obs_space, action_space, num_outputs, model_config, name
|
|
|
|
)
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
# Central VF maps (obs, opp_obs, opp_act) -> vf_pred
|
2022-01-29 18:41:57 -08:00
|
|
|
obs = tf.keras.layers.Input(shape=(6,), name="obs")
|
|
|
|
opp_obs = tf.keras.layers.Input(shape=(6,), name="opp_obs")
|
|
|
|
opp_act = tf.keras.layers.Input(shape=(2,), name="opp_act")
|
|
|
|
concat_obs = tf.keras.layers.Concatenate(axis=1)([obs, opp_obs, opp_act])
|
2020-05-08 08:20:18 +02:00
|
|
|
central_vf_dense = tf.keras.layers.Dense(
|
2022-01-29 18:41:57 -08:00
|
|
|
16, activation=tf.nn.tanh, name="c_vf_dense"
|
|
|
|
)(concat_obs)
|
|
|
|
central_vf_out = tf.keras.layers.Dense(1, activation=None, name="c_vf_out")(
|
|
|
|
central_vf_dense
|
|
|
|
)
|
2020-05-08 08:20:18 +02:00
|
|
|
self.central_vf = tf.keras.Model(
|
2022-01-29 18:41:57 -08:00
|
|
|
inputs=[obs, opp_obs, opp_act], outputs=central_vf_out
|
|
|
|
)
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
@override(ModelV2)
|
|
|
|
def forward(self, input_dict, state, seq_lens):
|
|
|
|
return self.model.forward(input_dict, state, seq_lens)
|
|
|
|
|
|
|
|
def central_value_function(self, obs, opponent_obs, opponent_actions):
|
|
|
|
return tf.reshape(
|
2022-01-29 18:41:57 -08:00
|
|
|
self.central_vf(
|
|
|
|
[obs, opponent_obs, tf.one_hot(tf.cast(opponent_actions, tf.int32), 2)]
|
|
|
|
),
|
|
|
|
[-1],
|
|
|
|
)
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
@override(ModelV2)
|
|
|
|
def value_function(self):
|
|
|
|
return self.model.value_function() # not used
|
|
|
|
|
|
|
|
|
|
|
|
class YetAnotherCentralizedCriticModel(TFModelV2):
|
|
|
|
"""Multi-agent model that implements a centralized value function.
|
|
|
|
|
|
|
|
It assumes the observation is a dict with 'own_obs' and 'opponent_obs', the
|
|
|
|
former of which can be used for computing actions (i.e., decentralized
|
|
|
|
execution), and the latter for optimization (i.e., centralized learning).
|
|
|
|
|
|
|
|
This model has two parts:
|
|
|
|
- An action model that looks at just 'own_obs' to compute actions
|
|
|
|
- A value model that also looks at the 'opponent_obs' / 'opponent_action'
|
|
|
|
to compute the value (it does this by using the 'obs_flat' tensor).
|
|
|
|
"""
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
|
2020-05-08 08:20:18 +02:00
|
|
|
super(YetAnotherCentralizedCriticModel, self).__init__(
|
2022-01-29 18:41:57 -08:00
|
|
|
obs_space, action_space, num_outputs, model_config, name
|
|
|
|
)
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
self.action_model = FullyConnectedNetwork(
|
2022-01-29 18:41:57 -08:00
|
|
|
Box(low=0, high=1, shape=(6,)), # one-hot encoded Discrete(6)
|
2020-05-08 08:20:18 +02:00
|
|
|
action_space,
|
|
|
|
num_outputs,
|
|
|
|
model_config,
|
2022-01-29 18:41:57 -08:00
|
|
|
name + "_action",
|
|
|
|
)
|
2020-05-08 08:20:18 +02:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
self.value_model = FullyConnectedNetwork(
|
|
|
|
obs_space, action_space, 1, model_config, name + "_vf"
|
|
|
|
)
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
def forward(self, input_dict, state, seq_lens):
|
2022-01-29 18:41:57 -08:00
|
|
|
self._value_out, _ = self.value_model(
|
|
|
|
{"obs": input_dict["obs_flat"]}, state, seq_lens
|
|
|
|
)
|
|
|
|
return self.action_model({"obs": input_dict["obs"]["own_obs"]}, state, seq_lens)
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
def value_function(self):
|
|
|
|
return tf.reshape(self._value_out, [-1])
|
|
|
|
|
|
|
|
|
|
|
|
class TorchCentralizedCriticModel(TorchModelV2, nn.Module):
|
|
|
|
"""Multi-agent model that implements a centralized VF."""
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
|
|
|
|
TorchModelV2.__init__(
|
|
|
|
self, obs_space, action_space, num_outputs, model_config, name
|
|
|
|
)
|
2020-05-08 08:20:18 +02:00
|
|
|
nn.Module.__init__(self)
|
|
|
|
|
|
|
|
# Base of the model
|
2022-01-29 18:41:57 -08:00
|
|
|
self.model = TorchFC(obs_space, action_space, num_outputs, model_config, name)
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
# Central VF maps (obs, opp_obs, opp_act) -> vf_pred
|
|
|
|
input_size = 6 + 6 + 2 # obs + opp_obs + opp_act
|
2020-05-12 08:23:10 +02:00
|
|
|
self.central_vf = nn.Sequential(
|
|
|
|
SlimFC(input_size, 16, activation_fn=nn.Tanh),
|
|
|
|
SlimFC(16, 1),
|
|
|
|
)
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
@override(ModelV2)
|
|
|
|
def forward(self, input_dict, state, seq_lens):
|
|
|
|
model_out, _ = self.model(input_dict, state, seq_lens)
|
|
|
|
return model_out, []
|
|
|
|
|
|
|
|
def central_value_function(self, obs, opponent_obs, opponent_actions):
|
2022-01-29 18:41:57 -08:00
|
|
|
input_ = torch.cat(
|
|
|
|
[
|
|
|
|
obs,
|
|
|
|
opponent_obs,
|
|
|
|
torch.nn.functional.one_hot(opponent_actions.long(), 2).float(),
|
|
|
|
],
|
|
|
|
1,
|
|
|
|
)
|
2020-05-12 08:23:10 +02:00
|
|
|
return torch.reshape(self.central_vf(input_), [-1])
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
@override(ModelV2)
|
|
|
|
def value_function(self):
|
|
|
|
return self.model.value_function() # not used
|
|
|
|
|
|
|
|
|
|
|
|
class YetAnotherTorchCentralizedCriticModel(TorchModelV2, nn.Module):
|
|
|
|
"""Multi-agent model that implements a centralized value function.
|
|
|
|
|
|
|
|
It assumes the observation is a dict with 'own_obs' and 'opponent_obs', the
|
|
|
|
former of which can be used for computing actions (i.e., decentralized
|
|
|
|
execution), and the latter for optimization (i.e., centralized learning).
|
|
|
|
|
|
|
|
This model has two parts:
|
|
|
|
- An action model that looks at just 'own_obs' to compute actions
|
|
|
|
- A value model that also looks at the 'opponent_obs' / 'opponent_action'
|
|
|
|
to compute the value (it does this by using the 'obs_flat' tensor).
|
|
|
|
"""
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
|
|
|
|
TorchModelV2.__init__(
|
|
|
|
self, obs_space, action_space, num_outputs, model_config, name
|
|
|
|
)
|
2020-05-08 08:20:18 +02:00
|
|
|
nn.Module.__init__(self)
|
|
|
|
|
|
|
|
self.action_model = TorchFC(
|
2022-01-29 18:41:57 -08:00
|
|
|
Box(low=0, high=1, shape=(6,)), # one-hot encoded Discrete(6)
|
2020-05-08 08:20:18 +02:00
|
|
|
action_space,
|
|
|
|
num_outputs,
|
|
|
|
model_config,
|
2022-01-29 18:41:57 -08:00
|
|
|
name + "_action",
|
|
|
|
)
|
2020-05-08 08:20:18 +02:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
self.value_model = TorchFC(
|
|
|
|
obs_space, action_space, 1, model_config, name + "_vf"
|
|
|
|
)
|
2020-05-14 10:15:50 +02:00
|
|
|
self._model_in = None
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
def forward(self, input_dict, state, seq_lens):
|
2020-05-14 10:15:50 +02:00
|
|
|
# Store model-input for possible `value_function()` call.
|
|
|
|
self._model_in = [input_dict["obs_flat"], state, seq_lens]
|
2022-01-29 18:41:57 -08:00
|
|
|
return self.action_model({"obs": input_dict["obs"]["own_obs"]}, state, seq_lens)
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
def value_function(self):
|
2022-01-29 18:41:57 -08:00
|
|
|
value_out, _ = self.value_model(
|
|
|
|
{"obs": self._model_in[0]}, self._model_in[1], self._model_in[2]
|
|
|
|
)
|
2020-05-14 10:15:50 +02:00
|
|
|
return torch.reshape(value_out, [-1])
|