mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] MARWIL loss function test case and cleanup. (#13455)
This commit is contained in:
parent
2506a6cd0e
commit
a65ee92b69
3 changed files with 139 additions and 72 deletions
|
@ -1,3 +1,5 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||||
|
@ -8,6 +10,8 @@ from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable
|
||||||
|
|
||||||
tf1, tf, tfv = try_import_tf()
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ValueNetworkMixin:
|
class ValueNetworkMixin:
|
||||||
def __init__(self, obs_space, action_space, config):
|
def __init__(self, obs_space, action_space, config):
|
||||||
|
@ -43,47 +47,6 @@ class ValueNetworkMixin:
|
||||||
self._value = value
|
self._value = value
|
||||||
|
|
||||||
|
|
||||||
class ValueLoss:
|
|
||||||
def __init__(self, state_values, cumulative_rewards):
|
|
||||||
self.loss = 0.5 * tf.reduce_mean(
|
|
||||||
tf.math.square(state_values - cumulative_rewards))
|
|
||||||
|
|
||||||
|
|
||||||
class ReweightedImitationLoss:
|
|
||||||
def __init__(self, policy, state_values, cumulative_rewards, actions,
|
|
||||||
action_dist, beta):
|
|
||||||
if beta != 0.0:
|
|
||||||
# Advantage Estimation.
|
|
||||||
adv = cumulative_rewards - state_values
|
|
||||||
|
|
||||||
# Update averaged advantage norm.
|
|
||||||
# Eager.
|
|
||||||
if policy.config["framework"] in ["tf2", "tfe"]:
|
|
||||||
policy._ma_adv_norm.assign_add(1e-6 * (
|
|
||||||
tf.reduce_mean(tf.math.square(adv)) - policy._ma_adv_norm))
|
|
||||||
# Exponentially weighted advantages.
|
|
||||||
exp_advs = tf.math.exp(beta * tf.math.divide(
|
|
||||||
adv, 1e-8 + tf.math.sqrt(policy._ma_adv_norm)))
|
|
||||||
# Static graph.
|
|
||||||
else:
|
|
||||||
update_adv_norm = tf1.assign_add(
|
|
||||||
ref=policy._ma_adv_norm,
|
|
||||||
value=1e-6 * (tf.reduce_mean(tf.math.square(adv)) -
|
|
||||||
policy._ma_adv_norm))
|
|
||||||
|
|
||||||
# Exponentially weighted advantages.
|
|
||||||
with tf1.control_dependencies([update_adv_norm]):
|
|
||||||
exp_advs = tf.math.exp(beta * tf.math.divide(
|
|
||||||
adv, 1e-8 + tf.math.sqrt(policy._ma_adv_norm)))
|
|
||||||
exp_advs = tf.stop_gradient(exp_advs)
|
|
||||||
else:
|
|
||||||
exp_advs = 1.0
|
|
||||||
|
|
||||||
# L = - A * log\pi_\theta(a|s)
|
|
||||||
logprobs = action_dist.logp(actions)
|
|
||||||
self.loss = -1.0 * tf.reduce_mean(exp_advs * logprobs)
|
|
||||||
|
|
||||||
|
|
||||||
def postprocess_advantages(policy,
|
def postprocess_advantages(policy,
|
||||||
sample_batch,
|
sample_batch,
|
||||||
other_agent_batches=None,
|
other_agent_batches=None,
|
||||||
|
@ -135,43 +98,74 @@ def postprocess_advantages(policy,
|
||||||
sample_batch[SampleBatch.REWARDS][-1],
|
sample_batch[SampleBatch.REWARDS][-1],
|
||||||
*next_state)
|
*next_state)
|
||||||
|
|
||||||
# Adds the policy logits, VF preds, and advantages to the batch,
|
# Adds the "advantages" (which in the case of MARWIL are simply the
|
||||||
# using GAE ("generalized advantage estimation") or not.
|
# discounted cummulative rewards) to the SampleBatch.
|
||||||
return compute_advantages(
|
return compute_advantages(
|
||||||
sample_batch,
|
sample_batch,
|
||||||
last_r,
|
last_r,
|
||||||
policy.config["gamma"],
|
policy.config["gamma"],
|
||||||
|
# We just want the discounted cummulative rewards, so we won't need
|
||||||
|
# GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
|
||||||
use_gae=False,
|
use_gae=False,
|
||||||
use_critic=False)
|
use_critic=False)
|
||||||
|
|
||||||
|
|
||||||
class MARWILLoss:
|
class MARWILLoss:
|
||||||
def __init__(self, policy, state_values, action_dist, actions, advantages,
|
def __init__(self, policy, value_estimates, action_dist, actions,
|
||||||
vf_loss_coeff, beta):
|
cumulative_rewards, vf_loss_coeff, beta):
|
||||||
|
|
||||||
self.v_loss = self._build_value_loss(state_values, advantages)
|
# Advantage Estimation.
|
||||||
self.p_loss = self._build_policy_loss(policy, state_values, advantages,
|
adv = cumulative_rewards - value_estimates
|
||||||
actions, action_dist, beta)
|
adv_squared = tf.reduce_mean(tf.math.square(adv))
|
||||||
|
|
||||||
self.total_loss = self.p_loss.loss + vf_loss_coeff * self.v_loss.loss
|
# Value function's loss term (MSE).
|
||||||
explained_var = explained_variance(advantages, state_values)
|
self.v_loss = 0.5 * adv_squared
|
||||||
self.explained_variance = tf.reduce_mean(explained_var)
|
|
||||||
|
|
||||||
def _build_value_loss(self, state_values, cum_rwds):
|
if beta != 0.0:
|
||||||
return ValueLoss(state_values, cum_rwds)
|
# Perform moving averaging of advantage^2.
|
||||||
|
|
||||||
def _build_policy_loss(self, policy, state_values, cum_rwds, actions,
|
# Update averaged advantage norm.
|
||||||
action_dist, beta):
|
# Eager.
|
||||||
return ReweightedImitationLoss(policy, state_values, cum_rwds, actions,
|
if policy.config["framework"] in ["tf2", "tfe"]:
|
||||||
action_dist, beta)
|
update_term = adv_squared - policy._moving_average_sqd_adv_norm
|
||||||
|
policy._moving_average_sqd_adv_norm.assign_add(
|
||||||
|
1e-8 * update_term)
|
||||||
|
|
||||||
|
# Exponentially weighted advantages.
|
||||||
|
c = tf.math.sqrt(policy._moving_average_sqd_adv_norm)
|
||||||
|
exp_advs = tf.math.exp(beta * (adv / c))
|
||||||
|
# Static graph.
|
||||||
|
else:
|
||||||
|
update_adv_norm = tf1.assign_add(
|
||||||
|
ref=policy._moving_average_sqd_adv_norm,
|
||||||
|
value=1e-6 *
|
||||||
|
(adv_squared - policy._moving_average_sqd_adv_norm))
|
||||||
|
|
||||||
|
# Exponentially weighted advantages.
|
||||||
|
with tf1.control_dependencies([update_adv_norm]):
|
||||||
|
exp_advs = tf.math.exp(beta * tf.math.divide(
|
||||||
|
adv, 1e-8 + tf.math.sqrt(
|
||||||
|
policy._moving_average_sqd_adv_norm)))
|
||||||
|
exp_advs = tf.stop_gradient(exp_advs)
|
||||||
|
else:
|
||||||
|
exp_advs = 1.0
|
||||||
|
|
||||||
|
# L = - A * log\pi_\theta(a|s)
|
||||||
|
logprobs = action_dist.logp(actions)
|
||||||
|
self.p_loss = -1.0 * tf.reduce_mean(exp_advs * logprobs)
|
||||||
|
|
||||||
|
self.total_loss = self.p_loss + vf_loss_coeff * self.v_loss
|
||||||
|
|
||||||
|
self.explained_variance = tf.reduce_mean(
|
||||||
|
explained_variance(cumulative_rewards, value_estimates))
|
||||||
|
|
||||||
|
|
||||||
def marwil_loss(policy, model, dist_class, train_batch):
|
def marwil_loss(policy, model, dist_class, train_batch):
|
||||||
model_out, _ = model.from_batch(train_batch)
|
model_out, _ = model.from_batch(train_batch)
|
||||||
action_dist = dist_class(model_out, model)
|
action_dist = dist_class(model_out, model)
|
||||||
state_values = model.value_function()
|
value_estimates = model.value_function()
|
||||||
|
|
||||||
policy.loss = MARWILLoss(policy, state_values, action_dist,
|
policy.loss = MARWILLoss(policy, value_estimates, action_dist,
|
||||||
train_batch[SampleBatch.ACTIONS],
|
train_batch[SampleBatch.ACTIONS],
|
||||||
train_batch[Postprocessing.ADVANTAGES],
|
train_batch[Postprocessing.ADVANTAGES],
|
||||||
policy.config["vf_coeff"], policy.config["beta"])
|
policy.config["vf_coeff"], policy.config["beta"])
|
||||||
|
@ -181,8 +175,8 @@ def marwil_loss(policy, model, dist_class, train_batch):
|
||||||
|
|
||||||
def stats(policy, train_batch):
|
def stats(policy, train_batch):
|
||||||
return {
|
return {
|
||||||
"policy_loss": policy.loss.p_loss.loss,
|
"policy_loss": policy.loss.p_loss,
|
||||||
"vf_loss": policy.loss.v_loss.loss,
|
"vf_loss": policy.loss.v_loss,
|
||||||
"total_loss": policy.loss.total_loss,
|
"total_loss": policy.loss.total_loss,
|
||||||
"vf_explained_var": policy.loss.explained_variance,
|
"vf_explained_var": policy.loss.explained_variance,
|
||||||
}
|
}
|
||||||
|
@ -191,8 +185,8 @@ def stats(policy, train_batch):
|
||||||
def setup_mixins(policy, obs_space, action_space, config):
|
def setup_mixins(policy, obs_space, action_space, config):
|
||||||
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||||
# Set up a tf-var for the moving avg (do this here to make it work with
|
# Set up a tf-var for the moving avg (do this here to make it work with
|
||||||
# eager mode).
|
# eager mode); "c^2" in the paper.
|
||||||
policy._ma_adv_norm = get_variable(
|
policy._moving_average_sqd_adv_norm = get_variable(
|
||||||
100.0,
|
100.0,
|
||||||
framework="tf",
|
framework="tf",
|
||||||
tf_name="moving_average_of_advantage_norm",
|
tf_name="moving_average_of_advantage_norm",
|
||||||
|
|
|
@ -48,16 +48,17 @@ def marwil_loss(policy, model, dist_class, train_batch):
|
||||||
advantages = train_batch[Postprocessing.ADVANTAGES]
|
advantages = train_batch[Postprocessing.ADVANTAGES]
|
||||||
actions = train_batch[SampleBatch.ACTIONS]
|
actions = train_batch[SampleBatch.ACTIONS]
|
||||||
|
|
||||||
# Value loss.
|
|
||||||
policy.v_loss = 0.5 * torch.mean(torch.pow(state_values - advantages, 2.0))
|
|
||||||
|
|
||||||
# Policy loss.
|
|
||||||
# Advantage estimation.
|
# Advantage estimation.
|
||||||
adv = advantages - state_values
|
adv = advantages - state_values
|
||||||
|
adv_squared = torch.mean(torch.pow(adv, 2.0))
|
||||||
|
|
||||||
|
# Value loss.
|
||||||
|
policy.v_loss = 0.5 * adv_squared
|
||||||
|
|
||||||
|
# Policy loss.
|
||||||
# Update averaged advantage norm.
|
# Update averaged advantage norm.
|
||||||
policy.ma_adv_norm.add_(
|
policy.ma_adv_norm.add_(1e-6 * (adv_squared - policy.ma_adv_norm))
|
||||||
1e-6 * (torch.mean(torch.pow(adv, 2.0)) - policy.ma_adv_norm))
|
# Exponentially weighted advantages.
|
||||||
# #xponentially weighted advantages.
|
|
||||||
exp_advs = torch.exp(policy.config["beta"] *
|
exp_advs = torch.exp(policy.config["beta"] *
|
||||||
(adv / (1e-8 + torch.pow(policy.ma_adv_norm, 0.5))))
|
(adv / (1e-8 + torch.pow(policy.ma_adv_norm, 0.5))))
|
||||||
# log\pi_\theta(a|s)
|
# log\pi_\theta(a|s)
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.rllib.agents.marwil as marwil
|
import ray.rllib.agents.marwil as marwil
|
||||||
from ray.rllib.utils.framework import try_import_tf
|
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||||
from ray.rllib.utils.test_utils import check_compute_single_action, \
|
from ray.rllib.offline import JsonReader
|
||||||
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||||
|
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
|
||||||
framework_iterator
|
framework_iterator
|
||||||
|
|
||||||
tf1, tf, tfv = try_import_tf()
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
torch, _ = try_import_torch()
|
||||||
|
|
||||||
|
|
||||||
class TestMARWIL(unittest.TestCase):
|
class TestMARWIL(unittest.TestCase):
|
||||||
|
@ -70,6 +74,74 @@ class TestMARWIL(unittest.TestCase):
|
||||||
|
|
||||||
trainer.stop()
|
trainer.stop()
|
||||||
|
|
||||||
|
def test_marwil_loss_function(self):
|
||||||
|
"""
|
||||||
|
To generate the historic data used in this test case, first run:
|
||||||
|
$ ./train.py --run=PPO --env=CartPole-v0 \
|
||||||
|
--stop='{"timesteps_total": 50000}' \
|
||||||
|
--config='{"output": "/tmp/out", "batch_mode": "complete_episodes"}'
|
||||||
|
"""
|
||||||
|
rllib_dir = Path(__file__).parent.parent.parent.parent
|
||||||
|
print("rllib dir={}".format(rllib_dir))
|
||||||
|
data_file = os.path.join(rllib_dir, "tests/data/cartpole/small.json")
|
||||||
|
print("data_file={} exists={}".format(data_file,
|
||||||
|
os.path.isfile(data_file)))
|
||||||
|
config = marwil.DEFAULT_CONFIG.copy()
|
||||||
|
config["num_workers"] = 0 # Run locally.
|
||||||
|
# Learn from offline data.
|
||||||
|
config["input"] = [data_file]
|
||||||
|
|
||||||
|
for fw in framework_iterator(config, frameworks=["torch", "tf2"]):
|
||||||
|
reader = JsonReader(inputs=[data_file])
|
||||||
|
batch = reader.next()
|
||||||
|
|
||||||
|
trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0")
|
||||||
|
policy = trainer.get_policy()
|
||||||
|
model = policy.model
|
||||||
|
|
||||||
|
# Calculate our own expected values (to then compare against the
|
||||||
|
# agent's loss output).
|
||||||
|
cummulative_rewards = compute_advantages(
|
||||||
|
batch, 0.0, config["gamma"], 1.0, False, False)["advantages"]
|
||||||
|
if fw == "torch":
|
||||||
|
cummulative_rewards = torch.tensor(cummulative_rewards)
|
||||||
|
tensor_batch = policy._lazy_tensor_dict(batch)
|
||||||
|
model_out, _ = model.from_batch(tensor_batch)
|
||||||
|
vf_estimates = model.value_function()
|
||||||
|
adv = cummulative_rewards - vf_estimates
|
||||||
|
if fw == "torch":
|
||||||
|
adv = adv.detach().cpu().numpy()
|
||||||
|
adv_squared = np.mean(np.square(adv))
|
||||||
|
c_2 = 100.0 + 1e-8 * (adv_squared - 100.0)
|
||||||
|
c = np.sqrt(c_2)
|
||||||
|
exp_advs = np.exp(config["beta"] * (adv / c))
|
||||||
|
logp = policy.dist_class(model_out,
|
||||||
|
model).logp(tensor_batch["actions"])
|
||||||
|
if fw == "torch":
|
||||||
|
logp = logp.detach().cpu().numpy()
|
||||||
|
# Calculate all expected loss components.
|
||||||
|
expected_vf_loss = 0.5 * adv_squared
|
||||||
|
expected_pol_loss = -1.0 * np.mean(exp_advs * logp)
|
||||||
|
expected_loss = \
|
||||||
|
expected_pol_loss + config["vf_coeff"] * expected_vf_loss
|
||||||
|
|
||||||
|
# Calculate the algorithm's loss (to check against our own
|
||||||
|
# calculation above).
|
||||||
|
postprocessed_batch = policy.postprocess_trajectory(batch)
|
||||||
|
loss_func = marwil.marwil_tf_policy.marwil_loss if fw != "torch" \
|
||||||
|
else marwil.marwil_torch_policy.marwil_loss
|
||||||
|
loss_out = loss_func(policy, model, policy.dist_class,
|
||||||
|
policy._lazy_tensor_dict(postprocessed_batch))
|
||||||
|
|
||||||
|
# Check all components.
|
||||||
|
if fw == "torch":
|
||||||
|
check(policy.v_loss, expected_vf_loss, decimals=4)
|
||||||
|
check(policy.p_loss, expected_pol_loss, decimals=4)
|
||||||
|
else:
|
||||||
|
check(policy.loss.v_loss, expected_vf_loss, decimals=4)
|
||||||
|
check(policy.loss.p_loss, expected_pol_loss, decimals=4)
|
||||||
|
check(loss_out, expected_loss, decimals=3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
|
|
Loading…
Add table
Reference in a new issue