ray/rllib/offline/estimators/fqe_torch_model.py
2022-08-09 16:12:30 -07:00

246 lines
9.6 KiB
Python

from ray.rllib.models.utils import get_initializer
from ray.rllib.policy import Policy
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.annotations import is_overridden
from ray.rllib.utils.typing import ModelConfigDict, TensorType
from gym.spaces import Discrete
torch, nn = try_import_torch()
# TODO: Create a config object for FQE and unify it with the RLModule API
@DeveloperAPI
class FQETorchModel:
"""Pytorch implementation of the Fitted Q-Evaluation (FQE) model from
https://arxiv.org/abs/1911.06854
"""
def __init__(
self,
policy: Policy,
gamma: float,
model: ModelConfigDict = None,
n_iters: int = 1,
lr: float = 1e-3,
min_loss_threshold: float = 1e-4,
clip_grad_norm: float = 100.0,
minibatch_size: int = None,
polyak_coef: float = 1.0,
) -> None:
"""
Args:
policy: Policy to evaluate.
gamma: Discount factor of the environment.
model: The ModelConfigDict for self.q_model, defaults to:
{
"fcnet_hiddens": [8, 8],
"fcnet_activation": "relu",
"vf_share_layers": True,
},
n_iters: Number of gradient steps to run on batch, defaults to 1
lr: Learning rate for Adam optimizer
min_loss_threshold: Early stopping if mean loss < min_loss_threshold
clip_grad_norm: Clip loss gradients to this maximum value
minibatch_size: Minibatch size for training Q-function;
if None, train on the whole batch
polyak_coef: Polyak averaging factor for target Q-function
"""
self.policy = policy
assert isinstance(
policy.action_space, Discrete
), f"{self.__class__.__name__} only supports discrete action spaces!"
self.gamma = gamma
self.observation_space = policy.observation_space
self.action_space = policy.action_space
if model is None:
model = {
"fcnet_hiddens": [32, 32, 32],
"fcnet_activation": "relu",
"vf_share_layers": True,
}
self.device = self.policy.device
self.q_model: TorchModelV2 = ModelCatalog.get_model_v2(
self.observation_space,
self.action_space,
self.action_space.n,
model,
framework="torch",
name="TorchQModel",
).to(self.device)
self.target_q_model: TorchModelV2 = ModelCatalog.get_model_v2(
self.observation_space,
self.action_space,
self.action_space.n,
model,
framework="torch",
name="TargetTorchQModel",
).to(self.device)
self.n_iters = n_iters
self.lr = lr
self.min_loss_threshold = min_loss_threshold
self.clip_grad_norm = clip_grad_norm
self.minibatch_size = minibatch_size
self.polyak_coef = polyak_coef
self.optimizer = torch.optim.Adam(self.q_model.variables(), self.lr)
initializer = get_initializer("xavier_uniform", framework="torch")
# Hard update target
self.update_target(polyak_coef=1.0)
def f(m):
if isinstance(m, nn.Linear):
initializer(m.weight)
self.initializer = f
def train(self, batch: SampleBatch) -> TensorType:
"""Trains self.q_model using FQE loss on given batch.
Args:
batch: A SampleBatch of episodes to train on
Returns:
A list of losses for each training iteration
"""
losses = []
minibatch_size = self.minibatch_size or batch.count
# Copy batch for shuffling
batch = batch.copy(shallow=True)
for _ in range(self.n_iters):
minibatch_losses = []
batch.shuffle()
for idx in range(0, batch.count, minibatch_size):
minibatch = batch[idx : idx + minibatch_size]
obs = torch.tensor(minibatch[SampleBatch.OBS], device=self.device)
actions = torch.tensor(
minibatch[SampleBatch.ACTIONS],
device=self.device,
dtype=int,
)
rewards = torch.tensor(
minibatch[SampleBatch.REWARDS], device=self.device
)
next_obs = torch.tensor(
minibatch[SampleBatch.NEXT_OBS], device=self.device
)
dones = torch.tensor(
minibatch[SampleBatch.DONES], device=self.device, dtype=float
)
# Compute Q-values for current obs
q_values, _ = self.q_model({"obs": obs}, [], None)
q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
next_action_probs = self._compute_action_probs(next_obs)
# Compute Q-values for next obs
with torch.no_grad():
next_q_values, _ = self.target_q_model({"obs": next_obs}, [], None)
# Compute estimated state value next_v = E_{a ~ pi(s)} [Q(next_obs,a)]
next_v = torch.sum(next_q_values * next_action_probs, axis=-1)
targets = rewards + (1 - dones) * self.gamma * next_v
loss = (targets - q_acts) ** 2
loss = torch.mean(loss)
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad.clip_grad_norm_(
self.q_model.variables(), self.clip_grad_norm
)
self.optimizer.step()
minibatch_losses.append(loss.item())
iter_loss = sum(minibatch_losses) / len(minibatch_losses)
losses.append(iter_loss)
if iter_loss < self.min_loss_threshold:
break
self.update_target()
return losses
def estimate_q(self, batch: SampleBatch) -> TensorType:
obs = torch.tensor(batch[SampleBatch.OBS], device=self.device)
with torch.no_grad():
q_values, _ = self.q_model({"obs": obs}, [], None)
actions = torch.tensor(
batch[SampleBatch.ACTIONS], device=self.device, dtype=int
)
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
return q_values
def estimate_v(self, batch: SampleBatch) -> TensorType:
obs = torch.tensor(batch[SampleBatch.OBS], device=self.device)
with torch.no_grad():
q_values, _ = self.q_model({"obs": obs}, [], None)
# Compute pi(a | s) for each action a in policy.action_space
action_probs = self._compute_action_probs(obs)
v_values = torch.sum(q_values * action_probs, axis=-1)
return v_values
def update_target(self, polyak_coef=None):
# Update_target will be called periodically to copy Q network to
# target Q network, using (soft) polyak_coef-synching.
polyak_coef = polyak_coef or self.polyak_coef
model_state_dict = self.q_model.state_dict()
# Support partial (soft) synching.
# If polyak_coef == 1.0: Full sync from Q-model to target Q-model.
target_state_dict = self.target_q_model.state_dict()
model_state_dict = {
k: polyak_coef * model_state_dict[k] + (1 - polyak_coef) * v
for k, v in target_state_dict.items()
}
self.target_q_model.load_state_dict(model_state_dict)
def _compute_action_probs(self, obs: TensorType) -> TensorType:
"""Compute action distribution over the action space.
Args:
obs: A tensor of observations of shape (batch_size * obs_dim)
Returns:
action_probs: A tensor of action probabilities
of shape (batch_size * action_dim)
"""
input_dict = {SampleBatch.OBS: obs}
seq_lens = torch.ones(len(obs), device=self.device, dtype=int)
state_batches = []
if is_overridden(self.policy.action_distribution_fn):
try:
# TorchPolicyV2 function signature
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
self.policy.model,
obs_batch=input_dict,
state_batches=state_batches,
seq_lens=seq_lens,
explore=False,
is_training=False,
)
except TypeError:
# TorchPolicyV1 function signature for compatibility with DQN
# TODO: Remove this once DQNTorchPolicy is migrated to PolicyV2
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
self.policy,
self.policy.model,
input_dict=input_dict,
state_batches=state_batches,
seq_lens=seq_lens,
explore=False,
is_training=False,
)
else:
dist_class = self.policy.dist_class
dist_inputs, _ = self.policy.model(input_dict, state_batches, seq_lens)
action_dist = dist_class(dist_inputs, self.policy.model)
assert isinstance(
action_dist.dist, torch.distributions.categorical.Categorical
), "FQE only supports Categorical or MultiCategorical distributions!"
action_probs = action_dist.dist.probs
return action_probs