mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
471 lines
17 KiB
Python
471 lines
17 KiB
Python
import numpy as np
|
|
import unittest
|
|
|
|
import ray
|
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
|
import ray.rllib.agents.ppo as ppo
|
|
from ray.rllib.agents.ppo.ppo_tf_policy import (
|
|
ppo_surrogate_loss as ppo_surrogate_loss_tf,
|
|
)
|
|
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
|
|
from ray.rllib.evaluation.postprocessing import (
|
|
compute_gae_for_sample_batch,
|
|
Postprocessing,
|
|
)
|
|
from ray.rllib.models.tf.tf_action_dist import Categorical
|
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
|
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
|
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
|
|
from ray.rllib.utils.numpy import fc
|
|
from ray.rllib.utils.test_utils import (
|
|
check,
|
|
check_compute_single_action,
|
|
check_train_results,
|
|
framework_iterator,
|
|
)
|
|
|
|
# Fake CartPole episode of n time steps.
|
|
FAKE_BATCH = SampleBatch(
|
|
{
|
|
SampleBatch.OBS: np.array(
|
|
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
|
|
dtype=np.float32,
|
|
),
|
|
SampleBatch.ACTIONS: np.array([0, 1, 1]),
|
|
SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]),
|
|
SampleBatch.REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32),
|
|
SampleBatch.PREV_REWARDS: np.array([1.0, -1.0, 0.5], dtype=np.float32),
|
|
SampleBatch.DONES: np.array([False, False, True]),
|
|
SampleBatch.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32),
|
|
SampleBatch.ACTION_DIST_INPUTS: np.array(
|
|
[[-2.0, 0.5], [-3.0, -0.3], [-0.1, 2.5]], dtype=np.float32
|
|
),
|
|
SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32),
|
|
SampleBatch.EPS_ID: np.array([0, 0, 0]),
|
|
SampleBatch.AGENT_INDEX: np.array([0, 0, 0]),
|
|
}
|
|
)
|
|
|
|
|
|
class MyCallbacks(DefaultCallbacks):
|
|
@staticmethod
|
|
def _check_lr_torch(policy, policy_id):
|
|
for j, opt in enumerate(policy._optimizers):
|
|
for p in opt.param_groups:
|
|
assert p["lr"] == policy.cur_lr, "LR scheduling error!"
|
|
|
|
@staticmethod
|
|
def _check_lr_tf(policy, policy_id):
|
|
lr = policy.cur_lr
|
|
sess = policy.get_session()
|
|
if sess:
|
|
lr = sess.run(lr)
|
|
optim_lr = sess.run(policy._optimizer._lr)
|
|
else:
|
|
lr = lr.numpy()
|
|
optim_lr = policy._optimizer.lr.numpy()
|
|
assert lr == optim_lr, "LR scheduling error!"
|
|
|
|
def on_train_result(self, *, trainer, result: dict, **kwargs):
|
|
stats = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY]
|
|
# Learning rate should go to 0 after 1 iter.
|
|
check(stats["cur_lr"], 5e-5 if trainer.iteration == 1 else 0.0)
|
|
# Entropy coeff goes to 0.05, then 0.0 (per iter).
|
|
check(stats["entropy_coeff"], 0.1 if trainer.iteration == 1 else 0.05)
|
|
|
|
trainer.workers.foreach_policy(
|
|
self._check_lr_torch
|
|
if trainer.config["framework"] == "torch"
|
|
else self._check_lr_tf
|
|
)
|
|
|
|
|
|
class TestPPO(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
ray.init()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
ray.shutdown()
|
|
|
|
def test_ppo_compilation_and_schedule_mixins(self):
|
|
"""Test whether a PPOTrainer can be built with all frameworks."""
|
|
|
|
# Build a PPOConfig object.
|
|
config = (
|
|
ppo.PPOConfig()
|
|
.training(
|
|
num_sgd_iter=2,
|
|
# Setup lr schedule for testing.
|
|
lr_schedule=[[0, 5e-5], [128, 0.0]],
|
|
# Set entropy_coeff to a faulty value to proof that it'll get
|
|
# overridden by the schedule below (which is expected).
|
|
entropy_coeff=100.0,
|
|
entropy_coeff_schedule=[[0, 0.1], [256, 0.0]],
|
|
)
|
|
.rollouts(
|
|
num_rollout_workers=1,
|
|
# Test with compression.
|
|
compress_observations=True,
|
|
)
|
|
.training(
|
|
train_batch_size=128,
|
|
model=dict(
|
|
# Settings in case we use an LSTM.
|
|
lstm_cell_size=10,
|
|
max_seq_len=20,
|
|
),
|
|
)
|
|
.callbacks(MyCallbacks)
|
|
) # For checking lr-schedule correctness.
|
|
|
|
num_iterations = 2
|
|
|
|
for fw in framework_iterator(config, with_eager_tracing=True):
|
|
for env in ["FrozenLake-v1", "MsPacmanNoFrameskip-v4"]:
|
|
print("Env={}".format(env))
|
|
for lstm in [True, False]:
|
|
print("LSTM={}".format(lstm))
|
|
config.training(
|
|
model=dict(
|
|
use_lstm=lstm,
|
|
lstm_use_prev_action=lstm,
|
|
lstm_use_prev_reward=lstm,
|
|
)
|
|
)
|
|
|
|
trainer = config.build(env=env)
|
|
policy = trainer.get_policy()
|
|
entropy_coeff = trainer.get_policy().entropy_coeff
|
|
lr = policy.cur_lr
|
|
if fw == "tf":
|
|
entropy_coeff, lr = policy.get_session().run(
|
|
[entropy_coeff, lr]
|
|
)
|
|
check(entropy_coeff, 0.1)
|
|
check(lr, config.lr)
|
|
|
|
for i in range(num_iterations):
|
|
results = trainer.train()
|
|
check_train_results(results)
|
|
print(results)
|
|
|
|
check_compute_single_action(
|
|
trainer, include_prev_action_reward=True, include_state=lstm
|
|
)
|
|
trainer.stop()
|
|
|
|
def test_ppo_exploration_setup(self):
|
|
"""Tests, whether PPO runs with different exploration setups."""
|
|
config = (
|
|
ppo.PPOConfig()
|
|
.environment(
|
|
env_config={"is_slippery": False, "map_name": "4x4"},
|
|
)
|
|
.rollouts(
|
|
# Run locally.
|
|
num_rollout_workers=0,
|
|
)
|
|
)
|
|
obs = np.array(0)
|
|
|
|
# Test against all frameworks.
|
|
for fw in framework_iterator(config):
|
|
# Default Agent should be setup with StochasticSampling.
|
|
trainer = ppo.PPOTrainer(config=config, env="FrozenLake-v1")
|
|
# explore=False, always expect the same (deterministic) action.
|
|
a_ = trainer.compute_single_action(
|
|
obs, explore=False, prev_action=np.array(2), prev_reward=np.array(1.0)
|
|
)
|
|
# Test whether this is really the argmax action over the logits.
|
|
if fw != "tf":
|
|
last_out = trainer.get_policy().model.last_output()
|
|
if fw == "torch":
|
|
check(a_, np.argmax(last_out.detach().cpu().numpy(), 1)[0])
|
|
else:
|
|
check(a_, np.argmax(last_out.numpy(), 1)[0])
|
|
for _ in range(50):
|
|
a = trainer.compute_single_action(
|
|
obs,
|
|
explore=False,
|
|
prev_action=np.array(2),
|
|
prev_reward=np.array(1.0),
|
|
)
|
|
check(a, a_)
|
|
|
|
# With explore=True (default), expect stochastic actions.
|
|
actions = []
|
|
for _ in range(300):
|
|
actions.append(
|
|
trainer.compute_single_action(
|
|
obs, prev_action=np.array(2), prev_reward=np.array(1.0)
|
|
)
|
|
)
|
|
check(np.mean(actions), 1.5, atol=0.2)
|
|
trainer.stop()
|
|
|
|
def test_ppo_free_log_std(self):
|
|
"""Tests the free log std option works."""
|
|
config = (
|
|
ppo.PPOConfig()
|
|
.rollouts(
|
|
num_rollout_workers=0,
|
|
)
|
|
.training(
|
|
gamma=0.99,
|
|
model=dict(
|
|
fcnet_hiddens=[10],
|
|
fcnet_activation="linear",
|
|
free_log_std=True,
|
|
vf_share_layers=True,
|
|
),
|
|
)
|
|
)
|
|
|
|
for fw, sess in framework_iterator(config, session=True):
|
|
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
|
policy = trainer.get_policy()
|
|
|
|
# Check the free log std var is created.
|
|
if fw == "torch":
|
|
matching = [
|
|
v for (n, v) in policy.model.named_parameters() if "log_std" in n
|
|
]
|
|
else:
|
|
matching = [
|
|
v for v in policy.model.trainable_variables() if "log_std" in str(v)
|
|
]
|
|
assert len(matching) == 1, matching
|
|
log_std_var = matching[0]
|
|
|
|
def get_value():
|
|
if fw == "tf":
|
|
return policy.get_session().run(log_std_var)[0]
|
|
elif fw == "torch":
|
|
return log_std_var.detach().cpu().numpy()[0]
|
|
else:
|
|
return log_std_var.numpy()[0]
|
|
|
|
# Check the variable is initially zero.
|
|
init_std = get_value()
|
|
assert init_std == 0.0, init_std
|
|
batch = compute_gae_for_sample_batch(policy, FAKE_BATCH.copy())
|
|
if fw == "torch":
|
|
batch = policy._lazy_tensor_dict(batch)
|
|
policy.learn_on_batch(batch)
|
|
|
|
# Check the variable is updated.
|
|
post_std = get_value()
|
|
assert post_std != 0.0, post_std
|
|
trainer.stop()
|
|
|
|
def test_ppo_legacy_config(self):
|
|
"""Tests, whether the old PPO config dict is still functional."""
|
|
ppo_config = ppo.DEFAULT_CONFIG
|
|
# Expect warning.
|
|
print(f"Accessing learning-rate from legacy config dict: {ppo_config['lr']}")
|
|
# Build Trainer.
|
|
ppo_trainer = ppo.PPOTrainer(config=ppo_config, env="CartPole-v1")
|
|
print(ppo_trainer.train())
|
|
|
|
def test_ppo_loss_function(self):
|
|
"""Tests the PPO loss function math."""
|
|
config = (
|
|
ppo.PPOConfig()
|
|
.rollouts(
|
|
num_rollout_workers=0,
|
|
)
|
|
.training(
|
|
gamma=0.99,
|
|
model=dict(
|
|
fcnet_hiddens=[10],
|
|
fcnet_activation="linear",
|
|
vf_share_layers=True,
|
|
),
|
|
)
|
|
)
|
|
|
|
for fw, sess in framework_iterator(config, session=True):
|
|
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
|
|
policy = trainer.get_policy()
|
|
|
|
# Check no free log std var by default.
|
|
if fw == "torch":
|
|
matching = [
|
|
v for (n, v) in policy.model.named_parameters() if "log_std" in n
|
|
]
|
|
else:
|
|
matching = [
|
|
v for v in policy.model.trainable_variables() if "log_std" in str(v)
|
|
]
|
|
assert len(matching) == 0, matching
|
|
|
|
# Post-process (calculate simple (non-GAE) advantages) and attach
|
|
# to train_batch dict.
|
|
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
|
|
# [0.50005, -0.505, 0.5]
|
|
train_batch = compute_gae_for_sample_batch(policy, FAKE_BATCH.copy())
|
|
if fw == "torch":
|
|
train_batch = policy._lazy_tensor_dict(train_batch)
|
|
|
|
# Check Advantage values.
|
|
check(train_batch[Postprocessing.VALUE_TARGETS], [0.50005, -0.505, 0.5])
|
|
|
|
# Calculate actual PPO loss.
|
|
if fw in ["tf2", "tfe"]:
|
|
ppo_surrogate_loss_tf(policy, policy.model, Categorical, train_batch)
|
|
elif fw == "torch":
|
|
PPOTorchPolicy.loss(
|
|
policy, policy.model, policy.dist_class, train_batch
|
|
)
|
|
|
|
vars = (
|
|
policy.model.variables()
|
|
if fw != "torch"
|
|
else list(policy.model.parameters())
|
|
)
|
|
if fw == "tf":
|
|
vars = policy.get_session().run(vars)
|
|
expected_shared_out = fc(
|
|
train_batch[SampleBatch.CUR_OBS],
|
|
vars[0 if fw != "torch" else 2],
|
|
vars[1 if fw != "torch" else 3],
|
|
framework=fw,
|
|
)
|
|
expected_logits = fc(
|
|
expected_shared_out,
|
|
vars[2 if fw != "torch" else 0],
|
|
vars[3 if fw != "torch" else 1],
|
|
framework=fw,
|
|
)
|
|
expected_value_outs = fc(
|
|
expected_shared_out, vars[4], vars[5], framework=fw
|
|
)
|
|
|
|
kl, entropy, pg_loss, vf_loss, overall_loss = self._ppo_loss_helper(
|
|
policy,
|
|
policy.model,
|
|
Categorical if fw != "torch" else TorchCategorical,
|
|
train_batch,
|
|
expected_logits,
|
|
expected_value_outs,
|
|
sess=sess,
|
|
)
|
|
if sess:
|
|
policy_sess = policy.get_session()
|
|
k, e, pl, v, tl = policy_sess.run(
|
|
[
|
|
policy._mean_kl_loss,
|
|
policy._mean_entropy,
|
|
policy._mean_policy_loss,
|
|
policy._mean_vf_loss,
|
|
policy._total_loss,
|
|
],
|
|
feed_dict=policy._get_loss_inputs_dict(train_batch, shuffle=False),
|
|
)
|
|
check(k, kl)
|
|
check(e, entropy)
|
|
check(pl, np.mean(-pg_loss))
|
|
check(v, np.mean(vf_loss), decimals=4)
|
|
check(tl, overall_loss, decimals=4)
|
|
elif fw == "torch":
|
|
check(policy.model.tower_stats["mean_kl_loss"], kl)
|
|
check(policy.model.tower_stats["mean_entropy"], entropy)
|
|
check(policy.model.tower_stats["mean_policy_loss"], np.mean(-pg_loss))
|
|
check(
|
|
policy.model.tower_stats["mean_vf_loss"],
|
|
np.mean(vf_loss),
|
|
decimals=4,
|
|
)
|
|
check(policy.model.tower_stats["total_loss"], overall_loss, decimals=4)
|
|
else:
|
|
check(policy._mean_kl_loss, kl)
|
|
check(policy._mean_entropy, entropy)
|
|
check(policy._mean_policy_loss, np.mean(-pg_loss))
|
|
check(policy._mean_vf_loss, np.mean(vf_loss), decimals=4)
|
|
check(policy._total_loss, overall_loss, decimals=4)
|
|
trainer.stop()
|
|
|
|
def _ppo_loss_helper(
|
|
self, policy, model, dist_class, train_batch, logits, vf_outs, sess=None
|
|
):
|
|
"""
|
|
Calculates the expected PPO loss (components) given Policy,
|
|
Model, distribution, some batch, logits & vf outputs, using numpy.
|
|
"""
|
|
# Calculate expected PPO loss results.
|
|
dist = dist_class(logits, policy.model)
|
|
dist_prev = dist_class(
|
|
train_batch[SampleBatch.ACTION_DIST_INPUTS], policy.model
|
|
)
|
|
expected_logp = dist.logp(train_batch[SampleBatch.ACTIONS])
|
|
if isinstance(model, TorchModelV2):
|
|
train_batch.set_get_interceptor(None)
|
|
expected_rho = np.exp(
|
|
expected_logp.detach().cpu().numpy()
|
|
- train_batch[SampleBatch.ACTION_LOGP]
|
|
)
|
|
# KL(prev vs current action dist)-loss component.
|
|
kl = np.mean(dist_prev.kl(dist).detach().cpu().numpy())
|
|
# Entropy-loss component.
|
|
entropy = np.mean(dist.entropy().detach().cpu().numpy())
|
|
else:
|
|
if sess:
|
|
expected_logp = sess.run(expected_logp)
|
|
expected_rho = np.exp(expected_logp - train_batch[SampleBatch.ACTION_LOGP])
|
|
# KL(prev vs current action dist)-loss component.
|
|
kl = dist_prev.kl(dist)
|
|
if sess:
|
|
kl = sess.run(kl)
|
|
kl = np.mean(kl)
|
|
# Entropy-loss component.
|
|
entropy = dist.entropy()
|
|
if sess:
|
|
entropy = sess.run(entropy)
|
|
entropy = np.mean(entropy)
|
|
|
|
# Policy loss component.
|
|
pg_loss = np.minimum(
|
|
train_batch[Postprocessing.ADVANTAGES] * expected_rho,
|
|
train_batch[Postprocessing.ADVANTAGES]
|
|
* np.clip(
|
|
expected_rho,
|
|
1 - policy.config["clip_param"],
|
|
1 + policy.config["clip_param"],
|
|
),
|
|
)
|
|
|
|
# Value function loss component.
|
|
vf_loss1 = np.power(vf_outs - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
|
|
vf_clipped = train_batch[SampleBatch.VF_PREDS] + np.clip(
|
|
vf_outs - train_batch[SampleBatch.VF_PREDS],
|
|
-policy.config["vf_clip_param"],
|
|
policy.config["vf_clip_param"],
|
|
)
|
|
vf_loss2 = np.power(vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
|
|
vf_loss = np.maximum(vf_loss1, vf_loss2)
|
|
|
|
# Overall loss.
|
|
if sess:
|
|
policy_sess = policy.get_session()
|
|
kl_coeff, entropy_coeff = policy_sess.run(
|
|
[policy.kl_coeff, policy.entropy_coeff]
|
|
)
|
|
else:
|
|
kl_coeff, entropy_coeff = policy.kl_coeff, policy.entropy_coeff
|
|
overall_loss = np.mean(
|
|
-pg_loss
|
|
+ kl_coeff * kl
|
|
+ policy.config["vf_loss_coeff"] * vf_loss
|
|
- entropy_coeff * entropy
|
|
)
|
|
return kl, entropy, pg_loss, vf_loss, overall_loss
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|