2020-05-12 10:14:05 -07:00
|
|
|
import copy
|
2020-01-21 08:06:50 +01:00
|
|
|
import numpy as np
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
import ray
|
2020-11-26 13:14:11 +01:00
|
|
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
2020-01-21 08:06:50 +01:00
|
|
|
import ray.rllib.agents.ppo as ppo
|
2021-01-19 14:22:36 +01:00
|
|
|
from ray.rllib.agents.ppo.ppo_tf_policy import ppo_surrogate_loss as \
|
|
|
|
ppo_surrogate_loss_tf
|
2021-11-15 16:11:35 -08:00
|
|
|
from ray.rllib.agents.ppo.ppo_torch_policy import ppo_surrogate_loss as \
|
|
|
|
ppo_surrogate_loss_torch
|
2021-01-19 14:22:36 +01:00
|
|
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
|
|
|
|
Postprocessing
|
2020-01-21 08:06:50 +01:00
|
|
|
from ray.rllib.models.tf.tf_action_dist import Categorical
|
|
|
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|
|
|
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
2021-05-20 18:15:10 +02:00
|
|
|
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
2021-09-30 16:39:05 +02:00
|
|
|
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \
|
|
|
|
LEARNER_STATS_KEY
|
2020-01-21 08:06:50 +01:00
|
|
|
from ray.rllib.utils.numpy import fc
|
2021-09-30 16:39:05 +02:00
|
|
|
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
|
|
|
|
check_train_results, framework_iterator
|
2020-01-21 08:06:50 +01:00
|
|
|
|
2020-05-12 10:14:05 -07:00
|
|
|
# Fake CartPole episode of n time steps.
|
2021-04-27 10:44:54 +02:00
|
|
|
FAKE_BATCH = SampleBatch({
|
2020-11-12 16:27:34 +01:00
|
|
|
SampleBatch.OBS: np.array(
|
2020-05-12 10:14:05 -07:00
|
|
|
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
|
|
|
|
dtype=np.float32),
|
|
|
|
SampleBatch.ACTIONS: np.array([0, 1, 1]),
|
|
|
|
SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]),
|
|
|
|
SampleBatch.REWARDS: np.array([1.0, -1.0, .5], dtype=np.float32),
|
|
|
|
SampleBatch.PREV_REWARDS: np.array([1.0, -1.0, .5], dtype=np.float32),
|
|
|
|
SampleBatch.DONES: np.array([False, False, True]),
|
|
|
|
SampleBatch.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32),
|
|
|
|
SampleBatch.ACTION_DIST_INPUTS: np.array(
|
|
|
|
[[-2., 0.5], [-3., -0.3], [-0.1, 2.5]], dtype=np.float32),
|
|
|
|
SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32),
|
2020-11-12 16:27:34 +01:00
|
|
|
SampleBatch.EPS_ID: np.array([0, 0, 0]),
|
|
|
|
SampleBatch.AGENT_INDEX: np.array([0, 0, 0]),
|
2021-04-27 10:44:54 +02:00
|
|
|
})
|
2020-05-12 10:14:05 -07:00
|
|
|
|
2020-01-21 08:06:50 +01:00
|
|
|
|
2020-11-26 13:14:11 +01:00
|
|
|
class MyCallbacks(DefaultCallbacks):
|
|
|
|
@staticmethod
|
|
|
|
def _check_lr_torch(policy, policy_id):
|
|
|
|
for j, opt in enumerate(policy._optimizers):
|
|
|
|
for p in opt.param_groups:
|
|
|
|
assert p["lr"] == policy.cur_lr, "LR scheduling error!"
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _check_lr_tf(policy, policy_id):
|
|
|
|
lr = policy.cur_lr
|
|
|
|
sess = policy.get_session()
|
|
|
|
if sess:
|
|
|
|
lr = sess.run(lr)
|
|
|
|
optim_lr = sess.run(policy._optimizer._lr)
|
|
|
|
else:
|
|
|
|
lr = lr.numpy()
|
|
|
|
optim_lr = policy._optimizer.lr.numpy()
|
|
|
|
assert lr == optim_lr, "LR scheduling error!"
|
|
|
|
|
|
|
|
def on_train_result(self, *, trainer, result: dict, **kwargs):
|
2021-09-30 16:39:05 +02:00
|
|
|
stats = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][
|
|
|
|
LEARNER_STATS_KEY]
|
2021-05-20 18:15:10 +02:00
|
|
|
# Learning rate should go to 0 after 1 iter.
|
|
|
|
check(stats["cur_lr"], 5e-5 if trainer.iteration == 1 else 0.0)
|
|
|
|
# Entropy coeff goes to 0.05, then 0.0 (per iter).
|
|
|
|
check(stats["entropy_coeff"], 0.1 if trainer.iteration == 1 else 0.05)
|
|
|
|
|
2020-11-26 13:14:11 +01:00
|
|
|
trainer.workers.foreach_policy(self._check_lr_torch if trainer.config[
|
|
|
|
"framework"] == "torch" else self._check_lr_tf)
|
|
|
|
|
|
|
|
|
2020-01-21 08:06:50 +01:00
|
|
|
class TestPPO(unittest.TestCase):
|
2020-03-12 04:39:47 +01:00
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
2020-10-02 23:07:44 +02:00
|
|
|
ray.init()
|
2020-01-21 08:06:50 +01:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ray.shutdown()
|
2020-01-21 08:06:50 +01:00
|
|
|
|
2021-05-20 18:15:10 +02:00
|
|
|
def test_ppo_compilation_and_schedule_mixins(self):
|
2020-06-27 20:50:01 +02:00
|
|
|
"""Test whether a PPOTrainer can be built with all frameworks."""
|
2020-05-12 10:14:05 -07:00
|
|
|
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
|
2020-12-09 01:41:45 +01:00
|
|
|
# For checking lr-schedule correctness.
|
2020-11-26 13:14:11 +01:00
|
|
|
config["callbacks"] = MyCallbacks
|
|
|
|
|
2020-05-27 16:19:13 +02:00
|
|
|
config["num_workers"] = 1
|
2020-06-27 20:50:01 +02:00
|
|
|
config["num_sgd_iter"] = 2
|
|
|
|
# Settings in case we use an LSTM.
|
|
|
|
config["model"]["lstm_cell_size"] = 10
|
|
|
|
config["model"]["max_seq_len"] = 20
|
2021-04-30 19:26:30 +02:00
|
|
|
# Use default-native keras models whenever possible.
|
2021-10-09 00:11:53 +02:00
|
|
|
# config["model"]["_use_default_native_models"] = True
|
2021-04-27 10:44:54 +02:00
|
|
|
|
2021-05-20 18:15:10 +02:00
|
|
|
# Setup lr- and entropy schedules for testing.
|
|
|
|
config["lr_schedule"] = [[0, config["lr"]], [128, 0.0]]
|
|
|
|
# Set entropy_coeff to a faulty value to proof that it'll get
|
|
|
|
# overridden by the schedule below (which is expected).
|
|
|
|
config["entropy_coeff"] = 100.0
|
|
|
|
config["entropy_coeff_schedule"] = [[0, 0.1], [256, 0.0]]
|
|
|
|
|
2020-06-27 20:50:01 +02:00
|
|
|
config["train_batch_size"] = 128
|
2021-02-18 21:36:32 +01:00
|
|
|
# Test with compression.
|
|
|
|
config["compress_observations"] = True
|
2021-05-20 18:15:10 +02:00
|
|
|
num_iterations = 2
|
2020-01-21 08:06:50 +01:00
|
|
|
|
2021-11-05 16:10:00 +01:00
|
|
|
for fw in framework_iterator(config, with_eager_tracing=True):
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
for env in ["FrozenLake-v1", "MsPacmanNoFrameskip-v4"]:
|
2020-06-27 20:50:01 +02:00
|
|
|
print("Env={}".format(env))
|
2021-04-30 19:26:30 +02:00
|
|
|
for lstm in [True, False]:
|
2020-06-27 20:50:01 +02:00
|
|
|
print("LSTM={}".format(lstm))
|
|
|
|
config["model"]["use_lstm"] = lstm
|
2020-11-25 20:27:46 +01:00
|
|
|
config["model"]["lstm_use_prev_action"] = lstm
|
|
|
|
config["model"]["lstm_use_prev_reward"] = lstm
|
2021-04-27 10:44:54 +02:00
|
|
|
|
2020-06-27 20:50:01 +02:00
|
|
|
trainer = ppo.PPOTrainer(config=config, env=env)
|
2021-05-20 18:15:10 +02:00
|
|
|
policy = trainer.get_policy()
|
|
|
|
entropy_coeff = trainer.get_policy().entropy_coeff
|
|
|
|
lr = policy.cur_lr
|
|
|
|
if fw == "tf":
|
|
|
|
entropy_coeff, lr = policy.get_session().run(
|
|
|
|
[entropy_coeff, lr])
|
|
|
|
check(entropy_coeff, 0.1)
|
|
|
|
check(lr, config["lr"])
|
|
|
|
|
2020-06-27 20:50:01 +02:00
|
|
|
for i in range(num_iterations):
|
2021-09-30 16:39:05 +02:00
|
|
|
results = trainer.train()
|
|
|
|
check_train_results(results)
|
|
|
|
print(results)
|
2021-05-20 18:15:10 +02:00
|
|
|
|
2020-06-27 20:50:01 +02:00
|
|
|
check_compute_single_action(
|
|
|
|
trainer,
|
|
|
|
include_prev_action_reward=True,
|
|
|
|
include_state=lstm)
|
|
|
|
trainer.stop()
|
2020-01-21 08:06:50 +01:00
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
def test_ppo_exploration_setup(self):
|
|
|
|
"""Tests, whether PPO runs with different exploration setups."""
|
2020-05-12 10:14:05 -07:00
|
|
|
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
|
2020-02-19 21:18:45 +01:00
|
|
|
config["num_workers"] = 0 # Run locally.
|
|
|
|
config["env_config"] = {"is_slippery": False, "map_name": "4x4"}
|
|
|
|
obs = np.array(0)
|
|
|
|
|
|
|
|
# Test against all frameworks.
|
2020-04-03 21:24:25 +02:00
|
|
|
for fw in framework_iterator(config):
|
2020-02-19 21:18:45 +01:00
|
|
|
# Default Agent should be setup with StochasticSampling.
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
trainer = ppo.PPOTrainer(config=config, env="FrozenLake-v1")
|
2020-02-19 21:18:45 +01:00
|
|
|
# explore=False, always expect the same (deterministic) action.
|
2021-06-30 12:32:11 +02:00
|
|
|
a_ = trainer.compute_single_action(
|
2020-02-19 21:18:45 +01:00
|
|
|
obs,
|
|
|
|
explore=False,
|
|
|
|
prev_action=np.array(2),
|
|
|
|
prev_reward=np.array(1.0))
|
|
|
|
# Test whether this is really the argmax action over the logits.
|
|
|
|
if fw != "tf":
|
|
|
|
last_out = trainer.get_policy().model.last_output()
|
2020-11-12 16:27:34 +01:00
|
|
|
if fw == "torch":
|
|
|
|
check(a_, np.argmax(last_out.detach().cpu().numpy(), 1)[0])
|
|
|
|
else:
|
|
|
|
check(a_, np.argmax(last_out.numpy(), 1)[0])
|
2020-02-19 21:18:45 +01:00
|
|
|
for _ in range(50):
|
2021-06-30 12:32:11 +02:00
|
|
|
a = trainer.compute_single_action(
|
2020-02-19 21:18:45 +01:00
|
|
|
obs,
|
|
|
|
explore=False,
|
|
|
|
prev_action=np.array(2),
|
|
|
|
prev_reward=np.array(1.0))
|
|
|
|
check(a, a_)
|
|
|
|
|
|
|
|
# With explore=True (default), expect stochastic actions.
|
|
|
|
actions = []
|
|
|
|
for _ in range(300):
|
|
|
|
actions.append(
|
2021-06-30 12:32:11 +02:00
|
|
|
trainer.compute_single_action(
|
2020-02-19 21:18:45 +01:00
|
|
|
obs,
|
|
|
|
prev_action=np.array(2),
|
|
|
|
prev_reward=np.array(1.0)))
|
|
|
|
check(np.mean(actions), 1.5, atol=0.2)
|
2020-06-27 20:50:01 +02:00
|
|
|
trainer.stop()
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2020-05-12 10:14:05 -07:00
|
|
|
def test_ppo_free_log_std(self):
|
|
|
|
"""Tests the free log std option works."""
|
|
|
|
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
|
|
|
|
config["num_workers"] = 0 # Run locally.
|
|
|
|
config["gamma"] = 0.99
|
|
|
|
config["model"]["fcnet_hiddens"] = [10]
|
|
|
|
config["model"]["fcnet_activation"] = "linear"
|
|
|
|
config["model"]["free_log_std"] = True
|
2021-01-19 09:51:35 +01:00
|
|
|
config["model"]["vf_share_layers"] = True
|
2020-05-12 10:14:05 -07:00
|
|
|
|
|
|
|
for fw, sess in framework_iterator(config, session=True):
|
|
|
|
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
|
|
|
policy = trainer.get_policy()
|
|
|
|
|
|
|
|
# Check the free log std var is created.
|
|
|
|
if fw == "torch":
|
|
|
|
matching = [
|
|
|
|
v for (n, v) in policy.model.named_parameters()
|
|
|
|
if "log_std" in n
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
matching = [
|
|
|
|
v for v in policy.model.trainable_variables()
|
|
|
|
if "log_std" in str(v)
|
|
|
|
]
|
|
|
|
assert len(matching) == 1, matching
|
|
|
|
log_std_var = matching[0]
|
|
|
|
|
|
|
|
def get_value():
|
|
|
|
if fw == "tf":
|
|
|
|
return policy.get_session().run(log_std_var)[0]
|
|
|
|
elif fw == "torch":
|
2020-11-12 16:27:34 +01:00
|
|
|
return log_std_var.detach().cpu().numpy()[0]
|
2020-05-12 10:14:05 -07:00
|
|
|
else:
|
|
|
|
return log_std_var.numpy()[0]
|
|
|
|
|
|
|
|
# Check the variable is initially zero.
|
|
|
|
init_std = get_value()
|
|
|
|
assert init_std == 0.0, init_std
|
2021-01-19 14:22:36 +01:00
|
|
|
batch = compute_gae_for_sample_batch(policy, FAKE_BATCH.copy())
|
|
|
|
if fw == "torch":
|
2020-05-12 10:14:05 -07:00
|
|
|
batch = policy._lazy_tensor_dict(batch)
|
|
|
|
policy.learn_on_batch(batch)
|
|
|
|
|
|
|
|
# Check the variable is updated.
|
|
|
|
post_std = get_value()
|
|
|
|
assert post_std != 0.0, post_std
|
2020-06-27 20:50:01 +02:00
|
|
|
trainer.stop()
|
2020-05-12 10:14:05 -07:00
|
|
|
|
2020-01-21 08:06:50 +01:00
|
|
|
def test_ppo_loss_function(self):
|
|
|
|
"""Tests the PPO loss function math."""
|
2020-05-12 10:14:05 -07:00
|
|
|
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
|
2020-01-21 08:06:50 +01:00
|
|
|
config["num_workers"] = 0 # Run locally.
|
|
|
|
config["gamma"] = 0.99
|
|
|
|
config["model"]["fcnet_hiddens"] = [10]
|
|
|
|
config["model"]["fcnet_activation"] = "linear"
|
2021-01-19 09:51:35 +01:00
|
|
|
config["model"]["vf_share_layers"] = True
|
2020-01-21 08:06:50 +01:00
|
|
|
|
2020-04-06 20:56:16 +02:00
|
|
|
for fw, sess in framework_iterator(config, session=True):
|
2020-04-01 07:00:28 +02:00
|
|
|
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
|
|
|
policy = trainer.get_policy()
|
|
|
|
|
2020-05-12 10:14:05 -07:00
|
|
|
# Check no free log std var by default.
|
|
|
|
if fw == "torch":
|
|
|
|
matching = [
|
|
|
|
v for (n, v) in policy.model.named_parameters()
|
|
|
|
if "log_std" in n
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
matching = [
|
|
|
|
v for v in policy.model.trainable_variables()
|
|
|
|
if "log_std" in str(v)
|
|
|
|
]
|
|
|
|
assert len(matching) == 0, matching
|
|
|
|
|
2020-04-01 07:00:28 +02:00
|
|
|
# Post-process (calculate simple (non-GAE) advantages) and attach
|
|
|
|
# to train_batch dict.
|
|
|
|
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
|
|
|
|
# [0.50005, -0.505, 0.5]
|
2021-01-19 14:22:36 +01:00
|
|
|
train_batch = compute_gae_for_sample_batch(policy,
|
|
|
|
FAKE_BATCH.copy())
|
|
|
|
if fw == "torch":
|
2020-04-01 07:00:28 +02:00
|
|
|
train_batch = policy._lazy_tensor_dict(train_batch)
|
|
|
|
|
|
|
|
# Check Advantage values.
|
|
|
|
check(train_batch[Postprocessing.VALUE_TARGETS],
|
|
|
|
[0.50005, -0.505, 0.5])
|
|
|
|
|
2020-04-03 21:24:25 +02:00
|
|
|
# Calculate actual PPO loss.
|
2020-07-11 22:06:35 +02:00
|
|
|
if fw in ["tf2", "tfe"]:
|
2020-04-01 07:00:28 +02:00
|
|
|
ppo_surrogate_loss_tf(policy, policy.model, Categorical,
|
|
|
|
train_batch)
|
2020-04-03 21:24:25 +02:00
|
|
|
elif fw == "torch":
|
2021-11-15 16:11:35 -08:00
|
|
|
ppo_surrogate_loss_torch(policy, policy.model,
|
|
|
|
TorchCategorical, train_batch)
|
2020-04-01 07:00:28 +02:00
|
|
|
|
2020-04-03 21:24:25 +02:00
|
|
|
vars = policy.model.variables() if fw != "torch" else \
|
2020-04-01 07:00:28 +02:00
|
|
|
list(policy.model.parameters())
|
2020-04-03 21:24:25 +02:00
|
|
|
if fw == "tf":
|
|
|
|
vars = policy.get_session().run(vars)
|
2020-04-06 20:56:16 +02:00
|
|
|
expected_shared_out = fc(
|
|
|
|
train_batch[SampleBatch.CUR_OBS],
|
2020-04-15 13:25:16 +02:00
|
|
|
vars[0 if fw != "torch" else 2],
|
|
|
|
vars[1 if fw != "torch" else 3],
|
2020-04-06 20:56:16 +02:00
|
|
|
framework=fw)
|
|
|
|
expected_logits = fc(
|
2020-04-15 13:25:16 +02:00
|
|
|
expected_shared_out,
|
|
|
|
vars[2 if fw != "torch" else 0],
|
|
|
|
vars[3 if fw != "torch" else 1],
|
|
|
|
framework=fw)
|
2020-04-06 20:56:16 +02:00
|
|
|
expected_value_outs = fc(
|
|
|
|
expected_shared_out, vars[4], vars[5], framework=fw)
|
2020-04-01 07:00:28 +02:00
|
|
|
|
|
|
|
kl, entropy, pg_loss, vf_loss, overall_loss = \
|
|
|
|
self._ppo_loss_helper(
|
|
|
|
policy, policy.model,
|
2020-04-03 21:24:25 +02:00
|
|
|
Categorical if fw != "torch" else TorchCategorical,
|
2020-04-01 07:00:28 +02:00
|
|
|
train_batch,
|
2020-04-03 21:24:25 +02:00
|
|
|
expected_logits, expected_value_outs,
|
|
|
|
sess=sess
|
2020-04-01 07:00:28 +02:00
|
|
|
)
|
2020-04-03 21:24:25 +02:00
|
|
|
if sess:
|
|
|
|
policy_sess = policy.get_session()
|
|
|
|
k, e, pl, v, tl = policy_sess.run(
|
|
|
|
[
|
2021-09-21 22:00:14 +02:00
|
|
|
policy._mean_kl_loss,
|
2020-09-02 14:03:01 +02:00
|
|
|
policy._mean_entropy,
|
|
|
|
policy._mean_policy_loss,
|
|
|
|
policy._mean_vf_loss,
|
|
|
|
policy._total_loss,
|
2020-04-03 21:24:25 +02:00
|
|
|
],
|
|
|
|
feed_dict=policy._get_loss_inputs_dict(
|
|
|
|
train_batch, shuffle=False))
|
|
|
|
check(k, kl)
|
|
|
|
check(e, entropy)
|
|
|
|
check(pl, np.mean(-pg_loss))
|
|
|
|
check(v, np.mean(vf_loss), decimals=4)
|
|
|
|
check(tl, overall_loss, decimals=4)
|
2021-10-04 13:29:00 +02:00
|
|
|
elif fw == "torch":
|
|
|
|
check(policy.model.tower_stats["mean_kl_loss"], kl)
|
|
|
|
check(policy.model.tower_stats["mean_entropy"], entropy)
|
|
|
|
check(policy.model.tower_stats["mean_policy_loss"],
|
|
|
|
np.mean(-pg_loss))
|
|
|
|
check(
|
|
|
|
policy.model.tower_stats["mean_vf_loss"],
|
|
|
|
np.mean(vf_loss),
|
|
|
|
decimals=4)
|
|
|
|
check(
|
|
|
|
policy.model.tower_stats["total_loss"],
|
|
|
|
overall_loss,
|
|
|
|
decimals=4)
|
2020-04-03 21:24:25 +02:00
|
|
|
else:
|
2021-09-21 22:00:14 +02:00
|
|
|
check(policy._mean_kl_loss, kl)
|
2020-09-02 14:03:01 +02:00
|
|
|
check(policy._mean_entropy, entropy)
|
|
|
|
check(policy._mean_policy_loss, np.mean(-pg_loss))
|
|
|
|
check(policy._mean_vf_loss, np.mean(vf_loss), decimals=4)
|
|
|
|
check(policy._total_loss, overall_loss, decimals=4)
|
2020-06-27 20:50:01 +02:00
|
|
|
trainer.stop()
|
2020-04-03 21:24:25 +02:00
|
|
|
|
|
|
|
def _ppo_loss_helper(self,
|
|
|
|
policy,
|
|
|
|
model,
|
|
|
|
dist_class,
|
|
|
|
train_batch,
|
|
|
|
logits,
|
|
|
|
vf_outs,
|
|
|
|
sess=None):
|
2020-01-21 08:06:50 +01:00
|
|
|
"""
|
|
|
|
Calculates the expected PPO loss (components) given Policy,
|
|
|
|
Model, distribution, some batch, logits & vf outputs, using numpy.
|
|
|
|
"""
|
|
|
|
# Calculate expected PPO loss results.
|
|
|
|
dist = dist_class(logits, policy.model)
|
2020-04-01 09:43:21 +02:00
|
|
|
dist_prev = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
|
|
|
|
policy.model)
|
2020-01-21 08:06:50 +01:00
|
|
|
expected_logp = dist.logp(train_batch[SampleBatch.ACTIONS])
|
|
|
|
if isinstance(model, TorchModelV2):
|
2021-03-29 20:07:44 +02:00
|
|
|
train_batch.set_get_interceptor(None)
|
2020-11-12 16:27:34 +01:00
|
|
|
expected_rho = np.exp(expected_logp.detach().cpu().numpy() -
|
2021-03-29 20:07:44 +02:00
|
|
|
train_batch[SampleBatch.ACTION_LOGP])
|
2020-01-21 08:06:50 +01:00
|
|
|
# KL(prev vs current action dist)-loss component.
|
2020-11-12 16:27:34 +01:00
|
|
|
kl = np.mean(dist_prev.kl(dist).detach().cpu().numpy())
|
2020-01-21 08:06:50 +01:00
|
|
|
# Entropy-loss component.
|
2020-11-12 16:27:34 +01:00
|
|
|
entropy = np.mean(dist.entropy().detach().cpu().numpy())
|
2020-01-21 08:06:50 +01:00
|
|
|
else:
|
2020-04-03 21:24:25 +02:00
|
|
|
if sess:
|
|
|
|
expected_logp = sess.run(expected_logp)
|
2020-04-01 09:43:21 +02:00
|
|
|
expected_rho = np.exp(expected_logp -
|
|
|
|
train_batch[SampleBatch.ACTION_LOGP])
|
2020-01-21 08:06:50 +01:00
|
|
|
# KL(prev vs current action dist)-loss component.
|
2020-04-03 21:24:25 +02:00
|
|
|
kl = dist_prev.kl(dist)
|
|
|
|
if sess:
|
|
|
|
kl = sess.run(kl)
|
|
|
|
kl = np.mean(kl)
|
2020-01-21 08:06:50 +01:00
|
|
|
# Entropy-loss component.
|
2020-04-03 21:24:25 +02:00
|
|
|
entropy = dist.entropy()
|
|
|
|
if sess:
|
|
|
|
entropy = sess.run(entropy)
|
|
|
|
entropy = np.mean(entropy)
|
2020-01-21 08:06:50 +01:00
|
|
|
|
|
|
|
# Policy loss component.
|
|
|
|
pg_loss = np.minimum(
|
2021-03-29 20:07:44 +02:00
|
|
|
train_batch[Postprocessing.ADVANTAGES] * expected_rho,
|
|
|
|
train_batch[Postprocessing.ADVANTAGES] * np.clip(
|
2020-01-21 08:06:50 +01:00
|
|
|
expected_rho, 1 - policy.config["clip_param"],
|
|
|
|
1 + policy.config["clip_param"]))
|
|
|
|
|
|
|
|
# Value function loss component.
|
|
|
|
vf_loss1 = np.power(
|
2021-03-29 20:07:44 +02:00
|
|
|
vf_outs - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
|
|
|
|
vf_clipped = train_batch[SampleBatch.VF_PREDS] + np.clip(
|
|
|
|
vf_outs - train_batch[SampleBatch.VF_PREDS],
|
2020-01-21 08:06:50 +01:00
|
|
|
-policy.config["vf_clip_param"], policy.config["vf_clip_param"])
|
|
|
|
vf_loss2 = np.power(
|
2021-03-29 20:07:44 +02:00
|
|
|
vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
|
2020-01-21 08:06:50 +01:00
|
|
|
vf_loss = np.maximum(vf_loss1, vf_loss2)
|
|
|
|
|
|
|
|
# Overall loss.
|
2020-04-03 21:24:25 +02:00
|
|
|
if sess:
|
|
|
|
policy_sess = policy.get_session()
|
|
|
|
kl_coeff, entropy_coeff = policy_sess.run(
|
|
|
|
[policy.kl_coeff, policy.entropy_coeff])
|
|
|
|
else:
|
|
|
|
kl_coeff, entropy_coeff = policy.kl_coeff, policy.entropy_coeff
|
|
|
|
overall_loss = np.mean(-pg_loss + kl_coeff * kl +
|
2020-01-21 08:06:50 +01:00
|
|
|
policy.config["vf_loss_coeff"] * vf_loss -
|
2020-04-03 21:24:25 +02:00
|
|
|
entropy_coeff * entropy)
|
2020-01-21 08:06:50 +01:00
|
|
|
return kl, entropy, pg_loss, vf_loss, overall_loss
|
2020-02-19 21:18:45 +01:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2020-03-12 04:39:47 +01:00
|
|
|
import pytest
|
|
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|