From a65ee92b69c9dfa05defbee17abd7af09103f88e Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Tue, 19 Jan 2021 09:51:05 +0100 Subject: [PATCH] [RLlib] MARWIL loss function test case and cleanup. (#13455) --- rllib/agents/marwil/marwil_tf_policy.py | 120 ++++++++++----------- rllib/agents/marwil/marwil_torch_policy.py | 15 +-- rllib/agents/marwil/tests/test_marwil.py | 76 ++++++++++++- 3 files changed, 139 insertions(+), 72 deletions(-) diff --git a/rllib/agents/marwil/marwil_tf_policy.py b/rllib/agents/marwil/marwil_tf_policy.py index 0be3149fa..44352be4f 100644 --- a/rllib/agents/marwil/marwil_tf_policy.py +++ b/rllib/agents/marwil/marwil_tf_policy.py @@ -1,3 +1,5 @@ +import logging + import ray from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.evaluation.postprocessing import compute_advantages, \ @@ -8,6 +10,8 @@ from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable tf1, tf, tfv = try_import_tf() +logger = logging.getLogger(__name__) + class ValueNetworkMixin: def __init__(self, obs_space, action_space, config): @@ -43,47 +47,6 @@ class ValueNetworkMixin: self._value = value -class ValueLoss: - def __init__(self, state_values, cumulative_rewards): - self.loss = 0.5 * tf.reduce_mean( - tf.math.square(state_values - cumulative_rewards)) - - -class ReweightedImitationLoss: - def __init__(self, policy, state_values, cumulative_rewards, actions, - action_dist, beta): - if beta != 0.0: - # Advantage Estimation. - adv = cumulative_rewards - state_values - - # Update averaged advantage norm. - # Eager. - if policy.config["framework"] in ["tf2", "tfe"]: - policy._ma_adv_norm.assign_add(1e-6 * ( - tf.reduce_mean(tf.math.square(adv)) - policy._ma_adv_norm)) - # Exponentially weighted advantages. - exp_advs = tf.math.exp(beta * tf.math.divide( - adv, 1e-8 + tf.math.sqrt(policy._ma_adv_norm))) - # Static graph. - else: - update_adv_norm = tf1.assign_add( - ref=policy._ma_adv_norm, - value=1e-6 * (tf.reduce_mean(tf.math.square(adv)) - - policy._ma_adv_norm)) - - # Exponentially weighted advantages. - with tf1.control_dependencies([update_adv_norm]): - exp_advs = tf.math.exp(beta * tf.math.divide( - adv, 1e-8 + tf.math.sqrt(policy._ma_adv_norm))) - exp_advs = tf.stop_gradient(exp_advs) - else: - exp_advs = 1.0 - - # L = - A * log\pi_\theta(a|s) - logprobs = action_dist.logp(actions) - self.loss = -1.0 * tf.reduce_mean(exp_advs * logprobs) - - def postprocess_advantages(policy, sample_batch, other_agent_batches=None, @@ -135,43 +98,74 @@ def postprocess_advantages(policy, sample_batch[SampleBatch.REWARDS][-1], *next_state) - # Adds the policy logits, VF preds, and advantages to the batch, - # using GAE ("generalized advantage estimation") or not. + # Adds the "advantages" (which in the case of MARWIL are simply the + # discounted cummulative rewards) to the SampleBatch. return compute_advantages( sample_batch, last_r, policy.config["gamma"], + # We just want the discounted cummulative rewards, so we won't need + # GAE nor critic (use_critic=True: Subtract vf-estimates from returns). use_gae=False, use_critic=False) class MARWILLoss: - def __init__(self, policy, state_values, action_dist, actions, advantages, - vf_loss_coeff, beta): + def __init__(self, policy, value_estimates, action_dist, actions, + cumulative_rewards, vf_loss_coeff, beta): - self.v_loss = self._build_value_loss(state_values, advantages) - self.p_loss = self._build_policy_loss(policy, state_values, advantages, - actions, action_dist, beta) + # Advantage Estimation. + adv = cumulative_rewards - value_estimates + adv_squared = tf.reduce_mean(tf.math.square(adv)) - self.total_loss = self.p_loss.loss + vf_loss_coeff * self.v_loss.loss - explained_var = explained_variance(advantages, state_values) - self.explained_variance = tf.reduce_mean(explained_var) + # Value function's loss term (MSE). + self.v_loss = 0.5 * adv_squared - def _build_value_loss(self, state_values, cum_rwds): - return ValueLoss(state_values, cum_rwds) + if beta != 0.0: + # Perform moving averaging of advantage^2. - def _build_policy_loss(self, policy, state_values, cum_rwds, actions, - action_dist, beta): - return ReweightedImitationLoss(policy, state_values, cum_rwds, actions, - action_dist, beta) + # Update averaged advantage norm. + # Eager. + if policy.config["framework"] in ["tf2", "tfe"]: + update_term = adv_squared - policy._moving_average_sqd_adv_norm + policy._moving_average_sqd_adv_norm.assign_add( + 1e-8 * update_term) + + # Exponentially weighted advantages. + c = tf.math.sqrt(policy._moving_average_sqd_adv_norm) + exp_advs = tf.math.exp(beta * (adv / c)) + # Static graph. + else: + update_adv_norm = tf1.assign_add( + ref=policy._moving_average_sqd_adv_norm, + value=1e-6 * + (adv_squared - policy._moving_average_sqd_adv_norm)) + + # Exponentially weighted advantages. + with tf1.control_dependencies([update_adv_norm]): + exp_advs = tf.math.exp(beta * tf.math.divide( + adv, 1e-8 + tf.math.sqrt( + policy._moving_average_sqd_adv_norm))) + exp_advs = tf.stop_gradient(exp_advs) + else: + exp_advs = 1.0 + + # L = - A * log\pi_\theta(a|s) + logprobs = action_dist.logp(actions) + self.p_loss = -1.0 * tf.reduce_mean(exp_advs * logprobs) + + self.total_loss = self.p_loss + vf_loss_coeff * self.v_loss + + self.explained_variance = tf.reduce_mean( + explained_variance(cumulative_rewards, value_estimates)) def marwil_loss(policy, model, dist_class, train_batch): model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) - state_values = model.value_function() + value_estimates = model.value_function() - policy.loss = MARWILLoss(policy, state_values, action_dist, + policy.loss = MARWILLoss(policy, value_estimates, action_dist, train_batch[SampleBatch.ACTIONS], train_batch[Postprocessing.ADVANTAGES], policy.config["vf_coeff"], policy.config["beta"]) @@ -181,8 +175,8 @@ def marwil_loss(policy, model, dist_class, train_batch): def stats(policy, train_batch): return { - "policy_loss": policy.loss.p_loss.loss, - "vf_loss": policy.loss.v_loss.loss, + "policy_loss": policy.loss.p_loss, + "vf_loss": policy.loss.v_loss, "total_loss": policy.loss.total_loss, "vf_explained_var": policy.loss.explained_variance, } @@ -191,8 +185,8 @@ def stats(policy, train_batch): def setup_mixins(policy, obs_space, action_space, config): ValueNetworkMixin.__init__(policy, obs_space, action_space, config) # Set up a tf-var for the moving avg (do this here to make it work with - # eager mode). - policy._ma_adv_norm = get_variable( + # eager mode); "c^2" in the paper. + policy._moving_average_sqd_adv_norm = get_variable( 100.0, framework="tf", tf_name="moving_average_of_advantage_norm", diff --git a/rllib/agents/marwil/marwil_torch_policy.py b/rllib/agents/marwil/marwil_torch_policy.py index 29ea822d2..ef3558378 100644 --- a/rllib/agents/marwil/marwil_torch_policy.py +++ b/rllib/agents/marwil/marwil_torch_policy.py @@ -48,16 +48,17 @@ def marwil_loss(policy, model, dist_class, train_batch): advantages = train_batch[Postprocessing.ADVANTAGES] actions = train_batch[SampleBatch.ACTIONS] - # Value loss. - policy.v_loss = 0.5 * torch.mean(torch.pow(state_values - advantages, 2.0)) - - # Policy loss. # Advantage estimation. adv = advantages - state_values + adv_squared = torch.mean(torch.pow(adv, 2.0)) + + # Value loss. + policy.v_loss = 0.5 * adv_squared + + # Policy loss. # Update averaged advantage norm. - policy.ma_adv_norm.add_( - 1e-6 * (torch.mean(torch.pow(adv, 2.0)) - policy.ma_adv_norm)) - # #xponentially weighted advantages. + policy.ma_adv_norm.add_(1e-6 * (adv_squared - policy.ma_adv_norm)) + # Exponentially weighted advantages. exp_advs = torch.exp(policy.config["beta"] * (adv / (1e-8 + torch.pow(policy.ma_adv_norm, 0.5)))) # log\pi_\theta(a|s) diff --git a/rllib/agents/marwil/tests/test_marwil.py b/rllib/agents/marwil/tests/test_marwil.py index acd0fe725..afb3ec9ee 100644 --- a/rllib/agents/marwil/tests/test_marwil.py +++ b/rllib/agents/marwil/tests/test_marwil.py @@ -1,14 +1,18 @@ +import numpy as np import os from pathlib import Path import unittest import ray import ray.rllib.agents.marwil as marwil -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.test_utils import check_compute_single_action, \ +from ray.rllib.evaluation.postprocessing import compute_advantages +from ray.rllib.offline import JsonReader +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.test_utils import check, check_compute_single_action, \ framework_iterator tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() class TestMARWIL(unittest.TestCase): @@ -70,6 +74,74 @@ class TestMARWIL(unittest.TestCase): trainer.stop() + def test_marwil_loss_function(self): + """ + To generate the historic data used in this test case, first run: + $ ./train.py --run=PPO --env=CartPole-v0 \ + --stop='{"timesteps_total": 50000}' \ + --config='{"output": "/tmp/out", "batch_mode": "complete_episodes"}' + """ + rllib_dir = Path(__file__).parent.parent.parent.parent + print("rllib dir={}".format(rllib_dir)) + data_file = os.path.join(rllib_dir, "tests/data/cartpole/small.json") + print("data_file={} exists={}".format(data_file, + os.path.isfile(data_file))) + config = marwil.DEFAULT_CONFIG.copy() + config["num_workers"] = 0 # Run locally. + # Learn from offline data. + config["input"] = [data_file] + + for fw in framework_iterator(config, frameworks=["torch", "tf2"]): + reader = JsonReader(inputs=[data_file]) + batch = reader.next() + + trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0") + policy = trainer.get_policy() + model = policy.model + + # Calculate our own expected values (to then compare against the + # agent's loss output). + cummulative_rewards = compute_advantages( + batch, 0.0, config["gamma"], 1.0, False, False)["advantages"] + if fw == "torch": + cummulative_rewards = torch.tensor(cummulative_rewards) + tensor_batch = policy._lazy_tensor_dict(batch) + model_out, _ = model.from_batch(tensor_batch) + vf_estimates = model.value_function() + adv = cummulative_rewards - vf_estimates + if fw == "torch": + adv = adv.detach().cpu().numpy() + adv_squared = np.mean(np.square(adv)) + c_2 = 100.0 + 1e-8 * (adv_squared - 100.0) + c = np.sqrt(c_2) + exp_advs = np.exp(config["beta"] * (adv / c)) + logp = policy.dist_class(model_out, + model).logp(tensor_batch["actions"]) + if fw == "torch": + logp = logp.detach().cpu().numpy() + # Calculate all expected loss components. + expected_vf_loss = 0.5 * adv_squared + expected_pol_loss = -1.0 * np.mean(exp_advs * logp) + expected_loss = \ + expected_pol_loss + config["vf_coeff"] * expected_vf_loss + + # Calculate the algorithm's loss (to check against our own + # calculation above). + postprocessed_batch = policy.postprocess_trajectory(batch) + loss_func = marwil.marwil_tf_policy.marwil_loss if fw != "torch" \ + else marwil.marwil_torch_policy.marwil_loss + loss_out = loss_func(policy, model, policy.dist_class, + policy._lazy_tensor_dict(postprocessed_batch)) + + # Check all components. + if fw == "torch": + check(policy.v_loss, expected_vf_loss, decimals=4) + check(policy.p_loss, expected_pol_loss, decimals=4) + else: + check(policy.loss.v_loss, expected_vf_loss, decimals=4) + check(policy.loss.p_loss, expected_pol_loss, decimals=4) + check(loss_out, expected_loss, decimals=3) + if __name__ == "__main__": import pytest