mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
222 lines
8.6 KiB
Python
222 lines
8.6 KiB
Python
from ray.rllib.models.utils import get_initializer
|
|
from ray.rllib.policy import Policy
|
|
from typing import List, Union
|
|
|
|
from ray.rllib.models.catalog import ModelCatalog
|
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
|
|
|
torch, nn = try_import_torch()
|
|
|
|
|
|
@DeveloperAPI
|
|
class FQETorchModel:
|
|
"""Pytorch implementation of the Fitted Q-Evaluation (FQE) model from
|
|
https://arxiv.org/pdf/1911.06854.pdf
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
policy: Policy,
|
|
gamma: float,
|
|
model: ModelConfigDict = None,
|
|
n_iters: int = 160,
|
|
lr: float = 1e-3,
|
|
delta: float = 1e-4,
|
|
clip_grad_norm: float = 100.0,
|
|
batch_size: int = 32,
|
|
tau: float = 0.05,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
policy: Policy to evaluate.
|
|
gamma: Discount factor of the environment.
|
|
# The ModelConfigDict for self.q_model
|
|
model = {
|
|
"fcnet_hiddens": [8, 8],
|
|
"fcnet_activation": "relu",
|
|
"vf_share_layers": True,
|
|
},
|
|
# Maximum number of training iterations to run on the batch
|
|
n_iters = 160,
|
|
# Learning rate for Q-function optimizer
|
|
lr = 1e-3,
|
|
# Early stopping if the mean loss < delta
|
|
delta = 1e-4,
|
|
# Clip gradients to this maximum value
|
|
clip_grad_norm = 100.0,
|
|
# Minibatch size for training Q-function
|
|
batch_size = 32,
|
|
# Polyak averaging factor for target Q-function
|
|
tau = 0.05
|
|
"""
|
|
self.policy = policy
|
|
self.gamma = gamma
|
|
self.observation_space = policy.observation_space
|
|
self.action_space = policy.action_space
|
|
|
|
if model is None:
|
|
model = {
|
|
"fcnet_hiddens": [8, 8],
|
|
"fcnet_activation": "relu",
|
|
"vf_share_layers": True,
|
|
}
|
|
|
|
self.device = self.policy.device
|
|
self.q_model: TorchModelV2 = ModelCatalog.get_model_v2(
|
|
self.observation_space,
|
|
self.action_space,
|
|
self.action_space.n,
|
|
model,
|
|
framework="torch",
|
|
name="TorchQModel",
|
|
).to(self.device)
|
|
self.target_q_model: TorchModelV2 = ModelCatalog.get_model_v2(
|
|
self.observation_space,
|
|
self.action_space,
|
|
self.action_space.n,
|
|
model,
|
|
framework="torch",
|
|
name="TargetTorchQModel",
|
|
).to(self.device)
|
|
self.n_iters = n_iters
|
|
self.lr = lr
|
|
self.delta = delta
|
|
self.clip_grad_norm = clip_grad_norm
|
|
self.batch_size = batch_size
|
|
self.tau = tau
|
|
self.optimizer = torch.optim.Adam(self.q_model.variables(), self.lr)
|
|
initializer = get_initializer("xavier_uniform", framework="torch")
|
|
# Hard update target
|
|
self.update_target(tau=1.0)
|
|
|
|
def f(m):
|
|
if isinstance(m, nn.Linear):
|
|
initializer(m.weight)
|
|
|
|
self.initializer = f
|
|
|
|
def reset(self) -> None:
|
|
"""Resets/Reinintializes the model weights."""
|
|
self.q_model.apply(self.initializer)
|
|
|
|
def train_q(self, batch: SampleBatch) -> TensorType:
|
|
"""Trains self.q_model using FQE loss on given batch.
|
|
|
|
Args:
|
|
batch: A SampleBatch of episodes to train on
|
|
|
|
Returns:
|
|
A list of losses for each training iteration
|
|
"""
|
|
losses = []
|
|
for _ in range(self.n_iters):
|
|
minibatch_losses = []
|
|
batch.shuffle()
|
|
for idx in range(0, batch.count, self.batch_size):
|
|
minibatch = batch[idx : idx + self.batch_size]
|
|
obs = torch.tensor(minibatch[SampleBatch.OBS], device=self.device)
|
|
actions = torch.tensor(
|
|
minibatch[SampleBatch.ACTIONS], device=self.device
|
|
)
|
|
rewards = torch.tensor(
|
|
minibatch[SampleBatch.REWARDS], device=self.device
|
|
)
|
|
next_obs = torch.tensor(
|
|
minibatch[SampleBatch.NEXT_OBS], device=self.device
|
|
)
|
|
dones = torch.tensor(minibatch[SampleBatch.DONES], device=self.device)
|
|
|
|
# Neccessary if policy uses recurrent/attention model
|
|
num_state_inputs = 0
|
|
for k in batch.keys():
|
|
if k.startswith("state_in_"):
|
|
num_state_inputs += 1
|
|
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
|
|
|
|
# Compute action_probs for next_obs as in FQE
|
|
all_actions = torch.zeros([minibatch.count, self.policy.action_space.n])
|
|
all_actions[:] = torch.arange(self.policy.action_space.n)
|
|
next_action_prob = self.policy.compute_log_likelihoods(
|
|
actions=all_actions.T,
|
|
obs_batch=next_obs,
|
|
state_batches=[minibatch[k] for k in state_keys],
|
|
prev_action_batch=minibatch[SampleBatch.ACTIONS],
|
|
prev_reward_batch=minibatch[SampleBatch.REWARDS],
|
|
actions_normalized=False,
|
|
)
|
|
next_action_prob = (
|
|
torch.exp(next_action_prob.T).to(self.device).detach()
|
|
)
|
|
|
|
q_values, _ = self.q_model({"obs": obs}, [], None)
|
|
q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
|
|
with torch.no_grad():
|
|
next_q_values, _ = self.target_q_model({"obs": next_obs}, [], None)
|
|
next_v = torch.sum(next_q_values * next_action_prob, axis=-1)
|
|
targets = rewards + ~dones * self.gamma * next_v
|
|
loss = (targets - q_acts) ** 2
|
|
loss = torch.mean(loss)
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
nn.utils.clip_grad.clip_grad_norm_(
|
|
self.q_model.variables(), self.clip_grad_norm
|
|
)
|
|
self.optimizer.step()
|
|
minibatch_losses.append(loss.item())
|
|
iter_loss = sum(minibatch_losses) / len(minibatch_losses)
|
|
losses.append(iter_loss)
|
|
if iter_loss < self.delta:
|
|
break
|
|
self.update_target()
|
|
return losses
|
|
|
|
def estimate_q(
|
|
self,
|
|
obs: Union[TensorType, List[TensorType]],
|
|
actions: Union[TensorType, List[TensorType]] = None,
|
|
) -> TensorType:
|
|
"""Given `obs`, a list or array or tensor of observations,
|
|
compute the Q-values for `obs` for all actions in the action space.
|
|
If `actions` is not None, return the Q-values for the actions provided,
|
|
else return Q-values for all actions for each observation in `obs`.
|
|
"""
|
|
obs = torch.tensor(obs, device=self.device)
|
|
q_values, _ = self.q_model({"obs": obs}, [], None)
|
|
if actions is not None:
|
|
actions = torch.tensor(actions, device=self.device, dtype=int)
|
|
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
|
|
return q_values.detach()
|
|
|
|
def estimate_v(
|
|
self,
|
|
obs: Union[TensorType, List[TensorType]],
|
|
action_probs: Union[TensorType, List[TensorType]],
|
|
) -> TensorType:
|
|
"""Given `obs`, compute q-values for all actions in the action space
|
|
for each observations s in `obs`, then multiply this by `action_probs`,
|
|
the probability distribution over actions for each state s to give the
|
|
state value V(s) = sum_A pi(a|s)Q(s,a).
|
|
"""
|
|
q_values = self.estimate_q(obs)
|
|
action_probs = torch.tensor(action_probs, device=self.device)
|
|
v_values = torch.sum(q_values * action_probs, axis=-1)
|
|
return v_values.detach()
|
|
|
|
def update_target(self, tau=None):
|
|
# Update_target will be called periodically to copy Q network to
|
|
# target Q network, using (soft) tau-synching.
|
|
tau = tau or self.tau
|
|
model_state_dict = self.q_model.state_dict()
|
|
# Support partial (soft) synching.
|
|
# If tau == 1.0: Full sync from Q-model to target Q-model.
|
|
target_state_dict = self.target_q_model.state_dict()
|
|
model_state_dict = {
|
|
k: tau * model_state_dict[k] + (1 - tau) * v
|
|
for k, v in target_state_dict.items()
|
|
}
|
|
|
|
self.target_q_model.load_state_dict(model_state_dict)
|