mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
85 lines
3 KiB
Python
85 lines
3 KiB
Python
import ray
|
|
from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages
|
|
from ray.rllib.evaluation.postprocessing import Postprocessing
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.torch_ops import explained_variance
|
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
class ValueNetworkMixin:
|
|
def __init__(self):
|
|
def value(ob, prev_action, prev_reward, *state):
|
|
model_out, _ = self.model({
|
|
SampleBatch.CUR_OBS: torch.Tensor([ob]).to(self.device),
|
|
SampleBatch.PREV_ACTIONS: torch.Tensor([prev_action]).to(
|
|
self.device),
|
|
SampleBatch.PREV_REWARDS: torch.Tensor([prev_reward]).to(
|
|
self.device),
|
|
"is_training": False,
|
|
}, [torch.Tensor([s]).to(self.device) for s in state],
|
|
torch.Tensor([1]).to(self.device))
|
|
return self.model.value_function()[0]
|
|
|
|
self._value = value
|
|
|
|
|
|
def marwil_loss(policy, model, dist_class, train_batch):
|
|
model_out, _ = model.from_batch(train_batch)
|
|
action_dist = dist_class(model_out, model)
|
|
state_values = model.value_function()
|
|
advantages = train_batch[Postprocessing.ADVANTAGES]
|
|
actions = train_batch[SampleBatch.ACTIONS]
|
|
|
|
# Value loss.
|
|
policy.v_loss = 0.5 * torch.mean(torch.pow(state_values - advantages, 2.0))
|
|
|
|
# Policy loss.
|
|
# Advantage estimation.
|
|
adv = advantages - state_values
|
|
# Update averaged advantage norm.
|
|
policy.ma_adv_norm.add_(
|
|
1e-6 * (torch.mean(torch.pow(adv, 2.0)) - policy.ma_adv_norm))
|
|
# #xponentially weighted advantages.
|
|
exp_advs = torch.exp(policy.config["beta"] *
|
|
(adv / (1e-8 + torch.pow(policy.ma_adv_norm, 0.5))))
|
|
# log\pi_\theta(a|s)
|
|
logprobs = action_dist.logp(actions)
|
|
policy.p_loss = -1.0 * torch.mean(exp_advs.detach() * logprobs)
|
|
|
|
# Combine both losses.
|
|
policy.total_loss = policy.p_loss + policy.config["vf_coeff"] * \
|
|
policy.v_loss
|
|
explained_var = explained_variance(advantages, state_values)
|
|
policy.explained_variance = torch.mean(explained_var)
|
|
|
|
return policy.total_loss
|
|
|
|
|
|
def stats(policy, train_batch):
|
|
return {
|
|
"policy_loss": policy.p_loss,
|
|
"vf_loss": policy.v_loss,
|
|
"total_loss": policy.total_loss,
|
|
"vf_explained_var": policy.explained_variance,
|
|
}
|
|
|
|
|
|
def setup_mixins(policy, obs_space, action_space, config):
|
|
# Create a var.
|
|
policy.ma_adv_norm = torch.tensor(
|
|
[100.0], dtype=torch.float32, requires_grad=False).to(policy.device)
|
|
# Setup Value branch of our NN.
|
|
ValueNetworkMixin.__init__(policy)
|
|
|
|
|
|
MARWILTorchPolicy = build_torch_policy(
|
|
name="MARWILTorchPolicy",
|
|
loss_fn=marwil_loss,
|
|
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
|
|
stats_fn=stats,
|
|
postprocess_fn=postprocess_advantages,
|
|
before_loss_init=setup_mixins,
|
|
mixins=[ValueNetworkMixin])
|