2019-12-07 21:08:40 +01:00
|
|
|
from abc import ABC
|
|
|
|
import numpy as np
|
2019-12-30 15:27:32 -05:00
|
|
|
|
2020-09-06 10:58:00 +02:00
|
|
|
from ray.rllib.models.modelv2 import restore_original_dimensions
|
2019-12-07 21:08:40 +01:00
|
|
|
from ray.rllib.models.preprocessors import get_preprocessor
|
|
|
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
2020-06-16 08:52:20 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
torch, nn = try_import_torch()
|
2019-12-07 21:08:40 +01:00
|
|
|
|
|
|
|
|
|
|
|
def convert_to_tensor(arr):
|
|
|
|
tensor = torch.from_numpy(np.asarray(arr))
|
|
|
|
if tensor.dtype == torch.double:
|
|
|
|
tensor = tensor.float()
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
class ActorCriticModel(TorchModelV2, nn.Module, ABC):
|
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
|
|
|
name):
|
|
|
|
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
|
|
|
|
model_config, name)
|
|
|
|
nn.Module.__init__(self)
|
|
|
|
|
|
|
|
self.preprocessor = get_preprocessor(obs_space.original_space)(
|
|
|
|
obs_space.original_space)
|
|
|
|
|
|
|
|
self.shared_layers = None
|
|
|
|
self.actor_layers = None
|
|
|
|
self.critic_layers = None
|
|
|
|
|
|
|
|
self._value_out = None
|
|
|
|
|
|
|
|
def forward(self, input_dict, state, seq_lens):
|
|
|
|
x = input_dict["obs"]
|
|
|
|
x = self.shared_layers(x)
|
|
|
|
# actor outputs
|
|
|
|
logits = self.actor_layers(x)
|
|
|
|
|
|
|
|
# compute value
|
|
|
|
self._value_out = self.critic_layers(x)
|
|
|
|
return logits, None
|
|
|
|
|
|
|
|
def value_function(self):
|
|
|
|
return self._value_out
|
|
|
|
|
|
|
|
def compute_priors_and_value(self, obs):
|
|
|
|
obs = convert_to_tensor([self.preprocessor.transform(obs)])
|
|
|
|
input_dict = restore_original_dimensions(obs, self.obs_space, "torch")
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
model_out = self.forward(input_dict, None, [1])
|
|
|
|
logits, _ = model_out
|
|
|
|
value = self.value_function()
|
|
|
|
logits, value = torch.squeeze(logits), torch.squeeze(value)
|
|
|
|
priors = nn.Softmax(dim=-1)(logits)
|
|
|
|
|
|
|
|
priors = priors.cpu().numpy()
|
|
|
|
value = value.cpu().numpy()
|
|
|
|
|
|
|
|
return priors, value
|
|
|
|
|
|
|
|
|
|
|
|
class Flatten(nn.Module):
|
|
|
|
def forward(self, input):
|
|
|
|
return input.view(input.size(0), -1)
|
|
|
|
|
|
|
|
|
|
|
|
class ConvNetModel(ActorCriticModel):
|
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
|
|
|
name):
|
|
|
|
ActorCriticModel.__init__(self, obs_space, action_space, num_outputs,
|
|
|
|
model_config, name)
|
|
|
|
|
2020-05-27 10:19:47 +02:00
|
|
|
in_channels = model_config["custom_model_config"]["in_channels"]
|
|
|
|
feature_dim = model_config["custom_model_config"]["feature_dim"]
|
2019-12-07 21:08:40 +01:00
|
|
|
|
|
|
|
self.shared_layers = nn.Sequential(
|
|
|
|
nn.Conv2d(in_channels, 32, kernel_size=4, stride=2),
|
|
|
|
nn.Conv2d(32, 64, kernel_size=2, stride=1),
|
|
|
|
nn.Conv2d(64, 64, kernel_size=2, stride=1), Flatten(),
|
|
|
|
nn.Linear(1024, feature_dim))
|
|
|
|
|
|
|
|
self.actor_layers = nn.Sequential(
|
|
|
|
nn.Linear(in_features=feature_dim, out_features=action_space.n))
|
|
|
|
|
|
|
|
self.critic_layers = nn.Sequential(
|
|
|
|
nn.Linear(in_features=feature_dim, out_features=1))
|
|
|
|
|
|
|
|
self._value_out = None
|
|
|
|
|
|
|
|
|
|
|
|
class DenseModel(ActorCriticModel):
|
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
|
|
|
name):
|
|
|
|
ActorCriticModel.__init__(self, obs_space, action_space, num_outputs,
|
|
|
|
model_config, name)
|
|
|
|
|
|
|
|
self.shared_layers = nn.Sequential(
|
|
|
|
nn.Linear(
|
|
|
|
in_features=obs_space.original_space["obs"].shape[0],
|
|
|
|
out_features=256), nn.Linear(
|
|
|
|
in_features=256, out_features=256))
|
|
|
|
self.actor_layers = nn.Sequential(
|
|
|
|
nn.Linear(in_features=256, out_features=action_space.n))
|
|
|
|
self.critic_layers = nn.Sequential(
|
|
|
|
nn.Linear(in_features=256, out_features=1))
|
|
|
|
self._value_out = None
|