mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
124 lines
4.7 KiB
Python
124 lines
4.7 KiB
Python
from typing import Dict, List, Type, Union
|
|
|
|
import ray
|
|
from ray.rllib.algorithms.marwil.marwil_tf_policy import PostprocessAdvantages
|
|
from ray.rllib.evaluation.postprocessing import Postprocessing
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.policy.torch_mixins import ValueNetworkMixin
|
|
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.numpy import convert_to_numpy
|
|
from ray.rllib.utils.torch_utils import apply_grad_clipping, explained_variance
|
|
from ray.rllib.utils.typing import TensorType
|
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
class MARWILTorchPolicy(ValueNetworkMixin, PostprocessAdvantages, TorchPolicyV2):
|
|
"""PyTorch policy class used with MarwilTrainer."""
|
|
|
|
def __init__(self, observation_space, action_space, config):
|
|
config = dict(ray.rllib.algorithms.marwil.marwil.DEFAULT_CONFIG, **config)
|
|
|
|
TorchPolicyV2.__init__(
|
|
self,
|
|
observation_space,
|
|
action_space,
|
|
config,
|
|
max_seq_len=config["model"]["max_seq_len"],
|
|
)
|
|
|
|
ValueNetworkMixin.__init__(self, config)
|
|
PostprocessAdvantages.__init__(self)
|
|
|
|
# Not needed for pure BC.
|
|
if config["beta"] != 0.0:
|
|
# Set up a torch-var for the squared moving avg. advantage norm.
|
|
self._moving_average_sqd_adv_norm = torch.tensor(
|
|
[config["moving_average_sqd_adv_norm_start"]],
|
|
dtype=torch.float32,
|
|
requires_grad=False,
|
|
).to(self.device)
|
|
|
|
# TODO: Don't require users to call this manually.
|
|
self._initialize_loss_from_dummy_batch()
|
|
|
|
@override(TorchPolicyV2)
|
|
def loss(
|
|
self,
|
|
model: ModelV2,
|
|
dist_class: Type[TorchDistributionWrapper],
|
|
train_batch: SampleBatch,
|
|
) -> Union[TensorType, List[TensorType]]:
|
|
model_out, _ = model(train_batch)
|
|
action_dist = dist_class(model_out, model)
|
|
actions = train_batch[SampleBatch.ACTIONS]
|
|
# log\pi_\theta(a|s)
|
|
logprobs = action_dist.logp(actions)
|
|
|
|
# Advantage estimation.
|
|
if self.config["beta"] != 0.0:
|
|
cumulative_rewards = train_batch[Postprocessing.ADVANTAGES]
|
|
state_values = model.value_function()
|
|
adv = cumulative_rewards - state_values
|
|
adv_squared_mean = torch.mean(torch.pow(adv, 2.0))
|
|
|
|
explained_var = explained_variance(cumulative_rewards, state_values)
|
|
self.explained_variance = torch.mean(explained_var)
|
|
|
|
# Policy loss.
|
|
# Update averaged advantage norm.
|
|
rate = self.config["moving_average_sqd_adv_norm_update_rate"]
|
|
self._moving_average_sqd_adv_norm.add_(
|
|
rate * (adv_squared_mean - self._moving_average_sqd_adv_norm)
|
|
)
|
|
# Exponentially weighted advantages.
|
|
exp_advs = torch.exp(
|
|
self.config["beta"]
|
|
* (adv / (1e-8 + torch.pow(self._moving_average_sqd_adv_norm, 0.5)))
|
|
).detach()
|
|
# Value loss.
|
|
self.v_loss = 0.5 * adv_squared_mean
|
|
else:
|
|
# Policy loss (simple BC loss term).
|
|
exp_advs = 1.0
|
|
# Value loss.
|
|
self.v_loss = 0.0
|
|
|
|
# logprob loss alone tends to push action distributions to
|
|
# have very low entropy, resulting in worse performance for
|
|
# unfamiliar situations.
|
|
# A scaled logstd loss term encourages stochasticity, thus
|
|
# alleviate the problem to some extent.
|
|
logstd_coeff = self.config["bc_logstd_coeff"]
|
|
if logstd_coeff > 0.0:
|
|
logstds = torch.mean(action_dist.log_std, dim=1)
|
|
else:
|
|
logstds = 0.0
|
|
|
|
self.p_loss = -torch.mean(exp_advs * (logprobs + logstd_coeff * logstds))
|
|
|
|
# Combine both losses.
|
|
self.total_loss = self.p_loss + self.config["vf_coeff"] * self.v_loss
|
|
|
|
return self.total_loss
|
|
|
|
@override(TorchPolicyV2)
|
|
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
|
stats = {
|
|
"policy_loss": self.p_loss,
|
|
"total_loss": self.total_loss,
|
|
}
|
|
if self.config["beta"] != 0.0:
|
|
stats["moving_average_sqd_adv_norm"] = self._moving_average_sqd_adv_norm
|
|
stats["vf_explained_var"] = self.explained_variance
|
|
stats["vf_loss"] = self.v_loss
|
|
return convert_to_numpy(stats)
|
|
|
|
def extra_grad_process(
|
|
self, optimizer: "torch.optim.Optimizer", loss: TensorType
|
|
) -> Dict[str, TensorType]:
|
|
return apply_grad_clipping(self, optimizer, loss)
|