[RLlib] MARWIL loss function test case and cleanup. (#13455)

This commit is contained in:
Sven Mika 2021-01-19 09:51:05 +01:00 committed by GitHub
parent 2506a6cd0e
commit a65ee92b69
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 139 additions and 72 deletions

View file

@ -1,3 +1,5 @@
import logging
import ray import ray
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.postprocessing import compute_advantages, \ from ray.rllib.evaluation.postprocessing import compute_advantages, \
@ -8,6 +10,8 @@ from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
class ValueNetworkMixin: class ValueNetworkMixin:
def __init__(self, obs_space, action_space, config): def __init__(self, obs_space, action_space, config):
@ -43,47 +47,6 @@ class ValueNetworkMixin:
self._value = value self._value = value
class ValueLoss:
def __init__(self, state_values, cumulative_rewards):
self.loss = 0.5 * tf.reduce_mean(
tf.math.square(state_values - cumulative_rewards))
class ReweightedImitationLoss:
def __init__(self, policy, state_values, cumulative_rewards, actions,
action_dist, beta):
if beta != 0.0:
# Advantage Estimation.
adv = cumulative_rewards - state_values
# Update averaged advantage norm.
# Eager.
if policy.config["framework"] in ["tf2", "tfe"]:
policy._ma_adv_norm.assign_add(1e-6 * (
tf.reduce_mean(tf.math.square(adv)) - policy._ma_adv_norm))
# Exponentially weighted advantages.
exp_advs = tf.math.exp(beta * tf.math.divide(
adv, 1e-8 + tf.math.sqrt(policy._ma_adv_norm)))
# Static graph.
else:
update_adv_norm = tf1.assign_add(
ref=policy._ma_adv_norm,
value=1e-6 * (tf.reduce_mean(tf.math.square(adv)) -
policy._ma_adv_norm))
# Exponentially weighted advantages.
with tf1.control_dependencies([update_adv_norm]):
exp_advs = tf.math.exp(beta * tf.math.divide(
adv, 1e-8 + tf.math.sqrt(policy._ma_adv_norm)))
exp_advs = tf.stop_gradient(exp_advs)
else:
exp_advs = 1.0
# L = - A * log\pi_\theta(a|s)
logprobs = action_dist.logp(actions)
self.loss = -1.0 * tf.reduce_mean(exp_advs * logprobs)
def postprocess_advantages(policy, def postprocess_advantages(policy,
sample_batch, sample_batch,
other_agent_batches=None, other_agent_batches=None,
@ -135,43 +98,74 @@ def postprocess_advantages(policy,
sample_batch[SampleBatch.REWARDS][-1], sample_batch[SampleBatch.REWARDS][-1],
*next_state) *next_state)
# Adds the policy logits, VF preds, and advantages to the batch, # Adds the "advantages" (which in the case of MARWIL are simply the
# using GAE ("generalized advantage estimation") or not. # discounted cummulative rewards) to the SampleBatch.
return compute_advantages( return compute_advantages(
sample_batch, sample_batch,
last_r, last_r,
policy.config["gamma"], policy.config["gamma"],
# We just want the discounted cummulative rewards, so we won't need
# GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
use_gae=False, use_gae=False,
use_critic=False) use_critic=False)
class MARWILLoss: class MARWILLoss:
def __init__(self, policy, state_values, action_dist, actions, advantages, def __init__(self, policy, value_estimates, action_dist, actions,
vf_loss_coeff, beta): cumulative_rewards, vf_loss_coeff, beta):
self.v_loss = self._build_value_loss(state_values, advantages) # Advantage Estimation.
self.p_loss = self._build_policy_loss(policy, state_values, advantages, adv = cumulative_rewards - value_estimates
actions, action_dist, beta) adv_squared = tf.reduce_mean(tf.math.square(adv))
self.total_loss = self.p_loss.loss + vf_loss_coeff * self.v_loss.loss # Value function's loss term (MSE).
explained_var = explained_variance(advantages, state_values) self.v_loss = 0.5 * adv_squared
self.explained_variance = tf.reduce_mean(explained_var)
def _build_value_loss(self, state_values, cum_rwds): if beta != 0.0:
return ValueLoss(state_values, cum_rwds) # Perform moving averaging of advantage^2.
def _build_policy_loss(self, policy, state_values, cum_rwds, actions, # Update averaged advantage norm.
action_dist, beta): # Eager.
return ReweightedImitationLoss(policy, state_values, cum_rwds, actions, if policy.config["framework"] in ["tf2", "tfe"]:
action_dist, beta) update_term = adv_squared - policy._moving_average_sqd_adv_norm
policy._moving_average_sqd_adv_norm.assign_add(
1e-8 * update_term)
# Exponentially weighted advantages.
c = tf.math.sqrt(policy._moving_average_sqd_adv_norm)
exp_advs = tf.math.exp(beta * (adv / c))
# Static graph.
else:
update_adv_norm = tf1.assign_add(
ref=policy._moving_average_sqd_adv_norm,
value=1e-6 *
(adv_squared - policy._moving_average_sqd_adv_norm))
# Exponentially weighted advantages.
with tf1.control_dependencies([update_adv_norm]):
exp_advs = tf.math.exp(beta * tf.math.divide(
adv, 1e-8 + tf.math.sqrt(
policy._moving_average_sqd_adv_norm)))
exp_advs = tf.stop_gradient(exp_advs)
else:
exp_advs = 1.0
# L = - A * log\pi_\theta(a|s)
logprobs = action_dist.logp(actions)
self.p_loss = -1.0 * tf.reduce_mean(exp_advs * logprobs)
self.total_loss = self.p_loss + vf_loss_coeff * self.v_loss
self.explained_variance = tf.reduce_mean(
explained_variance(cumulative_rewards, value_estimates))
def marwil_loss(policy, model, dist_class, train_batch): def marwil_loss(policy, model, dist_class, train_batch):
model_out, _ = model.from_batch(train_batch) model_out, _ = model.from_batch(train_batch)
action_dist = dist_class(model_out, model) action_dist = dist_class(model_out, model)
state_values = model.value_function() value_estimates = model.value_function()
policy.loss = MARWILLoss(policy, state_values, action_dist, policy.loss = MARWILLoss(policy, value_estimates, action_dist,
train_batch[SampleBatch.ACTIONS], train_batch[SampleBatch.ACTIONS],
train_batch[Postprocessing.ADVANTAGES], train_batch[Postprocessing.ADVANTAGES],
policy.config["vf_coeff"], policy.config["beta"]) policy.config["vf_coeff"], policy.config["beta"])
@ -181,8 +175,8 @@ def marwil_loss(policy, model, dist_class, train_batch):
def stats(policy, train_batch): def stats(policy, train_batch):
return { return {
"policy_loss": policy.loss.p_loss.loss, "policy_loss": policy.loss.p_loss,
"vf_loss": policy.loss.v_loss.loss, "vf_loss": policy.loss.v_loss,
"total_loss": policy.loss.total_loss, "total_loss": policy.loss.total_loss,
"vf_explained_var": policy.loss.explained_variance, "vf_explained_var": policy.loss.explained_variance,
} }
@ -191,8 +185,8 @@ def stats(policy, train_batch):
def setup_mixins(policy, obs_space, action_space, config): def setup_mixins(policy, obs_space, action_space, config):
ValueNetworkMixin.__init__(policy, obs_space, action_space, config) ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
# Set up a tf-var for the moving avg (do this here to make it work with # Set up a tf-var for the moving avg (do this here to make it work with
# eager mode). # eager mode); "c^2" in the paper.
policy._ma_adv_norm = get_variable( policy._moving_average_sqd_adv_norm = get_variable(
100.0, 100.0,
framework="tf", framework="tf",
tf_name="moving_average_of_advantage_norm", tf_name="moving_average_of_advantage_norm",

View file

@ -48,16 +48,17 @@ def marwil_loss(policy, model, dist_class, train_batch):
advantages = train_batch[Postprocessing.ADVANTAGES] advantages = train_batch[Postprocessing.ADVANTAGES]
actions = train_batch[SampleBatch.ACTIONS] actions = train_batch[SampleBatch.ACTIONS]
# Value loss.
policy.v_loss = 0.5 * torch.mean(torch.pow(state_values - advantages, 2.0))
# Policy loss.
# Advantage estimation. # Advantage estimation.
adv = advantages - state_values adv = advantages - state_values
adv_squared = torch.mean(torch.pow(adv, 2.0))
# Value loss.
policy.v_loss = 0.5 * adv_squared
# Policy loss.
# Update averaged advantage norm. # Update averaged advantage norm.
policy.ma_adv_norm.add_( policy.ma_adv_norm.add_(1e-6 * (adv_squared - policy.ma_adv_norm))
1e-6 * (torch.mean(torch.pow(adv, 2.0)) - policy.ma_adv_norm)) # Exponentially weighted advantages.
# #xponentially weighted advantages.
exp_advs = torch.exp(policy.config["beta"] * exp_advs = torch.exp(policy.config["beta"] *
(adv / (1e-8 + torch.pow(policy.ma_adv_norm, 0.5)))) (adv / (1e-8 + torch.pow(policy.ma_adv_norm, 0.5))))
# log\pi_\theta(a|s) # log\pi_\theta(a|s)

View file

@ -1,14 +1,18 @@
import numpy as np
import os import os
from pathlib import Path from pathlib import Path
import unittest import unittest
import ray import ray
import ray.rllib.agents.marwil as marwil import ray.rllib.agents.marwil as marwil
from ray.rllib.utils.framework import try_import_tf from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.utils.test_utils import check_compute_single_action, \ from ray.rllib.offline import JsonReader
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
framework_iterator framework_iterator
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
class TestMARWIL(unittest.TestCase): class TestMARWIL(unittest.TestCase):
@ -70,6 +74,74 @@ class TestMARWIL(unittest.TestCase):
trainer.stop() trainer.stop()
def test_marwil_loss_function(self):
"""
To generate the historic data used in this test case, first run:
$ ./train.py --run=PPO --env=CartPole-v0 \
--stop='{"timesteps_total": 50000}' \
--config='{"output": "/tmp/out", "batch_mode": "complete_episodes"}'
"""
rllib_dir = Path(__file__).parent.parent.parent.parent
print("rllib dir={}".format(rllib_dir))
data_file = os.path.join(rllib_dir, "tests/data/cartpole/small.json")
print("data_file={} exists={}".format(data_file,
os.path.isfile(data_file)))
config = marwil.DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
# Learn from offline data.
config["input"] = [data_file]
for fw in framework_iterator(config, frameworks=["torch", "tf2"]):
reader = JsonReader(inputs=[data_file])
batch = reader.next()
trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
model = policy.model
# Calculate our own expected values (to then compare against the
# agent's loss output).
cummulative_rewards = compute_advantages(
batch, 0.0, config["gamma"], 1.0, False, False)["advantages"]
if fw == "torch":
cummulative_rewards = torch.tensor(cummulative_rewards)
tensor_batch = policy._lazy_tensor_dict(batch)
model_out, _ = model.from_batch(tensor_batch)
vf_estimates = model.value_function()
adv = cummulative_rewards - vf_estimates
if fw == "torch":
adv = adv.detach().cpu().numpy()
adv_squared = np.mean(np.square(adv))
c_2 = 100.0 + 1e-8 * (adv_squared - 100.0)
c = np.sqrt(c_2)
exp_advs = np.exp(config["beta"] * (adv / c))
logp = policy.dist_class(model_out,
model).logp(tensor_batch["actions"])
if fw == "torch":
logp = logp.detach().cpu().numpy()
# Calculate all expected loss components.
expected_vf_loss = 0.5 * adv_squared
expected_pol_loss = -1.0 * np.mean(exp_advs * logp)
expected_loss = \
expected_pol_loss + config["vf_coeff"] * expected_vf_loss
# Calculate the algorithm's loss (to check against our own
# calculation above).
postprocessed_batch = policy.postprocess_trajectory(batch)
loss_func = marwil.marwil_tf_policy.marwil_loss if fw != "torch" \
else marwil.marwil_torch_policy.marwil_loss
loss_out = loss_func(policy, model, policy.dist_class,
policy._lazy_tensor_dict(postprocessed_batch))
# Check all components.
if fw == "torch":
check(policy.v_loss, expected_vf_loss, decimals=4)
check(policy.p_loss, expected_pol_loss, decimals=4)
else:
check(policy.loss.v_loss, expected_vf_loss, decimals=4)
check(policy.loss.p_loss, expected_pol_loss, decimals=4)
check(loss_out, expected_loss, decimals=3)
if __name__ == "__main__": if __name__ == "__main__":
import pytest import pytest