2020-04-15 13:25:16 +02:00
|
|
|
from gym import Env
|
2021-10-06 09:05:50 +02:00
|
|
|
from gym.spaces import Box, Dict, Discrete, Tuple
|
2020-04-15 13:25:16 +02:00
|
|
|
import numpy as np
|
|
|
|
import re
|
2020-03-06 19:37:12 +01:00
|
|
|
import unittest
|
|
|
|
|
2020-10-27 10:00:24 +01:00
|
|
|
import ray
|
2020-03-06 19:37:12 +01:00
|
|
|
import ray.rllib.agents.sac as sac
|
2020-07-11 22:06:35 +02:00
|
|
|
from ray.rllib.agents.sac.sac_tf_policy import sac_actor_critic_loss as tf_loss
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.agents.sac.sac_torch_policy import actor_critic_loss as loss_torch
|
2021-02-02 13:05:58 +01:00
|
|
|
from ray.rllib.examples.env.random_env import RandomEnv
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.examples.models.batch_norm_model import (
|
|
|
|
KerasBatchNormModel,
|
|
|
|
TorchBatchNormModel,
|
|
|
|
)
|
2021-02-02 13:05:58 +01:00
|
|
|
from ray.rllib.models.catalog import ModelCatalog
|
2020-11-11 18:45:28 +01:00
|
|
|
from ray.rllib.models.tf.tf_action_dist import Dirichlet
|
|
|
|
from ray.rllib.models.torch.torch_action_dist import TorchDirichlet
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
|
2020-04-15 13:25:16 +02:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2020-07-11 22:06:35 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
2020-11-25 20:28:46 +01:00
|
|
|
from ray.rllib.utils.numpy import fc, huber_loss, relu
|
2020-11-11 18:45:28 +01:00
|
|
|
from ray.rllib.utils.spaces.simplex import Simplex
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.utils.test_utils import (
|
|
|
|
check,
|
|
|
|
check_compute_single_action,
|
|
|
|
check_train_results,
|
|
|
|
framework_iterator,
|
|
|
|
)
|
2021-11-03 10:00:46 +01:00
|
|
|
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
|
2021-10-06 09:05:50 +02:00
|
|
|
from ray import tune
|
2020-03-06 19:37:12 +01:00
|
|
|
|
2020-07-11 22:06:35 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2020-04-15 13:25:16 +02:00
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
|
|
|
|
|
|
class SimpleEnv(Env):
|
|
|
|
def __init__(self, config):
|
2022-02-17 05:06:14 -08:00
|
|
|
self._skip_env_checking = True
|
2020-11-11 18:45:28 +01:00
|
|
|
if config.get("simplex_actions", False):
|
2022-01-29 18:41:57 -08:00
|
|
|
self.action_space = Simplex((2,))
|
2020-11-11 18:45:28 +01:00
|
|
|
else:
|
2022-01-29 18:41:57 -08:00
|
|
|
self.action_space = Box(0.0, 1.0, (1,))
|
|
|
|
self.observation_space = Box(0.0, 1.0, (1,))
|
2020-04-15 13:25:16 +02:00
|
|
|
self.max_steps = config.get("max_steps", 100)
|
|
|
|
self.state = None
|
|
|
|
self.steps = None
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.state = self.observation_space.sample()
|
|
|
|
self.steps = 0
|
|
|
|
return self.state
|
|
|
|
|
|
|
|
def step(self, action):
|
|
|
|
self.steps += 1
|
2020-11-11 18:45:28 +01:00
|
|
|
# Reward is 1.0 - (max(actions) - state).
|
|
|
|
[r] = 1.0 - np.abs(np.max(action) - self.state)
|
2020-04-15 13:25:16 +02:00
|
|
|
d = self.steps >= self.max_steps
|
|
|
|
self.state = self.observation_space.sample()
|
|
|
|
return self.state, r, d, {}
|
2020-03-06 19:37:12 +01:00
|
|
|
|
|
|
|
|
|
|
|
class TestSAC(unittest.TestCase):
|
2020-10-27 10:00:24 +01:00
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
2021-05-16 12:20:33 +02:00
|
|
|
np.random.seed(42)
|
|
|
|
torch.manual_seed(42)
|
2021-03-05 08:16:24 +01:00
|
|
|
ray.init()
|
2020-10-27 10:00:24 +01:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
|
|
|
ray.shutdown()
|
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
def test_sac_compilation(self):
|
2020-04-17 08:49:15 +02:00
|
|
|
"""Tests whether an SACTrainer can be built with all frameworks."""
|
2020-03-06 19:37:12 +01:00
|
|
|
config = sac.DEFAULT_CONFIG.copy()
|
2021-02-02 13:05:58 +01:00
|
|
|
config["Q_model"] = sac.DEFAULT_CONFIG["Q_model"].copy()
|
2020-03-06 19:37:12 +01:00
|
|
|
config["num_workers"] = 0 # Run locally.
|
2021-09-06 12:14:20 +02:00
|
|
|
config["n_step"] = 3
|
2020-04-15 13:25:16 +02:00
|
|
|
config["twin_q"] = True
|
|
|
|
config["learning_starts"] = 0
|
|
|
|
config["prioritized_replay"] = True
|
2021-02-02 13:05:58 +01:00
|
|
|
config["rollout_fragment_length"] = 10
|
|
|
|
config["train_batch_size"] = 10
|
2021-03-31 01:24:58 +08:00
|
|
|
# If we use default buffer size (1e6), the buffer will take up
|
|
|
|
# 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021)
|
|
|
|
# available system memory (8.34816 GB).
|
|
|
|
config["buffer_size"] = 40000
|
2021-08-31 12:21:49 +02:00
|
|
|
# Test with saved replay buffer.
|
|
|
|
config["store_buffer_in_checkpoints"] = True
|
2020-03-06 19:37:12 +01:00
|
|
|
num_iterations = 1
|
2021-02-02 13:05:58 +01:00
|
|
|
|
|
|
|
ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel)
|
2022-01-29 18:41:57 -08:00
|
|
|
ModelCatalog.register_custom_model("batch_norm_torch", TorchBatchNormModel)
|
2021-02-02 13:05:58 +01:00
|
|
|
|
|
|
|
image_space = Box(-1.0, 1.0, shape=(84, 84, 3))
|
2022-01-29 18:41:57 -08:00
|
|
|
simple_space = Box(-1.0, 1.0, shape=(3,))
|
2021-02-02 13:05:58 +01:00
|
|
|
|
2021-10-06 09:05:50 +02:00
|
|
|
tune.register_env(
|
2022-01-29 18:41:57 -08:00
|
|
|
"random_dict_env",
|
|
|
|
lambda _: RandomEnv(
|
|
|
|
{
|
|
|
|
"observation_space": Dict(
|
|
|
|
{
|
|
|
|
"a": simple_space,
|
|
|
|
"b": Discrete(2),
|
|
|
|
"c": image_space,
|
|
|
|
}
|
|
|
|
),
|
|
|
|
"action_space": Box(-1.0, 1.0, shape=(1,)),
|
|
|
|
}
|
|
|
|
),
|
|
|
|
)
|
2021-10-06 09:05:50 +02:00
|
|
|
tune.register_env(
|
2022-01-29 18:41:57 -08:00
|
|
|
"random_tuple_env",
|
|
|
|
lambda _: RandomEnv(
|
|
|
|
{
|
|
|
|
"observation_space": Tuple(
|
|
|
|
[simple_space, Discrete(2), image_space]
|
|
|
|
),
|
|
|
|
"action_space": Box(-1.0, 1.0, shape=(1,)),
|
|
|
|
}
|
|
|
|
),
|
|
|
|
)
|
2021-10-06 09:05:50 +02:00
|
|
|
|
2021-11-02 12:10:17 +01:00
|
|
|
for fw in framework_iterator(config, with_eager_tracing=True):
|
2020-04-15 13:25:16 +02:00
|
|
|
# Test for different env types (discrete w/ and w/o image, + cont).
|
2020-03-06 19:37:12 +01:00
|
|
|
for env in [
|
2022-01-29 18:41:57 -08:00
|
|
|
"random_dict_env",
|
|
|
|
"random_tuple_env",
|
|
|
|
# "MsPacmanNoFrameskip-v4",
|
|
|
|
"CartPole-v0",
|
2020-03-06 19:37:12 +01:00
|
|
|
]:
|
|
|
|
print("Env={}".format(env))
|
2021-02-02 13:05:58 +01:00
|
|
|
# Test making the Q-model a custom one for CartPole, otherwise,
|
|
|
|
# use the default model.
|
2022-01-29 18:41:57 -08:00
|
|
|
config["Q_model"]["custom_model"] = (
|
|
|
|
"batch_norm{}".format("_torch" if fw == "torch" else "")
|
|
|
|
if env == "CartPole-v0"
|
|
|
|
else None
|
|
|
|
)
|
2020-03-06 19:37:12 +01:00
|
|
|
trainer = sac.SACTrainer(config=config, env=env)
|
|
|
|
for i in range(num_iterations):
|
|
|
|
results = trainer.train()
|
2021-09-30 16:39:05 +02:00
|
|
|
check_train_results(results)
|
2020-03-06 19:37:12 +01:00
|
|
|
print(results)
|
2020-06-13 17:51:50 +02:00
|
|
|
check_compute_single_action(trainer)
|
2021-08-31 12:21:49 +02:00
|
|
|
|
|
|
|
# Test, whether the replay buffer is saved along with
|
|
|
|
# a checkpoint (no point in doing it for all frameworks since
|
|
|
|
# this is framework agnostic).
|
|
|
|
if fw == "tf" and env == "CartPole-v0":
|
|
|
|
checkpoint = trainer.save()
|
|
|
|
new_trainer = sac.SACTrainer(config, env=env)
|
|
|
|
new_trainer.restore(checkpoint)
|
|
|
|
# Get some data from the buffer and compare.
|
|
|
|
data = trainer.local_replay_buffer.replay_buffers[
|
2022-01-29 18:41:57 -08:00
|
|
|
"default_policy"
|
|
|
|
]._storage[: 42 + 42]
|
2021-08-31 12:21:49 +02:00
|
|
|
new_data = new_trainer.local_replay_buffer.replay_buffers[
|
2022-01-29 18:41:57 -08:00
|
|
|
"default_policy"
|
|
|
|
]._storage[: 42 + 42]
|
2021-08-31 12:21:49 +02:00
|
|
|
check(data, new_data)
|
|
|
|
new_trainer.stop()
|
|
|
|
|
2020-06-25 19:01:32 +02:00
|
|
|
trainer.stop()
|
2020-03-06 19:37:12 +01:00
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
def test_sac_loss_function(self):
|
2020-04-26 23:08:13 +02:00
|
|
|
"""Tests SAC loss function results across all frameworks."""
|
2020-04-15 13:25:16 +02:00
|
|
|
config = sac.DEFAULT_CONFIG.copy()
|
|
|
|
# Run locally.
|
2022-02-17 05:06:14 -08:00
|
|
|
config["seed"] = 42
|
2020-04-15 13:25:16 +02:00
|
|
|
config["num_workers"] = 0
|
|
|
|
config["learning_starts"] = 0
|
|
|
|
config["twin_q"] = False
|
|
|
|
config["gamma"] = 0.99
|
|
|
|
# Switch on deterministic loss so we can compare the loss values.
|
|
|
|
config["_deterministic_loss"] = True
|
|
|
|
# Use very simple nets.
|
|
|
|
config["Q_model"]["fcnet_hiddens"] = [10]
|
|
|
|
config["policy_model"]["fcnet_hiddens"] = [10]
|
|
|
|
# Make sure, timing differences do not affect trainer.train().
|
2022-01-25 14:16:58 +01:00
|
|
|
config["min_time_s_per_reporting"] = 0
|
2020-11-11 18:45:28 +01:00
|
|
|
# Test SAC with Simplex action space.
|
|
|
|
config["env_config"] = {"simplex_actions": True}
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
map_ = {
|
2021-02-02 13:05:58 +01:00
|
|
|
# Action net.
|
|
|
|
"default_policy/fc_1/kernel": "action_model._hidden_layers.0."
|
2020-04-15 13:25:16 +02:00
|
|
|
"_model.0.weight",
|
2021-02-02 13:05:58 +01:00
|
|
|
"default_policy/fc_1/bias": "action_model._hidden_layers.0."
|
2020-04-15 13:25:16 +02:00
|
|
|
"_model.0.bias",
|
2022-03-15 09:34:21 -07:00
|
|
|
"default_policy/fc_out/kernel": "action_model._logits._model.0.weight",
|
2021-02-02 13:05:58 +01:00
|
|
|
"default_policy/fc_out/bias": "action_model._logits._model.0.bias",
|
|
|
|
"default_policy/value_out/kernel": "action_model."
|
|
|
|
"_value_branch._model.0.weight",
|
|
|
|
"default_policy/value_out/bias": "action_model."
|
|
|
|
"_value_branch._model.0.bias",
|
|
|
|
# Q-net.
|
2022-03-15 09:34:21 -07:00
|
|
|
"default_policy/fc_1_1/kernel": "q_net._hidden_layers.0._model.0.weight",
|
|
|
|
"default_policy/fc_1_1/bias": "q_net._hidden_layers.0._model.0.bias",
|
2021-02-02 13:05:58 +01:00
|
|
|
"default_policy/fc_out_1/kernel": "q_net._logits._model.0.weight",
|
|
|
|
"default_policy/fc_out_1/bias": "q_net._logits._model.0.bias",
|
|
|
|
"default_policy/value_out_1/kernel": "q_net."
|
|
|
|
"_value_branch._model.0.weight",
|
2022-03-15 09:34:21 -07:00
|
|
|
"default_policy/value_out_1/bias": "q_net._value_branch._model.0.bias",
|
2020-10-12 22:48:44 +02:00
|
|
|
"default_policy/log_alpha": "log_alpha",
|
2021-02-02 13:05:58 +01:00
|
|
|
# Target action-net.
|
|
|
|
"default_policy/fc_1_2/kernel": "action_model."
|
|
|
|
"_hidden_layers.0._model.0.weight",
|
|
|
|
"default_policy/fc_1_2/bias": "action_model."
|
|
|
|
"_hidden_layers.0._model.0.bias",
|
2022-03-15 09:34:21 -07:00
|
|
|
"default_policy/fc_out_2/kernel": "action_model._logits._model.0.weight",
|
|
|
|
"default_policy/fc_out_2/bias": "action_model._logits._model.0.bias",
|
2021-02-02 13:05:58 +01:00
|
|
|
"default_policy/value_out_2/kernel": "action_model."
|
|
|
|
"_value_branch._model.0.weight",
|
|
|
|
"default_policy/value_out_2/bias": "action_model."
|
|
|
|
"_value_branch._model.0.bias",
|
|
|
|
# Target Q-net
|
2022-03-15 09:34:21 -07:00
|
|
|
"default_policy/fc_1_3/kernel": "q_net._hidden_layers.0._model.0.weight",
|
|
|
|
"default_policy/fc_1_3/bias": "q_net._hidden_layers.0._model.0.bias",
|
|
|
|
"default_policy/fc_out_3/kernel": "q_net._logits._model.0.weight",
|
|
|
|
"default_policy/fc_out_3/bias": "q_net._logits._model.0.bias",
|
2021-02-02 13:05:58 +01:00
|
|
|
"default_policy/value_out_3/kernel": "q_net."
|
|
|
|
"_value_branch._model.0.weight",
|
2022-03-15 09:34:21 -07:00
|
|
|
"default_policy/value_out_3/bias": "q_net._value_branch._model.0.bias",
|
2020-10-12 22:48:44 +02:00
|
|
|
"default_policy/log_alpha_1": "log_alpha",
|
2020-04-15 13:25:16 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
env = SimpleEnv
|
|
|
|
batch_size = 100
|
2021-02-10 15:10:01 +01:00
|
|
|
obs_size = (batch_size, 1)
|
|
|
|
actions = np.random.random(size=(batch_size, 2))
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
# Batch of size=n.
|
|
|
|
input_ = self._get_batch_helper(obs_size, actions, batch_size)
|
|
|
|
|
|
|
|
# Simply compare loss values AND grads of all frameworks with each
|
|
|
|
# other.
|
|
|
|
prev_fw_loss = weights_dict = None
|
|
|
|
expect_c, expect_a, expect_e, expect_t = None, None, None, None
|
|
|
|
# History of tf-updated NN-weights over n training steps.
|
|
|
|
tf_updated_weights = []
|
|
|
|
# History of input batches used.
|
|
|
|
tf_inputs = []
|
|
|
|
for fw, sess in framework_iterator(
|
2022-01-29 18:41:57 -08:00
|
|
|
config, frameworks=("tf", "torch"), session=True
|
|
|
|
):
|
2020-04-15 13:25:16 +02:00
|
|
|
# Generate Trainer and get its default Policy object.
|
|
|
|
trainer = sac.SACTrainer(config=config, env=env)
|
|
|
|
policy = trainer.get_policy()
|
|
|
|
p_sess = None
|
|
|
|
if sess:
|
|
|
|
p_sess = policy.get_session()
|
|
|
|
|
|
|
|
# Set all weights (of all nets) to fixed values.
|
|
|
|
if weights_dict is None:
|
2020-10-02 23:07:44 +02:00
|
|
|
# Start with the tf vars-dict.
|
|
|
|
assert fw in ["tf2", "tf", "tfe"]
|
2020-04-15 13:25:16 +02:00
|
|
|
weights_dict = policy.get_weights()
|
2020-07-11 22:06:35 +02:00
|
|
|
if fw == "tfe":
|
|
|
|
log_alpha = weights_dict[10]
|
2022-01-29 18:41:57 -08:00
|
|
|
weights_dict = self._translate_tfe_weights(weights_dict, map_)
|
2020-04-15 13:25:16 +02:00
|
|
|
else:
|
|
|
|
assert fw == "torch" # Then transfer that to torch Model.
|
2022-01-29 18:41:57 -08:00
|
|
|
model_dict = self._translate_weights_to_torch(weights_dict, map_)
|
2021-08-03 11:35:49 -04:00
|
|
|
# Have to add this here (not a parameter in tf, but must be
|
|
|
|
# one in torch, so it gets properly copied to the GPU(s)).
|
|
|
|
model_dict["target_entropy"] = policy.model.target_entropy
|
2020-04-15 13:25:16 +02:00
|
|
|
policy.model.load_state_dict(model_dict)
|
|
|
|
policy.target_model.load_state_dict(model_dict)
|
|
|
|
|
|
|
|
if fw == "tf":
|
|
|
|
log_alpha = weights_dict["default_policy/log_alpha"]
|
|
|
|
elif fw == "torch":
|
2020-06-25 19:01:32 +02:00
|
|
|
# Actually convert to torch tensors (by accessing everything).
|
2020-04-15 13:25:16 +02:00
|
|
|
input_ = policy._lazy_tensor_dict(input_)
|
|
|
|
input_ = {k: input_[k] for k in input_.keys()}
|
2020-10-12 22:48:44 +02:00
|
|
|
log_alpha = policy.model.log_alpha.detach().cpu().numpy()[0]
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
# Only run the expectation once, should be the same anyways
|
|
|
|
# for all frameworks.
|
|
|
|
if expect_c is None:
|
2022-01-29 18:41:57 -08:00
|
|
|
expect_c, expect_a, expect_e, expect_t = self._sac_loss_helper(
|
|
|
|
input_,
|
|
|
|
weights_dict,
|
|
|
|
sorted(weights_dict.keys()),
|
|
|
|
log_alpha,
|
|
|
|
fw,
|
|
|
|
gamma=config["gamma"],
|
|
|
|
sess=sess,
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
# Get actual outs and compare to expectation AND previous
|
|
|
|
# framework. c=critic, a=actor, e=entropy, t=td-error.
|
|
|
|
if fw == "tf":
|
2022-01-29 18:41:57 -08:00
|
|
|
c, a, e, t, tf_c_grads, tf_a_grads, tf_e_grads = p_sess.run(
|
|
|
|
[
|
2020-04-15 13:25:16 +02:00
|
|
|
policy.critic_loss,
|
|
|
|
policy.actor_loss,
|
|
|
|
policy.alpha_loss,
|
|
|
|
policy.td_error,
|
|
|
|
policy.optimizer().compute_gradients(
|
|
|
|
policy.critic_loss[0],
|
2022-01-29 18:41:57 -08:00
|
|
|
[
|
|
|
|
v
|
|
|
|
for v in policy.model.q_variables()
|
|
|
|
if "value_" not in v.name
|
|
|
|
],
|
|
|
|
),
|
2020-04-15 13:25:16 +02:00
|
|
|
policy.optimizer().compute_gradients(
|
|
|
|
policy.actor_loss,
|
2022-01-29 18:41:57 -08:00
|
|
|
[
|
|
|
|
v
|
|
|
|
for v in policy.model.policy_variables()
|
|
|
|
if "value_" not in v.name
|
|
|
|
],
|
|
|
|
),
|
2020-04-15 13:25:16 +02:00
|
|
|
policy.optimizer().compute_gradients(
|
2022-01-29 18:41:57 -08:00
|
|
|
policy.alpha_loss, policy.model.log_alpha
|
|
|
|
),
|
|
|
|
],
|
|
|
|
feed_dict=policy._get_loss_inputs_dict(input_, shuffle=False),
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
tf_c_grads = [g for g, v in tf_c_grads]
|
|
|
|
tf_a_grads = [g for g, v in tf_a_grads]
|
|
|
|
tf_e_grads = [g for g, v in tf_e_grads]
|
|
|
|
|
2020-07-11 22:06:35 +02:00
|
|
|
elif fw == "tfe":
|
|
|
|
with tf.GradientTape() as tape:
|
|
|
|
tf_loss(policy, policy.model, None, input_)
|
2022-01-29 18:41:57 -08:00
|
|
|
c, a, e, t = (
|
|
|
|
policy.critic_loss,
|
|
|
|
policy.actor_loss,
|
|
|
|
policy.alpha_loss,
|
|
|
|
policy.td_error,
|
|
|
|
)
|
2020-07-11 22:06:35 +02:00
|
|
|
vars = tape.watched_variables()
|
|
|
|
tf_c_grads = tape.gradient(c[0], vars[6:10])
|
|
|
|
tf_a_grads = tape.gradient(a, vars[2:6])
|
|
|
|
tf_e_grads = tape.gradient(e, vars[10])
|
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
elif fw == "torch":
|
|
|
|
loss_torch(policy, policy.model, None, input_)
|
2022-01-29 18:41:57 -08:00
|
|
|
c, a, e, t = (
|
|
|
|
policy.get_tower_stats("critic_loss")[0],
|
|
|
|
policy.get_tower_stats("actor_loss")[0],
|
|
|
|
policy.get_tower_stats("alpha_loss")[0],
|
|
|
|
policy.get_tower_stats("td_error")[0],
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
# Test actor gradients.
|
|
|
|
policy.actor_optim.zero_grad()
|
|
|
|
assert all(v.grad is None for v in policy.model.q_variables())
|
2022-01-29 18:41:57 -08:00
|
|
|
assert all(v.grad is None for v in policy.model.policy_variables())
|
2020-04-15 13:25:16 +02:00
|
|
|
assert policy.model.log_alpha.grad is None
|
|
|
|
a.backward()
|
|
|
|
# `actor_loss` depends on Q-net vars (but these grads must
|
|
|
|
# be ignored and overridden in critic_loss.backward!).
|
|
|
|
assert not all(
|
2022-01-29 18:41:57 -08:00
|
|
|
torch.mean(v.grad) == 0 for v in policy.model.policy_variables()
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
assert not all(
|
2022-01-29 18:41:57 -08:00
|
|
|
torch.min(v.grad) == 0 for v in policy.model.policy_variables()
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
assert policy.model.log_alpha.grad is None
|
|
|
|
# Compare with tf ones.
|
|
|
|
torch_a_grads = [
|
2022-01-29 18:41:57 -08:00
|
|
|
v.grad
|
|
|
|
for v in policy.model.policy_variables()
|
2021-02-02 13:05:58 +01:00
|
|
|
if v.grad is not None
|
2020-04-15 13:25:16 +02:00
|
|
|
]
|
2022-01-29 18:41:57 -08:00
|
|
|
check(tf_a_grads[2], np.transpose(torch_a_grads[0].detach().cpu()))
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
# Test critic gradients.
|
|
|
|
policy.critic_optims[0].zero_grad()
|
|
|
|
assert all(
|
|
|
|
torch.mean(v.grad) == 0.0
|
2022-01-29 18:41:57 -08:00
|
|
|
for v in policy.model.q_variables()
|
|
|
|
if v.grad is not None
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
assert all(
|
|
|
|
torch.min(v.grad) == 0.0
|
2022-01-29 18:41:57 -08:00
|
|
|
for v in policy.model.q_variables()
|
|
|
|
if v.grad is not None
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
assert policy.model.log_alpha.grad is None
|
|
|
|
c[0].backward()
|
|
|
|
assert not all(
|
|
|
|
torch.mean(v.grad) == 0
|
2022-01-29 18:41:57 -08:00
|
|
|
for v in policy.model.q_variables()
|
|
|
|
if v.grad is not None
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
assert not all(
|
2022-01-29 18:41:57 -08:00
|
|
|
torch.min(v.grad) == 0
|
|
|
|
for v in policy.model.q_variables()
|
|
|
|
if v.grad is not None
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
assert policy.model.log_alpha.grad is None
|
|
|
|
# Compare with tf ones.
|
|
|
|
torch_c_grads = [v.grad for v in policy.model.q_variables()]
|
2022-01-29 18:41:57 -08:00
|
|
|
check(tf_c_grads[0], np.transpose(torch_c_grads[2].detach().cpu()))
|
2020-04-15 13:25:16 +02:00
|
|
|
# Compare (unchanged(!) actor grads) with tf ones.
|
2022-01-29 18:41:57 -08:00
|
|
|
torch_a_grads = [v.grad for v in policy.model.policy_variables()]
|
|
|
|
check(tf_a_grads[2], np.transpose(torch_a_grads[0].detach().cpu()))
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
# Test alpha gradient.
|
|
|
|
policy.alpha_optim.zero_grad()
|
|
|
|
assert policy.model.log_alpha.grad is None
|
|
|
|
e.backward()
|
|
|
|
assert policy.model.log_alpha.grad is not None
|
|
|
|
check(policy.model.log_alpha.grad, tf_e_grads)
|
|
|
|
|
|
|
|
check(c, expect_c)
|
|
|
|
check(a, expect_a)
|
|
|
|
check(e, expect_e)
|
|
|
|
check(t, expect_t)
|
|
|
|
|
|
|
|
# Store this framework's losses in prev_fw_loss to compare with
|
|
|
|
# next framework's outputs.
|
|
|
|
if prev_fw_loss is not None:
|
|
|
|
check(c, prev_fw_loss[0])
|
|
|
|
check(a, prev_fw_loss[1])
|
|
|
|
check(e, prev_fw_loss[2])
|
|
|
|
check(t, prev_fw_loss[3])
|
|
|
|
|
|
|
|
prev_fw_loss = (c, a, e, t)
|
|
|
|
|
|
|
|
# Update weights from our batch (n times).
|
2021-02-02 13:05:58 +01:00
|
|
|
for update_iteration in range(5):
|
2020-04-15 13:25:16 +02:00
|
|
|
print("train iteration {}".format(update_iteration))
|
|
|
|
if fw == "tf":
|
|
|
|
in_ = self._get_batch_helper(obs_size, actions, batch_size)
|
|
|
|
tf_inputs.append(in_)
|
|
|
|
# Set a fake-batch to use
|
|
|
|
# (instead of sampling from replay buffer).
|
2021-11-19 11:57:37 +01:00
|
|
|
buf = MultiAgentReplayBuffer.get_instance_for_testing()
|
2020-05-05 12:36:42 -07:00
|
|
|
buf._fake_batch = in_
|
2020-04-15 13:25:16 +02:00
|
|
|
trainer.train()
|
|
|
|
updated_weights = policy.get_weights()
|
|
|
|
# Net must have changed.
|
|
|
|
if tf_updated_weights:
|
|
|
|
check(
|
2021-02-02 13:05:58 +01:00
|
|
|
updated_weights["default_policy/fc_1/kernel"],
|
2022-01-29 18:41:57 -08:00
|
|
|
tf_updated_weights[-1]["default_policy/fc_1/kernel"],
|
|
|
|
false=True,
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
tf_updated_weights.append(updated_weights)
|
|
|
|
|
|
|
|
# Compare with updated tf-weights. Must all be the same.
|
|
|
|
else:
|
|
|
|
tf_weights = tf_updated_weights[update_iteration]
|
|
|
|
in_ = tf_inputs[update_iteration]
|
|
|
|
# Set a fake-batch to use
|
|
|
|
# (instead of sampling from replay buffer).
|
2021-11-19 11:57:37 +01:00
|
|
|
buf = MultiAgentReplayBuffer.get_instance_for_testing()
|
2020-05-05 12:36:42 -07:00
|
|
|
buf._fake_batch = in_
|
2020-04-15 13:25:16 +02:00
|
|
|
trainer.train()
|
|
|
|
# Compare updated model.
|
2021-02-02 13:05:58 +01:00
|
|
|
for tf_key in sorted(tf_weights.keys()):
|
|
|
|
if re.search("_[23]|alpha", tf_key):
|
|
|
|
continue
|
2020-04-15 13:25:16 +02:00
|
|
|
tf_var = tf_weights[tf_key]
|
|
|
|
torch_var = policy.model.state_dict()[map_[tf_key]]
|
|
|
|
if tf_var.shape != torch_var.shape:
|
2020-10-12 22:48:44 +02:00
|
|
|
check(
|
|
|
|
tf_var,
|
|
|
|
np.transpose(torch_var.detach().cpu()),
|
2022-01-29 18:41:57 -08:00
|
|
|
atol=0.003,
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
else:
|
2021-08-02 17:29:59 -04:00
|
|
|
check(tf_var, torch_var, atol=0.003)
|
2020-04-15 13:25:16 +02:00
|
|
|
# And alpha.
|
2022-01-29 18:41:57 -08:00
|
|
|
check(
|
|
|
|
policy.model.log_alpha, tf_weights["default_policy/log_alpha"]
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
# Compare target nets.
|
2021-02-02 13:05:58 +01:00
|
|
|
for tf_key in sorted(tf_weights.keys()):
|
|
|
|
if not re.search("_[23]", tf_key):
|
|
|
|
continue
|
2020-04-15 13:25:16 +02:00
|
|
|
tf_var = tf_weights[tf_key]
|
2022-01-29 18:41:57 -08:00
|
|
|
torch_var = policy.target_model.state_dict()[map_[tf_key]]
|
2020-04-15 13:25:16 +02:00
|
|
|
if tf_var.shape != torch_var.shape:
|
2020-10-12 22:48:44 +02:00
|
|
|
check(
|
|
|
|
tf_var,
|
|
|
|
np.transpose(torch_var.detach().cpu()),
|
2022-01-29 18:41:57 -08:00
|
|
|
atol=0.003,
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
else:
|
2021-08-02 17:29:59 -04:00
|
|
|
check(tf_var, torch_var, atol=0.003)
|
2021-07-27 14:39:06 -04:00
|
|
|
trainer.stop()
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
def _get_batch_helper(self, obs_size, actions, batch_size):
|
2022-01-29 18:41:57 -08:00
|
|
|
return SampleBatch(
|
|
|
|
{
|
|
|
|
SampleBatch.CUR_OBS: np.random.random(size=obs_size),
|
|
|
|
SampleBatch.ACTIONS: actions,
|
|
|
|
SampleBatch.REWARDS: np.random.random(size=(batch_size,)),
|
|
|
|
SampleBatch.DONES: np.random.choice([True, False], size=(batch_size,)),
|
|
|
|
SampleBatch.NEXT_OBS: np.random.random(size=obs_size),
|
|
|
|
"weights": np.random.random(size=(batch_size,)),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
def _sac_loss_helper(self, train_batch, weights, ks, log_alpha, fw, gamma, sess):
|
2020-04-15 13:25:16 +02:00
|
|
|
"""Emulates SAC loss functions for tf and torch."""
|
|
|
|
# ks:
|
|
|
|
# 0=log_alpha
|
|
|
|
# 1=target log-alpha (not used)
|
|
|
|
|
|
|
|
# 2=action hidden bias
|
|
|
|
# 3=action hidden kernel
|
|
|
|
# 4=action out bias
|
|
|
|
# 5=action out kernel
|
|
|
|
|
|
|
|
# 6=Q hidden bias
|
|
|
|
# 7=Q hidden kernel
|
|
|
|
# 8=Q out bias
|
|
|
|
# 9=Q out kernel
|
|
|
|
|
|
|
|
# 14=target Q hidden bias
|
|
|
|
# 15=target Q hidden kernel
|
|
|
|
# 16=target Q out bias
|
|
|
|
# 17=target Q out kernel
|
|
|
|
alpha = np.exp(log_alpha)
|
2020-11-11 18:45:28 +01:00
|
|
|
# cls = TorchSquashedGaussian if fw == "torch" else SquashedGaussian
|
|
|
|
cls = TorchDirichlet if fw == "torch" else Dirichlet
|
2020-04-15 13:25:16 +02:00
|
|
|
model_out_t = train_batch[SampleBatch.CUR_OBS]
|
|
|
|
model_out_tp1 = train_batch[SampleBatch.NEXT_OBS]
|
|
|
|
target_model_out_tp1 = train_batch[SampleBatch.NEXT_OBS]
|
|
|
|
|
|
|
|
# get_policy_output
|
|
|
|
action_dist_t = cls(
|
|
|
|
fc(
|
2022-01-29 18:41:57 -08:00
|
|
|
relu(fc(model_out_t, weights[ks[1]], weights[ks[0]], framework=fw)),
|
|
|
|
weights[ks[9]],
|
|
|
|
weights[ks[8]],
|
|
|
|
),
|
|
|
|
None,
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
policy_t = action_dist_t.deterministic_sample()
|
|
|
|
log_pis_t = action_dist_t.logp(policy_t)
|
|
|
|
if sess:
|
|
|
|
log_pis_t = sess.run(log_pis_t)
|
|
|
|
policy_t = sess.run(policy_t)
|
|
|
|
log_pis_t = np.expand_dims(log_pis_t, -1)
|
|
|
|
|
|
|
|
# Get policy output for t+1.
|
|
|
|
action_dist_tp1 = cls(
|
|
|
|
fc(
|
2022-01-29 18:41:57 -08:00
|
|
|
relu(fc(model_out_tp1, weights[ks[1]], weights[ks[0]], framework=fw)),
|
|
|
|
weights[ks[9]],
|
|
|
|
weights[ks[8]],
|
|
|
|
),
|
|
|
|
None,
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
policy_tp1 = action_dist_tp1.deterministic_sample()
|
|
|
|
log_pis_tp1 = action_dist_tp1.logp(policy_tp1)
|
|
|
|
if sess:
|
|
|
|
log_pis_tp1 = sess.run(log_pis_tp1)
|
|
|
|
policy_tp1 = sess.run(policy_tp1)
|
|
|
|
log_pis_tp1 = np.expand_dims(log_pis_tp1, -1)
|
|
|
|
|
|
|
|
# Q-values for the actually selected actions.
|
|
|
|
# get_q_values
|
|
|
|
q_t = fc(
|
|
|
|
relu(
|
2022-01-29 18:41:57 -08:00
|
|
|
fc(
|
|
|
|
np.concatenate([model_out_t, train_batch[SampleBatch.ACTIONS]], -1),
|
|
|
|
weights[ks[3]],
|
|
|
|
weights[ks[2]],
|
|
|
|
framework=fw,
|
|
|
|
)
|
|
|
|
),
|
2021-02-02 13:05:58 +01:00
|
|
|
weights[ks[11]],
|
|
|
|
weights[ks[10]],
|
2022-01-29 18:41:57 -08:00
|
|
|
framework=fw,
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
# Q-values for current policy in given current state.
|
|
|
|
# get_q_values
|
|
|
|
q_t_det_policy = fc(
|
|
|
|
relu(
|
2022-01-29 18:41:57 -08:00
|
|
|
fc(
|
|
|
|
np.concatenate([model_out_t, policy_t], -1),
|
|
|
|
weights[ks[3]],
|
|
|
|
weights[ks[2]],
|
|
|
|
framework=fw,
|
|
|
|
)
|
|
|
|
),
|
2021-02-02 13:05:58 +01:00
|
|
|
weights[ks[11]],
|
|
|
|
weights[ks[10]],
|
2022-01-29 18:41:57 -08:00
|
|
|
framework=fw,
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
# Target q network evaluation.
|
|
|
|
# target_model.get_q_values
|
2020-07-11 22:06:35 +02:00
|
|
|
if fw == "tf":
|
|
|
|
q_tp1 = fc(
|
|
|
|
relu(
|
2022-01-29 18:41:57 -08:00
|
|
|
fc(
|
|
|
|
np.concatenate([target_model_out_tp1, policy_tp1], -1),
|
|
|
|
weights[ks[7]],
|
|
|
|
weights[ks[6]],
|
|
|
|
framework=fw,
|
|
|
|
)
|
|
|
|
),
|
2021-02-02 13:05:58 +01:00
|
|
|
weights[ks[15]],
|
|
|
|
weights[ks[14]],
|
2022-01-29 18:41:57 -08:00
|
|
|
framework=fw,
|
|
|
|
)
|
2020-07-11 22:06:35 +02:00
|
|
|
else:
|
|
|
|
assert fw == "tfe"
|
|
|
|
q_tp1 = fc(
|
|
|
|
relu(
|
2022-01-29 18:41:57 -08:00
|
|
|
fc(
|
|
|
|
np.concatenate([target_model_out_tp1, policy_tp1], -1),
|
|
|
|
weights[ks[7]],
|
|
|
|
weights[ks[6]],
|
|
|
|
framework=fw,
|
|
|
|
)
|
|
|
|
),
|
2020-07-11 22:06:35 +02:00
|
|
|
weights[ks[9]],
|
|
|
|
weights[ks[8]],
|
2022-01-29 18:41:57 -08:00
|
|
|
framework=fw,
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
q_t_selected = np.squeeze(q_t, axis=-1)
|
|
|
|
q_tp1 -= alpha * log_pis_tp1
|
|
|
|
q_tp1_best = np.squeeze(q_tp1, axis=-1)
|
|
|
|
dones = train_batch[SampleBatch.DONES]
|
|
|
|
rewards = train_batch[SampleBatch.REWARDS]
|
|
|
|
if fw == "torch":
|
|
|
|
dones = dones.float().numpy()
|
|
|
|
rewards = rewards.numpy()
|
|
|
|
q_tp1_best_masked = (1.0 - dones) * q_tp1_best
|
|
|
|
q_t_selected_target = rewards + gamma * q_tp1_best_masked
|
|
|
|
base_td_error = np.abs(q_t_selected - q_t_selected_target)
|
|
|
|
td_error = base_td_error
|
|
|
|
critic_loss = [
|
2022-01-29 18:41:57 -08:00
|
|
|
np.mean(
|
|
|
|
train_batch["weights"] * huber_loss(q_t_selected_target - q_t_selected)
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
]
|
2022-01-29 18:41:57 -08:00
|
|
|
target_entropy = -np.prod((1,))
|
2020-04-15 13:25:16 +02:00
|
|
|
alpha_loss = -np.mean(log_alpha * (log_pis_t + target_entropy))
|
|
|
|
actor_loss = np.mean(alpha * log_pis_t - q_t_det_policy)
|
|
|
|
|
|
|
|
return critic_loss, actor_loss, alpha_loss, td_error
|
|
|
|
|
|
|
|
def _translate_weights_to_torch(self, weights_dict, map_):
|
|
|
|
model_dict = {
|
|
|
|
map_[k]: convert_to_torch_tensor(
|
2022-01-29 18:41:57 -08:00
|
|
|
np.transpose(v)
|
|
|
|
if re.search("kernel", k)
|
|
|
|
else np.array([v])
|
|
|
|
if re.search("log_alpha", k)
|
|
|
|
else v
|
|
|
|
)
|
|
|
|
for i, (k, v) in enumerate(weights_dict.items())
|
|
|
|
if i < 13
|
2020-04-15 13:25:16 +02:00
|
|
|
}
|
2021-02-02 13:05:58 +01:00
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
return model_dict
|
|
|
|
|
2020-07-11 22:06:35 +02:00
|
|
|
def _translate_tfe_weights(self, weights_dict, map_):
|
|
|
|
model_dict = {
|
|
|
|
"default_policy/log_alpha": None,
|
|
|
|
"default_policy/log_alpha_target": None,
|
|
|
|
"default_policy/sequential/action_1/kernel": weights_dict[2],
|
|
|
|
"default_policy/sequential/action_1/bias": weights_dict[3],
|
|
|
|
"default_policy/sequential/action_out/kernel": weights_dict[4],
|
|
|
|
"default_policy/sequential/action_out/bias": weights_dict[5],
|
|
|
|
"default_policy/sequential_1/q_hidden_0/kernel": weights_dict[6],
|
|
|
|
"default_policy/sequential_1/q_hidden_0/bias": weights_dict[7],
|
|
|
|
"default_policy/sequential_1/q_out/kernel": weights_dict[8],
|
|
|
|
"default_policy/sequential_1/q_out/bias": weights_dict[9],
|
|
|
|
"default_policy/value_out/kernel": weights_dict[0],
|
|
|
|
"default_policy/value_out/bias": weights_dict[1],
|
|
|
|
}
|
|
|
|
return model_dict
|
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import pytest
|
|
|
|
import sys
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|