mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
246 lines
9.6 KiB
Python
246 lines
9.6 KiB
Python
from ray.rllib.models.utils import get_initializer
|
|
from ray.rllib.policy import Policy
|
|
|
|
from ray.rllib.models.catalog import ModelCatalog
|
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.annotations import is_overridden
|
|
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
|
from gym.spaces import Discrete
|
|
|
|
torch, nn = try_import_torch()
|
|
|
|
# TODO: Create a config object for FQE and unify it with the RLModule API
|
|
|
|
|
|
@DeveloperAPI
|
|
class FQETorchModel:
|
|
"""Pytorch implementation of the Fitted Q-Evaluation (FQE) model from
|
|
https://arxiv.org/abs/1911.06854
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
policy: Policy,
|
|
gamma: float,
|
|
model: ModelConfigDict = None,
|
|
n_iters: int = 1,
|
|
lr: float = 1e-3,
|
|
min_loss_threshold: float = 1e-4,
|
|
clip_grad_norm: float = 100.0,
|
|
minibatch_size: int = None,
|
|
polyak_coef: float = 1.0,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
policy: Policy to evaluate.
|
|
gamma: Discount factor of the environment.
|
|
model: The ModelConfigDict for self.q_model, defaults to:
|
|
{
|
|
"fcnet_hiddens": [8, 8],
|
|
"fcnet_activation": "relu",
|
|
"vf_share_layers": True,
|
|
},
|
|
n_iters: Number of gradient steps to run on batch, defaults to 1
|
|
lr: Learning rate for Adam optimizer
|
|
min_loss_threshold: Early stopping if mean loss < min_loss_threshold
|
|
clip_grad_norm: Clip loss gradients to this maximum value
|
|
minibatch_size: Minibatch size for training Q-function;
|
|
if None, train on the whole batch
|
|
polyak_coef: Polyak averaging factor for target Q-function
|
|
"""
|
|
self.policy = policy
|
|
assert isinstance(
|
|
policy.action_space, Discrete
|
|
), f"{self.__class__.__name__} only supports discrete action spaces!"
|
|
self.gamma = gamma
|
|
self.observation_space = policy.observation_space
|
|
self.action_space = policy.action_space
|
|
|
|
if model is None:
|
|
model = {
|
|
"fcnet_hiddens": [32, 32, 32],
|
|
"fcnet_activation": "relu",
|
|
"vf_share_layers": True,
|
|
}
|
|
|
|
self.device = self.policy.device
|
|
self.q_model: TorchModelV2 = ModelCatalog.get_model_v2(
|
|
self.observation_space,
|
|
self.action_space,
|
|
self.action_space.n,
|
|
model,
|
|
framework="torch",
|
|
name="TorchQModel",
|
|
).to(self.device)
|
|
|
|
self.target_q_model: TorchModelV2 = ModelCatalog.get_model_v2(
|
|
self.observation_space,
|
|
self.action_space,
|
|
self.action_space.n,
|
|
model,
|
|
framework="torch",
|
|
name="TargetTorchQModel",
|
|
).to(self.device)
|
|
|
|
self.n_iters = n_iters
|
|
self.lr = lr
|
|
self.min_loss_threshold = min_loss_threshold
|
|
self.clip_grad_norm = clip_grad_norm
|
|
self.minibatch_size = minibatch_size
|
|
self.polyak_coef = polyak_coef
|
|
self.optimizer = torch.optim.Adam(self.q_model.variables(), self.lr)
|
|
initializer = get_initializer("xavier_uniform", framework="torch")
|
|
# Hard update target
|
|
self.update_target(polyak_coef=1.0)
|
|
|
|
def f(m):
|
|
if isinstance(m, nn.Linear):
|
|
initializer(m.weight)
|
|
|
|
self.initializer = f
|
|
|
|
def train(self, batch: SampleBatch) -> TensorType:
|
|
"""Trains self.q_model using FQE loss on given batch.
|
|
|
|
Args:
|
|
batch: A SampleBatch of episodes to train on
|
|
|
|
Returns:
|
|
A list of losses for each training iteration
|
|
"""
|
|
losses = []
|
|
minibatch_size = self.minibatch_size or batch.count
|
|
# Copy batch for shuffling
|
|
batch = batch.copy(shallow=True)
|
|
for _ in range(self.n_iters):
|
|
minibatch_losses = []
|
|
batch.shuffle()
|
|
for idx in range(0, batch.count, minibatch_size):
|
|
minibatch = batch[idx : idx + minibatch_size]
|
|
obs = torch.tensor(minibatch[SampleBatch.OBS], device=self.device)
|
|
actions = torch.tensor(
|
|
minibatch[SampleBatch.ACTIONS],
|
|
device=self.device,
|
|
dtype=int,
|
|
)
|
|
rewards = torch.tensor(
|
|
minibatch[SampleBatch.REWARDS], device=self.device
|
|
)
|
|
next_obs = torch.tensor(
|
|
minibatch[SampleBatch.NEXT_OBS], device=self.device
|
|
)
|
|
dones = torch.tensor(
|
|
minibatch[SampleBatch.DONES], device=self.device, dtype=float
|
|
)
|
|
|
|
# Compute Q-values for current obs
|
|
q_values, _ = self.q_model({"obs": obs}, [], None)
|
|
q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
|
|
|
|
next_action_probs = self._compute_action_probs(next_obs)
|
|
|
|
# Compute Q-values for next obs
|
|
with torch.no_grad():
|
|
next_q_values, _ = self.target_q_model({"obs": next_obs}, [], None)
|
|
|
|
# Compute estimated state value next_v = E_{a ~ pi(s)} [Q(next_obs,a)]
|
|
next_v = torch.sum(next_q_values * next_action_probs, axis=-1)
|
|
targets = rewards + (1 - dones) * self.gamma * next_v
|
|
loss = (targets - q_acts) ** 2
|
|
loss = torch.mean(loss)
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
nn.utils.clip_grad.clip_grad_norm_(
|
|
self.q_model.variables(), self.clip_grad_norm
|
|
)
|
|
self.optimizer.step()
|
|
minibatch_losses.append(loss.item())
|
|
iter_loss = sum(minibatch_losses) / len(minibatch_losses)
|
|
losses.append(iter_loss)
|
|
if iter_loss < self.min_loss_threshold:
|
|
break
|
|
self.update_target()
|
|
return losses
|
|
|
|
def estimate_q(self, batch: SampleBatch) -> TensorType:
|
|
obs = torch.tensor(batch[SampleBatch.OBS], device=self.device)
|
|
with torch.no_grad():
|
|
q_values, _ = self.q_model({"obs": obs}, [], None)
|
|
actions = torch.tensor(
|
|
batch[SampleBatch.ACTIONS], device=self.device, dtype=int
|
|
)
|
|
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
|
|
return q_values
|
|
|
|
def estimate_v(self, batch: SampleBatch) -> TensorType:
|
|
obs = torch.tensor(batch[SampleBatch.OBS], device=self.device)
|
|
with torch.no_grad():
|
|
q_values, _ = self.q_model({"obs": obs}, [], None)
|
|
# Compute pi(a | s) for each action a in policy.action_space
|
|
action_probs = self._compute_action_probs(obs)
|
|
v_values = torch.sum(q_values * action_probs, axis=-1)
|
|
return v_values
|
|
|
|
def update_target(self, polyak_coef=None):
|
|
# Update_target will be called periodically to copy Q network to
|
|
# target Q network, using (soft) polyak_coef-synching.
|
|
polyak_coef = polyak_coef or self.polyak_coef
|
|
model_state_dict = self.q_model.state_dict()
|
|
# Support partial (soft) synching.
|
|
# If polyak_coef == 1.0: Full sync from Q-model to target Q-model.
|
|
target_state_dict = self.target_q_model.state_dict()
|
|
model_state_dict = {
|
|
k: polyak_coef * model_state_dict[k] + (1 - polyak_coef) * v
|
|
for k, v in target_state_dict.items()
|
|
}
|
|
|
|
self.target_q_model.load_state_dict(model_state_dict)
|
|
|
|
def _compute_action_probs(self, obs: TensorType) -> TensorType:
|
|
"""Compute action distribution over the action space.
|
|
|
|
Args:
|
|
obs: A tensor of observations of shape (batch_size * obs_dim)
|
|
|
|
Returns:
|
|
action_probs: A tensor of action probabilities
|
|
of shape (batch_size * action_dim)
|
|
"""
|
|
input_dict = {SampleBatch.OBS: obs}
|
|
seq_lens = torch.ones(len(obs), device=self.device, dtype=int)
|
|
state_batches = []
|
|
if is_overridden(self.policy.action_distribution_fn):
|
|
try:
|
|
# TorchPolicyV2 function signature
|
|
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
|
|
self.policy.model,
|
|
obs_batch=input_dict,
|
|
state_batches=state_batches,
|
|
seq_lens=seq_lens,
|
|
explore=False,
|
|
is_training=False,
|
|
)
|
|
except TypeError:
|
|
# TorchPolicyV1 function signature for compatibility with DQN
|
|
# TODO: Remove this once DQNTorchPolicy is migrated to PolicyV2
|
|
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
|
|
self.policy,
|
|
self.policy.model,
|
|
input_dict=input_dict,
|
|
state_batches=state_batches,
|
|
seq_lens=seq_lens,
|
|
explore=False,
|
|
is_training=False,
|
|
)
|
|
else:
|
|
dist_class = self.policy.dist_class
|
|
dist_inputs, _ = self.policy.model(input_dict, state_batches, seq_lens)
|
|
action_dist = dist_class(dist_inputs, self.policy.model)
|
|
assert isinstance(
|
|
action_dist.dist, torch.distributions.categorical.Categorical
|
|
), "FQE only supports Categorical or MultiCategorical distributions!"
|
|
action_probs = action_dist.dist.probs
|
|
return action_probs
|