ray/rllib/agents/ppo/tests/test_ppo.py
Avnish Narayan 026bf01071
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.

* Unpin gym and deprecate pendulum v0

Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1

Lastly, all of the RLlib tests and have
been moved to python 3.7

* Add gym installation based on python version.

Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20

* Reformatting

* Fixing tests

* Move atari-py install conditional to req.txt

* migrate to new ale install method

* Fix QMix, SAC, and MADDPA too.

* Unpin gym and deprecate pendulum v0

Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1

Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.

Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20

Move atari-py install conditional to req.txt

migrate to new ale install method

Make parametric_actions_cartpole return float32 actions/obs

Adding type conversions if obs/actions don't match space

Add utils to make elements match gym space dtypes

Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 16:24:00 +01:00

415 lines
17 KiB
Python

import copy
import numpy as np
import unittest
import ray
from ray.rllib.agents.callbacks import DefaultCallbacks
import ray.rllib.agents.ppo as ppo
from ray.rllib.agents.ppo.ppo_tf_policy import ppo_surrogate_loss as \
ppo_surrogate_loss_tf
from ray.rllib.agents.ppo.ppo_torch_policy import ppo_surrogate_loss as \
ppo_surrogate_loss_torch
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \
LEARNER_STATS_KEY
from ray.rllib.utils.numpy import fc
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
check_train_results, framework_iterator
# Fake CartPole episode of n time steps.
FAKE_BATCH = SampleBatch({
SampleBatch.OBS: np.array(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
dtype=np.float32),
SampleBatch.ACTIONS: np.array([0, 1, 1]),
SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]),
SampleBatch.REWARDS: np.array([1.0, -1.0, .5], dtype=np.float32),
SampleBatch.PREV_REWARDS: np.array([1.0, -1.0, .5], dtype=np.float32),
SampleBatch.DONES: np.array([False, False, True]),
SampleBatch.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32),
SampleBatch.ACTION_DIST_INPUTS: np.array(
[[-2., 0.5], [-3., -0.3], [-0.1, 2.5]], dtype=np.float32),
SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32),
SampleBatch.EPS_ID: np.array([0, 0, 0]),
SampleBatch.AGENT_INDEX: np.array([0, 0, 0]),
})
class MyCallbacks(DefaultCallbacks):
@staticmethod
def _check_lr_torch(policy, policy_id):
for j, opt in enumerate(policy._optimizers):
for p in opt.param_groups:
assert p["lr"] == policy.cur_lr, "LR scheduling error!"
@staticmethod
def _check_lr_tf(policy, policy_id):
lr = policy.cur_lr
sess = policy.get_session()
if sess:
lr = sess.run(lr)
optim_lr = sess.run(policy._optimizer._lr)
else:
lr = lr.numpy()
optim_lr = policy._optimizer.lr.numpy()
assert lr == optim_lr, "LR scheduling error!"
def on_train_result(self, *, trainer, result: dict, **kwargs):
stats = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][
LEARNER_STATS_KEY]
# Learning rate should go to 0 after 1 iter.
check(stats["cur_lr"], 5e-5 if trainer.iteration == 1 else 0.0)
# Entropy coeff goes to 0.05, then 0.0 (per iter).
check(stats["entropy_coeff"], 0.1 if trainer.iteration == 1 else 0.05)
trainer.workers.foreach_policy(self._check_lr_torch if trainer.config[
"framework"] == "torch" else self._check_lr_tf)
class TestPPO(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()
@classmethod
def tearDownClass(cls):
ray.shutdown()
def test_ppo_compilation_and_schedule_mixins(self):
"""Test whether a PPOTrainer can be built with all frameworks."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
# For checking lr-schedule correctness.
config["callbacks"] = MyCallbacks
config["num_workers"] = 1
config["num_sgd_iter"] = 2
# Settings in case we use an LSTM.
config["model"]["lstm_cell_size"] = 10
config["model"]["max_seq_len"] = 20
# Use default-native keras models whenever possible.
# config["model"]["_use_default_native_models"] = True
# Setup lr- and entropy schedules for testing.
config["lr_schedule"] = [[0, config["lr"]], [128, 0.0]]
# Set entropy_coeff to a faulty value to proof that it'll get
# overridden by the schedule below (which is expected).
config["entropy_coeff"] = 100.0
config["entropy_coeff_schedule"] = [[0, 0.1], [256, 0.0]]
config["train_batch_size"] = 128
# Test with compression.
config["compress_observations"] = True
num_iterations = 2
for fw in framework_iterator(config):
for env in ["FrozenLake-v1", "MsPacmanNoFrameskip-v4"]:
print("Env={}".format(env))
for lstm in [True, False]:
print("LSTM={}".format(lstm))
config["model"]["use_lstm"] = lstm
config["model"]["lstm_use_prev_action"] = lstm
config["model"]["lstm_use_prev_reward"] = lstm
trainer = ppo.PPOTrainer(config=config, env=env)
policy = trainer.get_policy()
entropy_coeff = trainer.get_policy().entropy_coeff
lr = policy.cur_lr
if fw == "tf":
entropy_coeff, lr = policy.get_session().run(
[entropy_coeff, lr])
check(entropy_coeff, 0.1)
check(lr, config["lr"])
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
print(results)
check_compute_single_action(
trainer,
include_prev_action_reward=True,
include_state=lstm)
trainer.stop()
def test_ppo_exploration_setup(self):
"""Tests, whether PPO runs with different exploration setups."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
config["num_workers"] = 0 # Run locally.
config["env_config"] = {"is_slippery": False, "map_name": "4x4"}
obs = np.array(0)
# Test against all frameworks.
for fw in framework_iterator(config):
# Default Agent should be setup with StochasticSampling.
trainer = ppo.PPOTrainer(config=config, env="FrozenLake-v1")
# explore=False, always expect the same (deterministic) action.
a_ = trainer.compute_single_action(
obs,
explore=False,
prev_action=np.array(2),
prev_reward=np.array(1.0))
# Test whether this is really the argmax action over the logits.
if fw != "tf":
last_out = trainer.get_policy().model.last_output()
if fw == "torch":
check(a_, np.argmax(last_out.detach().cpu().numpy(), 1)[0])
else:
check(a_, np.argmax(last_out.numpy(), 1)[0])
for _ in range(50):
a = trainer.compute_single_action(
obs,
explore=False,
prev_action=np.array(2),
prev_reward=np.array(1.0))
check(a, a_)
# With explore=True (default), expect stochastic actions.
actions = []
for _ in range(300):
actions.append(
trainer.compute_single_action(
obs,
prev_action=np.array(2),
prev_reward=np.array(1.0)))
check(np.mean(actions), 1.5, atol=0.2)
trainer.stop()
def test_ppo_free_log_std(self):
"""Tests the free log std option works."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
config["num_workers"] = 0 # Run locally.
config["gamma"] = 0.99
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = "linear"
config["model"]["free_log_std"] = True
config["model"]["vf_share_layers"] = True
for fw, sess in framework_iterator(config, session=True):
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
# Check the free log std var is created.
if fw == "torch":
matching = [
v for (n, v) in policy.model.named_parameters()
if "log_std" in n
]
else:
matching = [
v for v in policy.model.trainable_variables()
if "log_std" in str(v)
]
assert len(matching) == 1, matching
log_std_var = matching[0]
def get_value():
if fw == "tf":
return policy.get_session().run(log_std_var)[0]
elif fw == "torch":
return log_std_var.detach().cpu().numpy()[0]
else:
return log_std_var.numpy()[0]
# Check the variable is initially zero.
init_std = get_value()
assert init_std == 0.0, init_std
batch = compute_gae_for_sample_batch(policy, FAKE_BATCH.copy())
if fw == "torch":
batch = policy._lazy_tensor_dict(batch)
policy.learn_on_batch(batch)
# Check the variable is updated.
post_std = get_value()
assert post_std != 0.0, post_std
trainer.stop()
def test_ppo_loss_function(self):
"""Tests the PPO loss function math."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
config["num_workers"] = 0 # Run locally.
config["gamma"] = 0.99
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = "linear"
config["model"]["vf_share_layers"] = True
for fw, sess in framework_iterator(config, session=True):
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
# Check no free log std var by default.
if fw == "torch":
matching = [
v for (n, v) in policy.model.named_parameters()
if "log_std" in n
]
else:
matching = [
v for v in policy.model.trainable_variables()
if "log_std" in str(v)
]
assert len(matching) == 0, matching
# Post-process (calculate simple (non-GAE) advantages) and attach
# to train_batch dict.
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
# [0.50005, -0.505, 0.5]
train_batch = compute_gae_for_sample_batch(policy,
FAKE_BATCH.copy())
if fw == "torch":
train_batch = policy._lazy_tensor_dict(train_batch)
# Check Advantage values.
check(train_batch[Postprocessing.VALUE_TARGETS],
[0.50005, -0.505, 0.5])
# Calculate actual PPO loss.
if fw in ["tf2", "tfe"]:
ppo_surrogate_loss_tf(policy, policy.model, Categorical,
train_batch)
elif fw == "torch":
ppo_surrogate_loss_torch(policy, policy.model,
TorchCategorical, train_batch)
vars = policy.model.variables() if fw != "torch" else \
list(policy.model.parameters())
if fw == "tf":
vars = policy.get_session().run(vars)
expected_shared_out = fc(
train_batch[SampleBatch.CUR_OBS],
vars[0 if fw != "torch" else 2],
vars[1 if fw != "torch" else 3],
framework=fw)
expected_logits = fc(
expected_shared_out,
vars[2 if fw != "torch" else 0],
vars[3 if fw != "torch" else 1],
framework=fw)
expected_value_outs = fc(
expected_shared_out, vars[4], vars[5], framework=fw)
kl, entropy, pg_loss, vf_loss, overall_loss = \
self._ppo_loss_helper(
policy, policy.model,
Categorical if fw != "torch" else TorchCategorical,
train_batch,
expected_logits, expected_value_outs,
sess=sess
)
if sess:
policy_sess = policy.get_session()
k, e, pl, v, tl = policy_sess.run(
[
policy._mean_kl_loss,
policy._mean_entropy,
policy._mean_policy_loss,
policy._mean_vf_loss,
policy._total_loss,
],
feed_dict=policy._get_loss_inputs_dict(
train_batch, shuffle=False))
check(k, kl)
check(e, entropy)
check(pl, np.mean(-pg_loss))
check(v, np.mean(vf_loss), decimals=4)
check(tl, overall_loss, decimals=4)
elif fw == "torch":
check(policy.model.tower_stats["mean_kl_loss"], kl)
check(policy.model.tower_stats["mean_entropy"], entropy)
check(policy.model.tower_stats["mean_policy_loss"],
np.mean(-pg_loss))
check(
policy.model.tower_stats["mean_vf_loss"],
np.mean(vf_loss),
decimals=4)
check(
policy.model.tower_stats["total_loss"],
overall_loss,
decimals=4)
else:
check(policy._mean_kl_loss, kl)
check(policy._mean_entropy, entropy)
check(policy._mean_policy_loss, np.mean(-pg_loss))
check(policy._mean_vf_loss, np.mean(vf_loss), decimals=4)
check(policy._total_loss, overall_loss, decimals=4)
trainer.stop()
def _ppo_loss_helper(self,
policy,
model,
dist_class,
train_batch,
logits,
vf_outs,
sess=None):
"""
Calculates the expected PPO loss (components) given Policy,
Model, distribution, some batch, logits & vf outputs, using numpy.
"""
# Calculate expected PPO loss results.
dist = dist_class(logits, policy.model)
dist_prev = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
policy.model)
expected_logp = dist.logp(train_batch[SampleBatch.ACTIONS])
if isinstance(model, TorchModelV2):
train_batch.set_get_interceptor(None)
expected_rho = np.exp(expected_logp.detach().cpu().numpy() -
train_batch[SampleBatch.ACTION_LOGP])
# KL(prev vs current action dist)-loss component.
kl = np.mean(dist_prev.kl(dist).detach().cpu().numpy())
# Entropy-loss component.
entropy = np.mean(dist.entropy().detach().cpu().numpy())
else:
if sess:
expected_logp = sess.run(expected_logp)
expected_rho = np.exp(expected_logp -
train_batch[SampleBatch.ACTION_LOGP])
# KL(prev vs current action dist)-loss component.
kl = dist_prev.kl(dist)
if sess:
kl = sess.run(kl)
kl = np.mean(kl)
# Entropy-loss component.
entropy = dist.entropy()
if sess:
entropy = sess.run(entropy)
entropy = np.mean(entropy)
# Policy loss component.
pg_loss = np.minimum(
train_batch[Postprocessing.ADVANTAGES] * expected_rho,
train_batch[Postprocessing.ADVANTAGES] * np.clip(
expected_rho, 1 - policy.config["clip_param"],
1 + policy.config["clip_param"]))
# Value function loss component.
vf_loss1 = np.power(
vf_outs - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_clipped = train_batch[SampleBatch.VF_PREDS] + np.clip(
vf_outs - train_batch[SampleBatch.VF_PREDS],
-policy.config["vf_clip_param"], policy.config["vf_clip_param"])
vf_loss2 = np.power(
vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_loss = np.maximum(vf_loss1, vf_loss2)
# Overall loss.
if sess:
policy_sess = policy.get_session()
kl_coeff, entropy_coeff = policy_sess.run(
[policy.kl_coeff, policy.entropy_coeff])
else:
kl_coeff, entropy_coeff = policy.kl_coeff, policy.entropy_coeff
overall_loss = np.mean(-pg_loss + kl_coeff * kl +
policy.config["vf_loss_coeff"] * vf_loss -
entropy_coeff * entropy)
return kl, entropy, pg_loss, vf_loss, overall_loss
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))