mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
122 lines
4.2 KiB
Python
122 lines
4.2 KiB
Python
import gym
|
|
from typing import Dict
|
|
|
|
import ray
|
|
from ray.rllib.agents.a3c.a3c_torch_policy import ValueNetworkMixin
|
|
from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages
|
|
from ray.rllib.evaluation.postprocessing import Postprocessing
|
|
from ray.rllib.policy.policy_template import build_policy_class
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.torch_utils import apply_grad_clipping, explained_variance
|
|
from ray.rllib.utils.typing import TrainerConfigDict, TensorType
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
def marwil_loss(
|
|
policy: Policy,
|
|
model: ModelV2,
|
|
dist_class: ActionDistribution,
|
|
train_batch: SampleBatch,
|
|
) -> TensorType:
|
|
model_out, _ = model(train_batch)
|
|
action_dist = dist_class(model_out, model)
|
|
actions = train_batch[SampleBatch.ACTIONS]
|
|
# log\pi_\theta(a|s)
|
|
logprobs = action_dist.logp(actions)
|
|
|
|
# Advantage estimation.
|
|
if policy.config["beta"] != 0.0:
|
|
cumulative_rewards = train_batch[Postprocessing.ADVANTAGES]
|
|
state_values = model.value_function()
|
|
adv = cumulative_rewards - state_values
|
|
adv_squared_mean = torch.mean(torch.pow(adv, 2.0))
|
|
|
|
explained_var = explained_variance(cumulative_rewards, state_values)
|
|
policy.explained_variance = torch.mean(explained_var)
|
|
|
|
# Policy loss.
|
|
# Update averaged advantage norm.
|
|
rate = policy.config["moving_average_sqd_adv_norm_update_rate"]
|
|
policy._moving_average_sqd_adv_norm.add_(
|
|
rate * (adv_squared_mean - policy._moving_average_sqd_adv_norm)
|
|
)
|
|
# Exponentially weighted advantages.
|
|
exp_advs = torch.exp(
|
|
policy.config["beta"]
|
|
* (adv / (1e-8 + torch.pow(policy._moving_average_sqd_adv_norm, 0.5)))
|
|
).detach()
|
|
# Value loss.
|
|
policy.v_loss = 0.5 * adv_squared_mean
|
|
else:
|
|
# Policy loss (simple BC loss term).
|
|
exp_advs = 1.0
|
|
# Value loss.
|
|
policy.v_loss = 0.0
|
|
|
|
# logprob loss alone tends to push action distributions to
|
|
# have very low entropy, resulting in worse performance for
|
|
# unfamiliar situations.
|
|
# A scaled logstd loss term encourages stochasticity, thus
|
|
# alleviate the problem to some extent.
|
|
logstd_coeff = policy.config["bc_logstd_coeff"]
|
|
if logstd_coeff > 0.0:
|
|
logstds = torch.mean(action_dist.log_std, dim=1)
|
|
else:
|
|
logstds = 0.0
|
|
|
|
policy.p_loss = -torch.mean(exp_advs * (logprobs + logstd_coeff * logstds))
|
|
|
|
# Combine both losses.
|
|
policy.total_loss = policy.p_loss + policy.config["vf_coeff"] * policy.v_loss
|
|
|
|
return policy.total_loss
|
|
|
|
|
|
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
|
stats = {
|
|
"policy_loss": policy.p_loss,
|
|
"total_loss": policy.total_loss,
|
|
}
|
|
if policy.config["beta"] != 0.0:
|
|
stats["moving_average_sqd_adv_norm"] = policy._moving_average_sqd_adv_norm
|
|
stats["vf_explained_var"] = policy.explained_variance
|
|
stats["vf_loss"] = policy.v_loss
|
|
|
|
return stats
|
|
|
|
|
|
def setup_mixins(
|
|
policy: Policy,
|
|
obs_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
config: TrainerConfigDict,
|
|
) -> None:
|
|
# Setup Value branch of our NN.
|
|
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
|
|
|
# Not needed for pure BC.
|
|
if policy.config["beta"] != 0.0:
|
|
# Set up a torch-var for the squared moving avg. advantage norm.
|
|
policy._moving_average_sqd_adv_norm = torch.tensor(
|
|
[policy.config["moving_average_sqd_adv_norm_start"]],
|
|
dtype=torch.float32,
|
|
requires_grad=False,
|
|
).to(policy.device)
|
|
|
|
|
|
MARWILTorchPolicy = build_policy_class(
|
|
name="MARWILTorchPolicy",
|
|
framework="torch",
|
|
loss_fn=marwil_loss,
|
|
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
|
|
stats_fn=stats,
|
|
postprocess_fn=postprocess_advantages,
|
|
extra_grad_process_fn=apply_grad_clipping,
|
|
before_loss_init=setup_mixins,
|
|
mixins=[ValueNetworkMixin],
|
|
)
|