ray/rllib/contrib/alpha_zero/models/custom_torch_models.py

109 lines
3.7 KiB
Python

from abc import ABC
import numpy as np
from ray.rllib.models.modelv2 import restore_original_dimensions
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()
def convert_to_tensor(arr):
tensor = torch.from_numpy(np.asarray(arr))
if tensor.dtype == torch.double:
tensor = tensor.float()
return tensor
class ActorCriticModel(TorchModelV2, nn.Module, ABC):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
self.preprocessor = get_preprocessor(obs_space.original_space)(
obs_space.original_space)
self.shared_layers = None
self.actor_layers = None
self.critic_layers = None
self._value_out = None
def forward(self, input_dict, state, seq_lens):
x = input_dict["obs"]
x = self.shared_layers(x)
# actor outputs
logits = self.actor_layers(x)
# compute value
self._value_out = self.critic_layers(x)
return logits, None
def value_function(self):
return self._value_out
def compute_priors_and_value(self, obs):
obs = convert_to_tensor([self.preprocessor.transform(obs)])
input_dict = restore_original_dimensions(obs, self.obs_space, "torch")
with torch.no_grad():
model_out = self.forward(input_dict, None, [1])
logits, _ = model_out
value = self.value_function()
logits, value = torch.squeeze(logits), torch.squeeze(value)
priors = nn.Softmax(dim=-1)(logits)
priors = priors.cpu().numpy()
value = value.cpu().numpy()
return priors, value
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
class ConvNetModel(ActorCriticModel):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
ActorCriticModel.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
in_channels = model_config["custom_model_config"]["in_channels"]
feature_dim = model_config["custom_model_config"]["feature_dim"]
self.shared_layers = nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=4, stride=2),
nn.Conv2d(32, 64, kernel_size=2, stride=1),
nn.Conv2d(64, 64, kernel_size=2, stride=1), Flatten(),
nn.Linear(1024, feature_dim))
self.actor_layers = nn.Sequential(
nn.Linear(in_features=feature_dim, out_features=action_space.n))
self.critic_layers = nn.Sequential(
nn.Linear(in_features=feature_dim, out_features=1))
self._value_out = None
class DenseModel(ActorCriticModel):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
ActorCriticModel.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
self.shared_layers = nn.Sequential(
nn.Linear(
in_features=obs_space.original_space["obs"].shape[0],
out_features=256), nn.Linear(
in_features=256, out_features=256))
self.actor_layers = nn.Sequential(
nn.Linear(in_features=256, out_features=action_space.n))
self.critic_layers = nn.Sequential(
nn.Linear(in_features=256, out_features=1))
self._value_out = None