ray/rllib/agents/sac/tests/test_sac.py
Sven Mika 2d24ef0d32
[RLlib] Add all simple learning tests as framework=tf2. (#19273)
* Unpin gym and deprecate pendulum v0

Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1

Lastly, all of the RLlib tests and Tune tests have
been moved to python 3.7

* fix tune test_sampler::testSampleBoundsAx

* fix re-install ray for py3.7 tests

Co-authored-by: avnishn <avnishn@uw.edu>
2021-11-02 12:10:17 +01:00

627 lines
27 KiB
Python

from gym import Env
from gym.spaces import Box, Dict, Discrete, Tuple
import numpy as np
import re
import unittest
import ray
import ray.rllib.agents.sac as sac
from ray.rllib.agents.sac.sac_tf_policy import sac_actor_critic_loss as tf_loss
from ray.rllib.agents.sac.sac_torch_policy import actor_critic_loss as \
loss_torch
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.examples.models.batch_norm_model import KerasBatchNormModel, \
TorchBatchNormModel
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.tf.tf_action_dist import Dirichlet
from ray.rllib.models.torch.torch_action_dist import TorchDirichlet
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.numpy import fc, huber_loss, relu
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
check_train_results, framework_iterator
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
from ray import tune
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
class SimpleEnv(Env):
def __init__(self, config):
if config.get("simplex_actions", False):
self.action_space = Simplex((2, ))
else:
self.action_space = Box(0.0, 1.0, (1, ))
self.observation_space = Box(0.0, 1.0, (1, ))
self.max_steps = config.get("max_steps", 100)
self.state = None
self.steps = None
def reset(self):
self.state = self.observation_space.sample()
self.steps = 0
return self.state
def step(self, action):
self.steps += 1
# Reward is 1.0 - (max(actions) - state).
[r] = 1.0 - np.abs(np.max(action) - self.state)
d = self.steps >= self.max_steps
self.state = self.observation_space.sample()
return self.state, r, d, {}
class TestSAC(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
np.random.seed(42)
torch.manual_seed(42)
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_sac_compilation(self):
"""Tests whether an SACTrainer can be built with all frameworks."""
config = sac.DEFAULT_CONFIG.copy()
config["Q_model"] = sac.DEFAULT_CONFIG["Q_model"].copy()
config["num_workers"] = 0 # Run locally.
config["n_step"] = 3
config["twin_q"] = True
config["learning_starts"] = 0
config["prioritized_replay"] = True
config["rollout_fragment_length"] = 10
config["train_batch_size"] = 10
# If we use default buffer size (1e6), the buffer will take up
# 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021)
# available system memory (8.34816 GB).
config["buffer_size"] = 40000
# Test with saved replay buffer.
config["store_buffer_in_checkpoints"] = True
num_iterations = 1
ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel)
ModelCatalog.register_custom_model("batch_norm_torch",
TorchBatchNormModel)
image_space = Box(-1.0, 1.0, shape=(84, 84, 3))
simple_space = Box(-1.0, 1.0, shape=(3, ))
tune.register_env(
"random_dict_env", lambda _: RandomEnv({
"observation_space": Dict({
"a": simple_space,
"b": Discrete(2),
"c": image_space, }),
"action_space": Box(-1.0, 1.0, shape=(1, )), }))
tune.register_env(
"random_tuple_env", lambda _: RandomEnv({
"observation_space": Tuple([
simple_space, Discrete(2), image_space]),
"action_space": Box(-1.0, 1.0, shape=(1, )), }))
for fw in framework_iterator(config, with_eager_tracing=True):
# Test for different env types (discrete w/ and w/o image, + cont).
for env in [
"random_dict_env",
"random_tuple_env",
# "MsPacmanNoFrameskip-v4",
"CartPole-v0",
]:
print("Env={}".format(env))
# Test making the Q-model a custom one for CartPole, otherwise,
# use the default model.
config["Q_model"]["custom_model"] = "batch_norm{}".format(
"_torch"
if fw == "torch" else "") if env == "CartPole-v0" else None
trainer = sac.SACTrainer(config=config, env=env)
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
print(results)
check_compute_single_action(trainer)
# Test, whether the replay buffer is saved along with
# a checkpoint (no point in doing it for all frameworks since
# this is framework agnostic).
if fw == "tf" and env == "CartPole-v0":
checkpoint = trainer.save()
new_trainer = sac.SACTrainer(config, env=env)
new_trainer.restore(checkpoint)
# Get some data from the buffer and compare.
data = trainer.local_replay_buffer.replay_buffers[
"default_policy"]._storage[:42 + 42]
new_data = new_trainer.local_replay_buffer.replay_buffers[
"default_policy"]._storage[:42 + 42]
check(data, new_data)
new_trainer.stop()
trainer.stop()
def test_sac_loss_function(self):
"""Tests SAC loss function results across all frameworks."""
config = sac.DEFAULT_CONFIG.copy()
# Run locally.
config["num_workers"] = 0
config["learning_starts"] = 0
config["twin_q"] = False
config["gamma"] = 0.99
# Switch on deterministic loss so we can compare the loss values.
config["_deterministic_loss"] = True
# Use very simple nets.
config["Q_model"]["fcnet_hiddens"] = [10]
config["policy_model"]["fcnet_hiddens"] = [10]
# Make sure, timing differences do not affect trainer.train().
config["min_iter_time_s"] = 0
# Test SAC with Simplex action space.
config["env_config"] = {"simplex_actions": True}
map_ = {
# Action net.
"default_policy/fc_1/kernel": "action_model._hidden_layers.0."
"_model.0.weight",
"default_policy/fc_1/bias": "action_model._hidden_layers.0."
"_model.0.bias",
"default_policy/fc_out/kernel": "action_model."
"_logits._model.0.weight",
"default_policy/fc_out/bias": "action_model._logits._model.0.bias",
"default_policy/value_out/kernel": "action_model."
"_value_branch._model.0.weight",
"default_policy/value_out/bias": "action_model."
"_value_branch._model.0.bias",
# Q-net.
"default_policy/fc_1_1/kernel": "q_net."
"_hidden_layers.0._model.0.weight",
"default_policy/fc_1_1/bias": "q_net."
"_hidden_layers.0._model.0.bias",
"default_policy/fc_out_1/kernel": "q_net._logits._model.0.weight",
"default_policy/fc_out_1/bias": "q_net._logits._model.0.bias",
"default_policy/value_out_1/kernel": "q_net."
"_value_branch._model.0.weight",
"default_policy/value_out_1/bias": "q_net."
"_value_branch._model.0.bias",
"default_policy/log_alpha": "log_alpha",
# Target action-net.
"default_policy/fc_1_2/kernel": "action_model."
"_hidden_layers.0._model.0.weight",
"default_policy/fc_1_2/bias": "action_model."
"_hidden_layers.0._model.0.bias",
"default_policy/fc_out_2/kernel": "action_model."
"_logits._model.0.weight",
"default_policy/fc_out_2/bias": "action_model."
"_logits._model.0.bias",
"default_policy/value_out_2/kernel": "action_model."
"_value_branch._model.0.weight",
"default_policy/value_out_2/bias": "action_model."
"_value_branch._model.0.bias",
# Target Q-net
"default_policy/fc_1_3/kernel": "q_net."
"_hidden_layers.0._model.0.weight",
"default_policy/fc_1_3/bias": "q_net."
"_hidden_layers.0._model.0.bias",
"default_policy/fc_out_3/kernel": "q_net."
"_logits._model.0.weight",
"default_policy/fc_out_3/bias": "q_net."
"_logits._model.0.bias",
"default_policy/value_out_3/kernel": "q_net."
"_value_branch._model.0.weight",
"default_policy/value_out_3/bias": "q_net."
"_value_branch._model.0.bias",
"default_policy/log_alpha_1": "log_alpha",
}
env = SimpleEnv
batch_size = 100
obs_size = (batch_size, 1)
actions = np.random.random(size=(batch_size, 2))
# Batch of size=n.
input_ = self._get_batch_helper(obs_size, actions, batch_size)
# Simply compare loss values AND grads of all frameworks with each
# other.
prev_fw_loss = weights_dict = None
expect_c, expect_a, expect_e, expect_t = None, None, None, None
# History of tf-updated NN-weights over n training steps.
tf_updated_weights = []
# History of input batches used.
tf_inputs = []
for fw, sess in framework_iterator(
config, frameworks=("tf", "torch"), session=True):
# Generate Trainer and get its default Policy object.
trainer = sac.SACTrainer(config=config, env=env)
policy = trainer.get_policy()
p_sess = None
if sess:
p_sess = policy.get_session()
# Set all weights (of all nets) to fixed values.
if weights_dict is None:
# Start with the tf vars-dict.
assert fw in ["tf2", "tf", "tfe"]
weights_dict = policy.get_weights()
if fw == "tfe":
log_alpha = weights_dict[10]
weights_dict = self._translate_tfe_weights(
weights_dict, map_)
else:
assert fw == "torch" # Then transfer that to torch Model.
model_dict = self._translate_weights_to_torch(
weights_dict, map_)
# Have to add this here (not a parameter in tf, but must be
# one in torch, so it gets properly copied to the GPU(s)).
model_dict["target_entropy"] = policy.model.target_entropy
policy.model.load_state_dict(model_dict)
policy.target_model.load_state_dict(model_dict)
if fw == "tf":
log_alpha = weights_dict["default_policy/log_alpha"]
elif fw == "torch":
# Actually convert to torch tensors (by accessing everything).
input_ = policy._lazy_tensor_dict(input_)
input_ = {k: input_[k] for k in input_.keys()}
log_alpha = policy.model.log_alpha.detach().cpu().numpy()[0]
# Only run the expectation once, should be the same anyways
# for all frameworks.
if expect_c is None:
expect_c, expect_a, expect_e, expect_t = \
self._sac_loss_helper(input_, weights_dict,
sorted(weights_dict.keys()),
log_alpha, fw,
gamma=config["gamma"], sess=sess)
# Get actual outs and compare to expectation AND previous
# framework. c=critic, a=actor, e=entropy, t=td-error.
if fw == "tf":
c, a, e, t, tf_c_grads, tf_a_grads, tf_e_grads = \
p_sess.run([
policy.critic_loss,
policy.actor_loss,
policy.alpha_loss,
policy.td_error,
policy.optimizer().compute_gradients(
policy.critic_loss[0],
[v for v in policy.model.q_variables() if
"value_" not in v.name]),
policy.optimizer().compute_gradients(
policy.actor_loss,
[v for v in policy.model.policy_variables() if
"value_" not in v.name]),
policy.optimizer().compute_gradients(
policy.alpha_loss, policy.model.log_alpha)],
feed_dict=policy._get_loss_inputs_dict(
input_, shuffle=False))
tf_c_grads = [g for g, v in tf_c_grads]
tf_a_grads = [g for g, v in tf_a_grads]
tf_e_grads = [g for g, v in tf_e_grads]
elif fw == "tfe":
with tf.GradientTape() as tape:
tf_loss(policy, policy.model, None, input_)
c, a, e, t = policy.critic_loss, policy.actor_loss, \
policy.alpha_loss, policy.td_error
vars = tape.watched_variables()
tf_c_grads = tape.gradient(c[0], vars[6:10])
tf_a_grads = tape.gradient(a, vars[2:6])
tf_e_grads = tape.gradient(e, vars[10])
elif fw == "torch":
loss_torch(policy, policy.model, None, input_)
c, a, e, t = policy.get_tower_stats("critic_loss")[0], \
policy.get_tower_stats("actor_loss")[0], \
policy.get_tower_stats("alpha_loss")[0], \
policy.get_tower_stats("td_error")[0]
# Test actor gradients.
policy.actor_optim.zero_grad()
assert all(v.grad is None for v in policy.model.q_variables())
assert all(
v.grad is None for v in policy.model.policy_variables())
assert policy.model.log_alpha.grad is None
a.backward()
# `actor_loss` depends on Q-net vars (but these grads must
# be ignored and overridden in critic_loss.backward!).
assert not all(
torch.mean(v.grad) == 0
for v in policy.model.policy_variables())
assert not all(
torch.min(v.grad) == 0
for v in policy.model.policy_variables())
assert policy.model.log_alpha.grad is None
# Compare with tf ones.
torch_a_grads = [
v.grad for v in policy.model.policy_variables()
if v.grad is not None
]
check(tf_a_grads[2],
np.transpose(torch_a_grads[0].detach().cpu()))
# Test critic gradients.
policy.critic_optims[0].zero_grad()
assert all(
torch.mean(v.grad) == 0.0
for v in policy.model.q_variables() if v.grad is not None)
assert all(
torch.min(v.grad) == 0.0
for v in policy.model.q_variables() if v.grad is not None)
assert policy.model.log_alpha.grad is None
c[0].backward()
assert not all(
torch.mean(v.grad) == 0
for v in policy.model.q_variables() if v.grad is not None)
assert not all(
torch.min(v.grad) == 0 for v in policy.model.q_variables()
if v.grad is not None)
assert policy.model.log_alpha.grad is None
# Compare with tf ones.
torch_c_grads = [v.grad for v in policy.model.q_variables()]
check(tf_c_grads[0],
np.transpose(torch_c_grads[2].detach().cpu()))
# Compare (unchanged(!) actor grads) with tf ones.
torch_a_grads = [
v.grad for v in policy.model.policy_variables()
]
check(tf_a_grads[2],
np.transpose(torch_a_grads[0].detach().cpu()))
# Test alpha gradient.
policy.alpha_optim.zero_grad()
assert policy.model.log_alpha.grad is None
e.backward()
assert policy.model.log_alpha.grad is not None
check(policy.model.log_alpha.grad, tf_e_grads)
check(c, expect_c)
check(a, expect_a)
check(e, expect_e)
check(t, expect_t)
# Store this framework's losses in prev_fw_loss to compare with
# next framework's outputs.
if prev_fw_loss is not None:
check(c, prev_fw_loss[0])
check(a, prev_fw_loss[1])
check(e, prev_fw_loss[2])
check(t, prev_fw_loss[3])
prev_fw_loss = (c, a, e, t)
# Update weights from our batch (n times).
for update_iteration in range(5):
print("train iteration {}".format(update_iteration))
if fw == "tf":
in_ = self._get_batch_helper(obs_size, actions, batch_size)
tf_inputs.append(in_)
# Set a fake-batch to use
# (instead of sampling from replay buffer).
buf = LocalReplayBuffer.get_instance_for_testing()
buf._fake_batch = in_
trainer.train()
updated_weights = policy.get_weights()
# Net must have changed.
if tf_updated_weights:
check(
updated_weights["default_policy/fc_1/kernel"],
tf_updated_weights[-1][
"default_policy/fc_1/kernel"],
false=True)
tf_updated_weights.append(updated_weights)
# Compare with updated tf-weights. Must all be the same.
else:
tf_weights = tf_updated_weights[update_iteration]
in_ = tf_inputs[update_iteration]
# Set a fake-batch to use
# (instead of sampling from replay buffer).
buf = LocalReplayBuffer.get_instance_for_testing()
buf._fake_batch = in_
trainer.train()
# Compare updated model.
for tf_key in sorted(tf_weights.keys()):
if re.search("_[23]|alpha", tf_key):
continue
tf_var = tf_weights[tf_key]
torch_var = policy.model.state_dict()[map_[tf_key]]
if tf_var.shape != torch_var.shape:
check(
tf_var,
np.transpose(torch_var.detach().cpu()),
atol=0.003)
else:
check(tf_var, torch_var, atol=0.003)
# And alpha.
check(policy.model.log_alpha,
tf_weights["default_policy/log_alpha"])
# Compare target nets.
for tf_key in sorted(tf_weights.keys()):
if not re.search("_[23]", tf_key):
continue
tf_var = tf_weights[tf_key]
torch_var = policy.target_model.state_dict()[map_[
tf_key]]
if tf_var.shape != torch_var.shape:
check(
tf_var,
np.transpose(torch_var.detach().cpu()),
atol=0.003)
else:
check(tf_var, torch_var, atol=0.003)
trainer.stop()
def _get_batch_helper(self, obs_size, actions, batch_size):
return SampleBatch({
SampleBatch.CUR_OBS: np.random.random(size=obs_size),
SampleBatch.ACTIONS: actions,
SampleBatch.REWARDS: np.random.random(size=(batch_size, )),
SampleBatch.DONES: np.random.choice(
[True, False], size=(batch_size, )),
SampleBatch.NEXT_OBS: np.random.random(size=obs_size),
"weights": np.random.random(size=(batch_size, )),
})
def _sac_loss_helper(self, train_batch, weights, ks, log_alpha, fw, gamma,
sess):
"""Emulates SAC loss functions for tf and torch."""
# ks:
# 0=log_alpha
# 1=target log-alpha (not used)
# 2=action hidden bias
# 3=action hidden kernel
# 4=action out bias
# 5=action out kernel
# 6=Q hidden bias
# 7=Q hidden kernel
# 8=Q out bias
# 9=Q out kernel
# 14=target Q hidden bias
# 15=target Q hidden kernel
# 16=target Q out bias
# 17=target Q out kernel
alpha = np.exp(log_alpha)
# cls = TorchSquashedGaussian if fw == "torch" else SquashedGaussian
cls = TorchDirichlet if fw == "torch" else Dirichlet
model_out_t = train_batch[SampleBatch.CUR_OBS]
model_out_tp1 = train_batch[SampleBatch.NEXT_OBS]
target_model_out_tp1 = train_batch[SampleBatch.NEXT_OBS]
# get_policy_output
action_dist_t = cls(
fc(
relu(
fc(model_out_t,
weights[ks[1]],
weights[ks[0]],
framework=fw)), weights[ks[9]], weights[ks[8]]), None)
policy_t = action_dist_t.deterministic_sample()
log_pis_t = action_dist_t.logp(policy_t)
if sess:
log_pis_t = sess.run(log_pis_t)
policy_t = sess.run(policy_t)
log_pis_t = np.expand_dims(log_pis_t, -1)
# Get policy output for t+1.
action_dist_tp1 = cls(
fc(
relu(
fc(model_out_tp1,
weights[ks[1]],
weights[ks[0]],
framework=fw)), weights[ks[9]], weights[ks[8]]), None)
policy_tp1 = action_dist_tp1.deterministic_sample()
log_pis_tp1 = action_dist_tp1.logp(policy_tp1)
if sess:
log_pis_tp1 = sess.run(log_pis_tp1)
policy_tp1 = sess.run(policy_tp1)
log_pis_tp1 = np.expand_dims(log_pis_tp1, -1)
# Q-values for the actually selected actions.
# get_q_values
q_t = fc(
relu(
fc(np.concatenate(
[model_out_t, train_batch[SampleBatch.ACTIONS]], -1),
weights[ks[3]],
weights[ks[2]],
framework=fw)),
weights[ks[11]],
weights[ks[10]],
framework=fw)
# Q-values for current policy in given current state.
# get_q_values
q_t_det_policy = fc(
relu(
fc(np.concatenate([model_out_t, policy_t], -1),
weights[ks[3]],
weights[ks[2]],
framework=fw)),
weights[ks[11]],
weights[ks[10]],
framework=fw)
# Target q network evaluation.
# target_model.get_q_values
if fw == "tf":
q_tp1 = fc(
relu(
fc(np.concatenate([target_model_out_tp1, policy_tp1], -1),
weights[ks[7]],
weights[ks[6]],
framework=fw)),
weights[ks[15]],
weights[ks[14]],
framework=fw)
else:
assert fw == "tfe"
q_tp1 = fc(
relu(
fc(np.concatenate([target_model_out_tp1, policy_tp1], -1),
weights[ks[7]],
weights[ks[6]],
framework=fw)),
weights[ks[9]],
weights[ks[8]],
framework=fw)
q_t_selected = np.squeeze(q_t, axis=-1)
q_tp1 -= alpha * log_pis_tp1
q_tp1_best = np.squeeze(q_tp1, axis=-1)
dones = train_batch[SampleBatch.DONES]
rewards = train_batch[SampleBatch.REWARDS]
if fw == "torch":
dones = dones.float().numpy()
rewards = rewards.numpy()
q_tp1_best_masked = (1.0 - dones) * q_tp1_best
q_t_selected_target = rewards + gamma * q_tp1_best_masked
base_td_error = np.abs(q_t_selected - q_t_selected_target)
td_error = base_td_error
critic_loss = [
np.mean(train_batch["weights"] *
huber_loss(q_t_selected_target - q_t_selected))
]
target_entropy = -np.prod((1, ))
alpha_loss = -np.mean(log_alpha * (log_pis_t + target_entropy))
actor_loss = np.mean(alpha * log_pis_t - q_t_det_policy)
return critic_loss, actor_loss, alpha_loss, td_error
def _translate_weights_to_torch(self, weights_dict, map_):
model_dict = {
map_[k]: convert_to_torch_tensor(
np.transpose(v) if re.search("kernel", k) else np.array([v])
if re.search("log_alpha", k) else v)
for i, (k, v) in enumerate(weights_dict.items()) if i < 13
}
return model_dict
def _translate_tfe_weights(self, weights_dict, map_):
model_dict = {
"default_policy/log_alpha": None,
"default_policy/log_alpha_target": None,
"default_policy/sequential/action_1/kernel": weights_dict[2],
"default_policy/sequential/action_1/bias": weights_dict[3],
"default_policy/sequential/action_out/kernel": weights_dict[4],
"default_policy/sequential/action_out/bias": weights_dict[5],
"default_policy/sequential_1/q_hidden_0/kernel": weights_dict[6],
"default_policy/sequential_1/q_hidden_0/bias": weights_dict[7],
"default_policy/sequential_1/q_out/kernel": weights_dict[8],
"default_policy/sequential_1/q_out/bias": weights_dict[9],
"default_policy/value_out/kernel": weights_dict[0],
"default_policy/value_out/bias": weights_dict[1],
}
return model_dict
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))