mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] MARWIL loss function test case and cleanup. (#13455)
This commit is contained in:
parent
2506a6cd0e
commit
a65ee92b69
3 changed files with 139 additions and 72 deletions
|
@ -1,3 +1,5 @@
|
|||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
|
@ -8,6 +10,8 @@ from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
|
|||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValueNetworkMixin:
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
|
@ -43,47 +47,6 @@ class ValueNetworkMixin:
|
|||
self._value = value
|
||||
|
||||
|
||||
class ValueLoss:
|
||||
def __init__(self, state_values, cumulative_rewards):
|
||||
self.loss = 0.5 * tf.reduce_mean(
|
||||
tf.math.square(state_values - cumulative_rewards))
|
||||
|
||||
|
||||
class ReweightedImitationLoss:
|
||||
def __init__(self, policy, state_values, cumulative_rewards, actions,
|
||||
action_dist, beta):
|
||||
if beta != 0.0:
|
||||
# Advantage Estimation.
|
||||
adv = cumulative_rewards - state_values
|
||||
|
||||
# Update averaged advantage norm.
|
||||
# Eager.
|
||||
if policy.config["framework"] in ["tf2", "tfe"]:
|
||||
policy._ma_adv_norm.assign_add(1e-6 * (
|
||||
tf.reduce_mean(tf.math.square(adv)) - policy._ma_adv_norm))
|
||||
# Exponentially weighted advantages.
|
||||
exp_advs = tf.math.exp(beta * tf.math.divide(
|
||||
adv, 1e-8 + tf.math.sqrt(policy._ma_adv_norm)))
|
||||
# Static graph.
|
||||
else:
|
||||
update_adv_norm = tf1.assign_add(
|
||||
ref=policy._ma_adv_norm,
|
||||
value=1e-6 * (tf.reduce_mean(tf.math.square(adv)) -
|
||||
policy._ma_adv_norm))
|
||||
|
||||
# Exponentially weighted advantages.
|
||||
with tf1.control_dependencies([update_adv_norm]):
|
||||
exp_advs = tf.math.exp(beta * tf.math.divide(
|
||||
adv, 1e-8 + tf.math.sqrt(policy._ma_adv_norm)))
|
||||
exp_advs = tf.stop_gradient(exp_advs)
|
||||
else:
|
||||
exp_advs = 1.0
|
||||
|
||||
# L = - A * log\pi_\theta(a|s)
|
||||
logprobs = action_dist.logp(actions)
|
||||
self.loss = -1.0 * tf.reduce_mean(exp_advs * logprobs)
|
||||
|
||||
|
||||
def postprocess_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
|
@ -135,43 +98,74 @@ def postprocess_advantages(policy,
|
|||
sample_batch[SampleBatch.REWARDS][-1],
|
||||
*next_state)
|
||||
|
||||
# Adds the policy logits, VF preds, and advantages to the batch,
|
||||
# using GAE ("generalized advantage estimation") or not.
|
||||
# Adds the "advantages" (which in the case of MARWIL are simply the
|
||||
# discounted cummulative rewards) to the SampleBatch.
|
||||
return compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
policy.config["gamma"],
|
||||
# We just want the discounted cummulative rewards, so we won't need
|
||||
# GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
|
||||
use_gae=False,
|
||||
use_critic=False)
|
||||
|
||||
|
||||
class MARWILLoss:
|
||||
def __init__(self, policy, state_values, action_dist, actions, advantages,
|
||||
vf_loss_coeff, beta):
|
||||
def __init__(self, policy, value_estimates, action_dist, actions,
|
||||
cumulative_rewards, vf_loss_coeff, beta):
|
||||
|
||||
self.v_loss = self._build_value_loss(state_values, advantages)
|
||||
self.p_loss = self._build_policy_loss(policy, state_values, advantages,
|
||||
actions, action_dist, beta)
|
||||
# Advantage Estimation.
|
||||
adv = cumulative_rewards - value_estimates
|
||||
adv_squared = tf.reduce_mean(tf.math.square(adv))
|
||||
|
||||
self.total_loss = self.p_loss.loss + vf_loss_coeff * self.v_loss.loss
|
||||
explained_var = explained_variance(advantages, state_values)
|
||||
self.explained_variance = tf.reduce_mean(explained_var)
|
||||
# Value function's loss term (MSE).
|
||||
self.v_loss = 0.5 * adv_squared
|
||||
|
||||
def _build_value_loss(self, state_values, cum_rwds):
|
||||
return ValueLoss(state_values, cum_rwds)
|
||||
if beta != 0.0:
|
||||
# Perform moving averaging of advantage^2.
|
||||
|
||||
def _build_policy_loss(self, policy, state_values, cum_rwds, actions,
|
||||
action_dist, beta):
|
||||
return ReweightedImitationLoss(policy, state_values, cum_rwds, actions,
|
||||
action_dist, beta)
|
||||
# Update averaged advantage norm.
|
||||
# Eager.
|
||||
if policy.config["framework"] in ["tf2", "tfe"]:
|
||||
update_term = adv_squared - policy._moving_average_sqd_adv_norm
|
||||
policy._moving_average_sqd_adv_norm.assign_add(
|
||||
1e-8 * update_term)
|
||||
|
||||
# Exponentially weighted advantages.
|
||||
c = tf.math.sqrt(policy._moving_average_sqd_adv_norm)
|
||||
exp_advs = tf.math.exp(beta * (adv / c))
|
||||
# Static graph.
|
||||
else:
|
||||
update_adv_norm = tf1.assign_add(
|
||||
ref=policy._moving_average_sqd_adv_norm,
|
||||
value=1e-6 *
|
||||
(adv_squared - policy._moving_average_sqd_adv_norm))
|
||||
|
||||
# Exponentially weighted advantages.
|
||||
with tf1.control_dependencies([update_adv_norm]):
|
||||
exp_advs = tf.math.exp(beta * tf.math.divide(
|
||||
adv, 1e-8 + tf.math.sqrt(
|
||||
policy._moving_average_sqd_adv_norm)))
|
||||
exp_advs = tf.stop_gradient(exp_advs)
|
||||
else:
|
||||
exp_advs = 1.0
|
||||
|
||||
# L = - A * log\pi_\theta(a|s)
|
||||
logprobs = action_dist.logp(actions)
|
||||
self.p_loss = -1.0 * tf.reduce_mean(exp_advs * logprobs)
|
||||
|
||||
self.total_loss = self.p_loss + vf_loss_coeff * self.v_loss
|
||||
|
||||
self.explained_variance = tf.reduce_mean(
|
||||
explained_variance(cumulative_rewards, value_estimates))
|
||||
|
||||
|
||||
def marwil_loss(policy, model, dist_class, train_batch):
|
||||
model_out, _ = model.from_batch(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
state_values = model.value_function()
|
||||
value_estimates = model.value_function()
|
||||
|
||||
policy.loss = MARWILLoss(policy, state_values, action_dist,
|
||||
policy.loss = MARWILLoss(policy, value_estimates, action_dist,
|
||||
train_batch[SampleBatch.ACTIONS],
|
||||
train_batch[Postprocessing.ADVANTAGES],
|
||||
policy.config["vf_coeff"], policy.config["beta"])
|
||||
|
@ -181,8 +175,8 @@ def marwil_loss(policy, model, dist_class, train_batch):
|
|||
|
||||
def stats(policy, train_batch):
|
||||
return {
|
||||
"policy_loss": policy.loss.p_loss.loss,
|
||||
"vf_loss": policy.loss.v_loss.loss,
|
||||
"policy_loss": policy.loss.p_loss,
|
||||
"vf_loss": policy.loss.v_loss,
|
||||
"total_loss": policy.loss.total_loss,
|
||||
"vf_explained_var": policy.loss.explained_variance,
|
||||
}
|
||||
|
@ -191,8 +185,8 @@ def stats(policy, train_batch):
|
|||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
# Set up a tf-var for the moving avg (do this here to make it work with
|
||||
# eager mode).
|
||||
policy._ma_adv_norm = get_variable(
|
||||
# eager mode); "c^2" in the paper.
|
||||
policy._moving_average_sqd_adv_norm = get_variable(
|
||||
100.0,
|
||||
framework="tf",
|
||||
tf_name="moving_average_of_advantage_norm",
|
||||
|
|
|
@ -48,16 +48,17 @@ def marwil_loss(policy, model, dist_class, train_batch):
|
|||
advantages = train_batch[Postprocessing.ADVANTAGES]
|
||||
actions = train_batch[SampleBatch.ACTIONS]
|
||||
|
||||
# Value loss.
|
||||
policy.v_loss = 0.5 * torch.mean(torch.pow(state_values - advantages, 2.0))
|
||||
|
||||
# Policy loss.
|
||||
# Advantage estimation.
|
||||
adv = advantages - state_values
|
||||
adv_squared = torch.mean(torch.pow(adv, 2.0))
|
||||
|
||||
# Value loss.
|
||||
policy.v_loss = 0.5 * adv_squared
|
||||
|
||||
# Policy loss.
|
||||
# Update averaged advantage norm.
|
||||
policy.ma_adv_norm.add_(
|
||||
1e-6 * (torch.mean(torch.pow(adv, 2.0)) - policy.ma_adv_norm))
|
||||
# #xponentially weighted advantages.
|
||||
policy.ma_adv_norm.add_(1e-6 * (adv_squared - policy.ma_adv_norm))
|
||||
# Exponentially weighted advantages.
|
||||
exp_advs = torch.exp(policy.config["beta"] *
|
||||
(adv / (1e-8 + torch.pow(policy.ma_adv_norm, 0.5))))
|
||||
# log\pi_\theta(a|s)
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
import numpy as np
|
||||
import os
|
||||
from pathlib import Path
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.marwil as marwil
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check_compute_single_action, \
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.offline import JsonReader
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
|
||||
framework_iterator
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
class TestMARWIL(unittest.TestCase):
|
||||
|
@ -70,6 +74,74 @@ class TestMARWIL(unittest.TestCase):
|
|||
|
||||
trainer.stop()
|
||||
|
||||
def test_marwil_loss_function(self):
|
||||
"""
|
||||
To generate the historic data used in this test case, first run:
|
||||
$ ./train.py --run=PPO --env=CartPole-v0 \
|
||||
--stop='{"timesteps_total": 50000}' \
|
||||
--config='{"output": "/tmp/out", "batch_mode": "complete_episodes"}'
|
||||
"""
|
||||
rllib_dir = Path(__file__).parent.parent.parent.parent
|
||||
print("rllib dir={}".format(rllib_dir))
|
||||
data_file = os.path.join(rllib_dir, "tests/data/cartpole/small.json")
|
||||
print("data_file={} exists={}".format(data_file,
|
||||
os.path.isfile(data_file)))
|
||||
config = marwil.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
# Learn from offline data.
|
||||
config["input"] = [data_file]
|
||||
|
||||
for fw in framework_iterator(config, frameworks=["torch", "tf2"]):
|
||||
reader = JsonReader(inputs=[data_file])
|
||||
batch = reader.next()
|
||||
|
||||
trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0")
|
||||
policy = trainer.get_policy()
|
||||
model = policy.model
|
||||
|
||||
# Calculate our own expected values (to then compare against the
|
||||
# agent's loss output).
|
||||
cummulative_rewards = compute_advantages(
|
||||
batch, 0.0, config["gamma"], 1.0, False, False)["advantages"]
|
||||
if fw == "torch":
|
||||
cummulative_rewards = torch.tensor(cummulative_rewards)
|
||||
tensor_batch = policy._lazy_tensor_dict(batch)
|
||||
model_out, _ = model.from_batch(tensor_batch)
|
||||
vf_estimates = model.value_function()
|
||||
adv = cummulative_rewards - vf_estimates
|
||||
if fw == "torch":
|
||||
adv = adv.detach().cpu().numpy()
|
||||
adv_squared = np.mean(np.square(adv))
|
||||
c_2 = 100.0 + 1e-8 * (adv_squared - 100.0)
|
||||
c = np.sqrt(c_2)
|
||||
exp_advs = np.exp(config["beta"] * (adv / c))
|
||||
logp = policy.dist_class(model_out,
|
||||
model).logp(tensor_batch["actions"])
|
||||
if fw == "torch":
|
||||
logp = logp.detach().cpu().numpy()
|
||||
# Calculate all expected loss components.
|
||||
expected_vf_loss = 0.5 * adv_squared
|
||||
expected_pol_loss = -1.0 * np.mean(exp_advs * logp)
|
||||
expected_loss = \
|
||||
expected_pol_loss + config["vf_coeff"] * expected_vf_loss
|
||||
|
||||
# Calculate the algorithm's loss (to check against our own
|
||||
# calculation above).
|
||||
postprocessed_batch = policy.postprocess_trajectory(batch)
|
||||
loss_func = marwil.marwil_tf_policy.marwil_loss if fw != "torch" \
|
||||
else marwil.marwil_torch_policy.marwil_loss
|
||||
loss_out = loss_func(policy, model, policy.dist_class,
|
||||
policy._lazy_tensor_dict(postprocessed_batch))
|
||||
|
||||
# Check all components.
|
||||
if fw == "torch":
|
||||
check(policy.v_loss, expected_vf_loss, decimals=4)
|
||||
check(policy.p_loss, expected_pol_loss, decimals=4)
|
||||
else:
|
||||
check(policy.loss.v_loss, expected_vf_loss, decimals=4)
|
||||
check(policy.loss.p_loss, expected_pol_loss, decimals=4)
|
||||
check(loss_out, expected_loss, decimals=3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
|
Loading…
Add table
Reference in a new issue