[RLlib] MARWIL loss function test case and cleanup. (#13455)

This commit is contained in:
Sven Mika 2021-01-19 09:51:05 +01:00 committed by GitHub
parent 2506a6cd0e
commit a65ee92b69
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 139 additions and 72 deletions

View file

@ -1,3 +1,5 @@
import logging
import ray
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.postprocessing import compute_advantages, \
@ -8,6 +10,8 @@ from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
class ValueNetworkMixin:
def __init__(self, obs_space, action_space, config):
@ -43,47 +47,6 @@ class ValueNetworkMixin:
self._value = value
class ValueLoss:
def __init__(self, state_values, cumulative_rewards):
self.loss = 0.5 * tf.reduce_mean(
tf.math.square(state_values - cumulative_rewards))
class ReweightedImitationLoss:
def __init__(self, policy, state_values, cumulative_rewards, actions,
action_dist, beta):
if beta != 0.0:
# Advantage Estimation.
adv = cumulative_rewards - state_values
# Update averaged advantage norm.
# Eager.
if policy.config["framework"] in ["tf2", "tfe"]:
policy._ma_adv_norm.assign_add(1e-6 * (
tf.reduce_mean(tf.math.square(adv)) - policy._ma_adv_norm))
# Exponentially weighted advantages.
exp_advs = tf.math.exp(beta * tf.math.divide(
adv, 1e-8 + tf.math.sqrt(policy._ma_adv_norm)))
# Static graph.
else:
update_adv_norm = tf1.assign_add(
ref=policy._ma_adv_norm,
value=1e-6 * (tf.reduce_mean(tf.math.square(adv)) -
policy._ma_adv_norm))
# Exponentially weighted advantages.
with tf1.control_dependencies([update_adv_norm]):
exp_advs = tf.math.exp(beta * tf.math.divide(
adv, 1e-8 + tf.math.sqrt(policy._ma_adv_norm)))
exp_advs = tf.stop_gradient(exp_advs)
else:
exp_advs = 1.0
# L = - A * log\pi_\theta(a|s)
logprobs = action_dist.logp(actions)
self.loss = -1.0 * tf.reduce_mean(exp_advs * logprobs)
def postprocess_advantages(policy,
sample_batch,
other_agent_batches=None,
@ -135,43 +98,74 @@ def postprocess_advantages(policy,
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
# Adds the policy logits, VF preds, and advantages to the batch,
# using GAE ("generalized advantage estimation") or not.
# Adds the "advantages" (which in the case of MARWIL are simply the
# discounted cummulative rewards) to the SampleBatch.
return compute_advantages(
sample_batch,
last_r,
policy.config["gamma"],
# We just want the discounted cummulative rewards, so we won't need
# GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
use_gae=False,
use_critic=False)
class MARWILLoss:
def __init__(self, policy, state_values, action_dist, actions, advantages,
vf_loss_coeff, beta):
def __init__(self, policy, value_estimates, action_dist, actions,
cumulative_rewards, vf_loss_coeff, beta):
self.v_loss = self._build_value_loss(state_values, advantages)
self.p_loss = self._build_policy_loss(policy, state_values, advantages,
actions, action_dist, beta)
# Advantage Estimation.
adv = cumulative_rewards - value_estimates
adv_squared = tf.reduce_mean(tf.math.square(adv))
self.total_loss = self.p_loss.loss + vf_loss_coeff * self.v_loss.loss
explained_var = explained_variance(advantages, state_values)
self.explained_variance = tf.reduce_mean(explained_var)
# Value function's loss term (MSE).
self.v_loss = 0.5 * adv_squared
def _build_value_loss(self, state_values, cum_rwds):
return ValueLoss(state_values, cum_rwds)
if beta != 0.0:
# Perform moving averaging of advantage^2.
def _build_policy_loss(self, policy, state_values, cum_rwds, actions,
action_dist, beta):
return ReweightedImitationLoss(policy, state_values, cum_rwds, actions,
action_dist, beta)
# Update averaged advantage norm.
# Eager.
if policy.config["framework"] in ["tf2", "tfe"]:
update_term = adv_squared - policy._moving_average_sqd_adv_norm
policy._moving_average_sqd_adv_norm.assign_add(
1e-8 * update_term)
# Exponentially weighted advantages.
c = tf.math.sqrt(policy._moving_average_sqd_adv_norm)
exp_advs = tf.math.exp(beta * (adv / c))
# Static graph.
else:
update_adv_norm = tf1.assign_add(
ref=policy._moving_average_sqd_adv_norm,
value=1e-6 *
(adv_squared - policy._moving_average_sqd_adv_norm))
# Exponentially weighted advantages.
with tf1.control_dependencies([update_adv_norm]):
exp_advs = tf.math.exp(beta * tf.math.divide(
adv, 1e-8 + tf.math.sqrt(
policy._moving_average_sqd_adv_norm)))
exp_advs = tf.stop_gradient(exp_advs)
else:
exp_advs = 1.0
# L = - A * log\pi_\theta(a|s)
logprobs = action_dist.logp(actions)
self.p_loss = -1.0 * tf.reduce_mean(exp_advs * logprobs)
self.total_loss = self.p_loss + vf_loss_coeff * self.v_loss
self.explained_variance = tf.reduce_mean(
explained_variance(cumulative_rewards, value_estimates))
def marwil_loss(policy, model, dist_class, train_batch):
model_out, _ = model.from_batch(train_batch)
action_dist = dist_class(model_out, model)
state_values = model.value_function()
value_estimates = model.value_function()
policy.loss = MARWILLoss(policy, state_values, action_dist,
policy.loss = MARWILLoss(policy, value_estimates, action_dist,
train_batch[SampleBatch.ACTIONS],
train_batch[Postprocessing.ADVANTAGES],
policy.config["vf_coeff"], policy.config["beta"])
@ -181,8 +175,8 @@ def marwil_loss(policy, model, dist_class, train_batch):
def stats(policy, train_batch):
return {
"policy_loss": policy.loss.p_loss.loss,
"vf_loss": policy.loss.v_loss.loss,
"policy_loss": policy.loss.p_loss,
"vf_loss": policy.loss.v_loss,
"total_loss": policy.loss.total_loss,
"vf_explained_var": policy.loss.explained_variance,
}
@ -191,8 +185,8 @@ def stats(policy, train_batch):
def setup_mixins(policy, obs_space, action_space, config):
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
# Set up a tf-var for the moving avg (do this here to make it work with
# eager mode).
policy._ma_adv_norm = get_variable(
# eager mode); "c^2" in the paper.
policy._moving_average_sqd_adv_norm = get_variable(
100.0,
framework="tf",
tf_name="moving_average_of_advantage_norm",

View file

@ -48,16 +48,17 @@ def marwil_loss(policy, model, dist_class, train_batch):
advantages = train_batch[Postprocessing.ADVANTAGES]
actions = train_batch[SampleBatch.ACTIONS]
# Value loss.
policy.v_loss = 0.5 * torch.mean(torch.pow(state_values - advantages, 2.0))
# Policy loss.
# Advantage estimation.
adv = advantages - state_values
adv_squared = torch.mean(torch.pow(adv, 2.0))
# Value loss.
policy.v_loss = 0.5 * adv_squared
# Policy loss.
# Update averaged advantage norm.
policy.ma_adv_norm.add_(
1e-6 * (torch.mean(torch.pow(adv, 2.0)) - policy.ma_adv_norm))
# #xponentially weighted advantages.
policy.ma_adv_norm.add_(1e-6 * (adv_squared - policy.ma_adv_norm))
# Exponentially weighted advantages.
exp_advs = torch.exp(policy.config["beta"] *
(adv / (1e-8 + torch.pow(policy.ma_adv_norm, 0.5))))
# log\pi_\theta(a|s)

View file

@ -1,14 +1,18 @@
import numpy as np
import os
from pathlib import Path
import unittest
import ray
import ray.rllib.agents.marwil as marwil
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check_compute_single_action, \
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.offline import JsonReader
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
framework_iterator
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
class TestMARWIL(unittest.TestCase):
@ -70,6 +74,74 @@ class TestMARWIL(unittest.TestCase):
trainer.stop()
def test_marwil_loss_function(self):
"""
To generate the historic data used in this test case, first run:
$ ./train.py --run=PPO --env=CartPole-v0 \
--stop='{"timesteps_total": 50000}' \
--config='{"output": "/tmp/out", "batch_mode": "complete_episodes"}'
"""
rllib_dir = Path(__file__).parent.parent.parent.parent
print("rllib dir={}".format(rllib_dir))
data_file = os.path.join(rllib_dir, "tests/data/cartpole/small.json")
print("data_file={} exists={}".format(data_file,
os.path.isfile(data_file)))
config = marwil.DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
# Learn from offline data.
config["input"] = [data_file]
for fw in framework_iterator(config, frameworks=["torch", "tf2"]):
reader = JsonReader(inputs=[data_file])
batch = reader.next()
trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
model = policy.model
# Calculate our own expected values (to then compare against the
# agent's loss output).
cummulative_rewards = compute_advantages(
batch, 0.0, config["gamma"], 1.0, False, False)["advantages"]
if fw == "torch":
cummulative_rewards = torch.tensor(cummulative_rewards)
tensor_batch = policy._lazy_tensor_dict(batch)
model_out, _ = model.from_batch(tensor_batch)
vf_estimates = model.value_function()
adv = cummulative_rewards - vf_estimates
if fw == "torch":
adv = adv.detach().cpu().numpy()
adv_squared = np.mean(np.square(adv))
c_2 = 100.0 + 1e-8 * (adv_squared - 100.0)
c = np.sqrt(c_2)
exp_advs = np.exp(config["beta"] * (adv / c))
logp = policy.dist_class(model_out,
model).logp(tensor_batch["actions"])
if fw == "torch":
logp = logp.detach().cpu().numpy()
# Calculate all expected loss components.
expected_vf_loss = 0.5 * adv_squared
expected_pol_loss = -1.0 * np.mean(exp_advs * logp)
expected_loss = \
expected_pol_loss + config["vf_coeff"] * expected_vf_loss
# Calculate the algorithm's loss (to check against our own
# calculation above).
postprocessed_batch = policy.postprocess_trajectory(batch)
loss_func = marwil.marwil_tf_policy.marwil_loss if fw != "torch" \
else marwil.marwil_torch_policy.marwil_loss
loss_out = loss_func(policy, model, policy.dist_class,
policy._lazy_tensor_dict(postprocessed_batch))
# Check all components.
if fw == "torch":
check(policy.v_loss, expected_vf_loss, decimals=4)
check(policy.p_loss, expected_pol_loss, decimals=4)
else:
check(policy.loss.v_loss, expected_vf_loss, decimals=4)
check(policy.loss.p_loss, expected_pol_loss, decimals=4)
check(loss_out, expected_loss, decimals=3)
if __name__ == "__main__":
import pytest