2020-03-01 20:53:35 +01:00
|
|
|
import numpy as np
|
2020-04-26 23:08:13 +02:00
|
|
|
import re
|
2020-03-01 20:53:35 +01:00
|
|
|
import unittest
|
2021-03-24 11:26:22 -04:00
|
|
|
from tempfile import TemporaryDirectory
|
2020-03-01 20:53:35 +01:00
|
|
|
|
2020-05-27 16:19:13 +02:00
|
|
|
import ray
|
2020-03-01 20:53:35 +01:00
|
|
|
import ray.rllib.agents.ddpg as ddpg
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.agents.ddpg.ddpg_torch_policy import ddpg_actor_critic_loss as loss_torch
|
2020-04-26 23:08:13 +02:00
|
|
|
from ray.rllib.agents.sac.tests.test_sac import SimpleEnv
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
|
2020-04-26 23:08:13 +02:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2020-07-11 22:06:35 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
2020-04-26 23:08:13 +02:00
|
|
|
from ray.rllib.utils.numpy import fc, huber_loss, l2_loss, relu, sigmoid
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.utils.test_utils import (
|
|
|
|
check,
|
|
|
|
check_compute_single_action,
|
|
|
|
check_train_results,
|
|
|
|
framework_iterator,
|
|
|
|
)
|
2021-11-03 10:00:46 +01:00
|
|
|
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
|
2020-03-01 20:53:35 +01:00
|
|
|
|
2020-07-11 22:06:35 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2020-04-26 23:08:13 +02:00
|
|
|
torch, _ = try_import_torch()
|
2020-03-01 20:53:35 +01:00
|
|
|
|
|
|
|
|
|
|
|
class TestDDPG(unittest.TestCase):
|
2020-05-27 16:19:13 +02:00
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
|
|
|
ray.init()
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
|
|
|
ray.shutdown()
|
|
|
|
|
2020-03-01 20:53:35 +01:00
|
|
|
def test_ddpg_compilation(self):
|
|
|
|
"""Test whether a DDPGTrainer can be built with both frameworks."""
|
|
|
|
config = ddpg.DEFAULT_CONFIG.copy()
|
2021-06-21 13:46:01 +02:00
|
|
|
config["seed"] = 42
|
2020-05-27 16:19:13 +02:00
|
|
|
config["num_workers"] = 1
|
2020-05-26 11:10:27 +02:00
|
|
|
config["num_envs_per_worker"] = 2
|
|
|
|
config["learning_starts"] = 0
|
|
|
|
config["exploration_config"]["random_timesteps"] = 100
|
2020-03-01 20:53:35 +01:00
|
|
|
|
2020-07-08 16:12:20 +02:00
|
|
|
num_iterations = 1
|
2020-04-09 23:04:21 +02:00
|
|
|
|
2020-03-01 20:53:35 +01:00
|
|
|
# Test against all frameworks.
|
2021-11-05 16:10:00 +01:00
|
|
|
for _ in framework_iterator(config, with_eager_tracing=True):
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v1")
|
2020-03-01 20:53:35 +01:00
|
|
|
for i in range(num_iterations):
|
|
|
|
results = trainer.train()
|
2021-09-30 16:39:05 +02:00
|
|
|
check_train_results(results)
|
2020-03-01 20:53:35 +01:00
|
|
|
print(results)
|
2020-06-13 17:51:50 +02:00
|
|
|
check_compute_single_action(trainer)
|
2021-03-24 11:26:22 -04:00
|
|
|
# Ensure apply_gradient_fn is being called and updating global_step
|
|
|
|
if config["framework"] == "tf":
|
|
|
|
a = trainer.get_policy().global_step.eval(
|
2022-01-29 18:41:57 -08:00
|
|
|
trainer.get_policy().get_session()
|
|
|
|
)
|
2021-03-24 11:26:22 -04:00
|
|
|
else:
|
|
|
|
a = trainer.get_policy().global_step
|
|
|
|
check(a, 500)
|
|
|
|
trainer.stop()
|
|
|
|
|
|
|
|
def test_ddpg_checkpoint_save_and_restore(self):
|
|
|
|
"""Test whether a DDPGTrainer can save and load checkpoints."""
|
|
|
|
config = ddpg.DEFAULT_CONFIG.copy()
|
|
|
|
config["num_workers"] = 1
|
|
|
|
config["num_envs_per_worker"] = 2
|
|
|
|
config["learning_starts"] = 0
|
|
|
|
config["exploration_config"]["random_timesteps"] = 100
|
|
|
|
|
|
|
|
# Test against all frameworks.
|
2021-11-02 12:10:17 +01:00
|
|
|
for _ in framework_iterator(config, with_eager_tracing=True):
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v1")
|
2021-03-24 11:26:22 -04:00
|
|
|
trainer.train()
|
|
|
|
with TemporaryDirectory() as temp_dir:
|
|
|
|
checkpoint = trainer.save(temp_dir)
|
|
|
|
trainer.restore(checkpoint)
|
2020-07-08 16:12:20 +02:00
|
|
|
trainer.stop()
|
2020-03-01 20:53:35 +01:00
|
|
|
|
|
|
|
def test_ddpg_exploration_and_with_random_prerun(self):
|
|
|
|
"""Tests DDPG's Exploration (w/ random actions for n timesteps)."""
|
2020-04-16 10:20:01 +02:00
|
|
|
core_config = ddpg.DEFAULT_CONFIG.copy()
|
|
|
|
core_config["num_workers"] = 0 # Run locally.
|
2020-03-01 20:53:35 +01:00
|
|
|
obs = np.array([0.0, 0.1, -0.1])
|
|
|
|
|
|
|
|
# Test against all frameworks.
|
2020-07-08 16:12:20 +02:00
|
|
|
for _ in framework_iterator(core_config):
|
2020-04-16 10:20:01 +02:00
|
|
|
config = core_config.copy()
|
2022-02-17 05:06:14 -08:00
|
|
|
config["seed"] = 42
|
2020-03-01 20:53:35 +01:00
|
|
|
# Default OUNoise setup.
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v1")
|
2020-03-01 20:53:35 +01:00
|
|
|
# Setting explore=False should always return the same action.
|
2021-06-30 12:32:11 +02:00
|
|
|
a_ = trainer.compute_single_action(obs, explore=False)
|
2020-10-06 20:28:16 +02:00
|
|
|
self.assertEqual(trainer.get_policy().global_timestep, 1)
|
|
|
|
for i in range(50):
|
2021-06-30 12:32:11 +02:00
|
|
|
a = trainer.compute_single_action(obs, explore=False)
|
2020-10-06 20:28:16 +02:00
|
|
|
self.assertEqual(trainer.get_policy().global_timestep, i + 2)
|
2020-03-01 20:53:35 +01:00
|
|
|
check(a, a_)
|
|
|
|
# explore=None (default: explore) should return different actions.
|
|
|
|
actions = []
|
2020-10-06 20:28:16 +02:00
|
|
|
for i in range(50):
|
2021-06-30 12:32:11 +02:00
|
|
|
actions.append(trainer.compute_single_action(obs))
|
2020-10-06 20:28:16 +02:00
|
|
|
self.assertEqual(trainer.get_policy().global_timestep, i + 52)
|
2020-03-01 20:53:35 +01:00
|
|
|
check(np.std(actions), 0.0, false=True)
|
2020-07-08 16:12:20 +02:00
|
|
|
trainer.stop()
|
2020-03-01 20:53:35 +01:00
|
|
|
|
|
|
|
# Check randomness at beginning.
|
|
|
|
config["exploration_config"] = {
|
|
|
|
# Act randomly at beginning ...
|
|
|
|
"random_timesteps": 50,
|
|
|
|
# Then act very closely to deterministic actions thereafter.
|
|
|
|
"ou_base_scale": 0.001,
|
|
|
|
"initial_scale": 0.001,
|
|
|
|
"final_scale": 0.001,
|
|
|
|
}
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v1")
|
2020-10-06 20:28:16 +02:00
|
|
|
# ts=0 (get a deterministic action as per explore=False).
|
2022-01-29 18:41:57 -08:00
|
|
|
deterministic_action = trainer.compute_single_action(obs, explore=False)
|
2020-10-06 20:28:16 +02:00
|
|
|
self.assertEqual(trainer.get_policy().global_timestep, 1)
|
|
|
|
# ts=1-49 (in random window).
|
2020-03-01 20:53:35 +01:00
|
|
|
random_a = []
|
2020-10-06 20:28:16 +02:00
|
|
|
for i in range(1, 50):
|
2022-01-29 18:41:57 -08:00
|
|
|
random_a.append(trainer.compute_single_action(obs, explore=True))
|
2020-10-06 20:28:16 +02:00
|
|
|
self.assertEqual(trainer.get_policy().global_timestep, i + 1)
|
2020-03-01 20:53:35 +01:00
|
|
|
check(random_a[-1], deterministic_action, false=True)
|
|
|
|
self.assertTrue(np.std(random_a) > 0.5)
|
|
|
|
|
|
|
|
# ts > 50 (a=deterministic_action + scale * N[0,1])
|
2020-10-06 20:28:16 +02:00
|
|
|
for i in range(50):
|
2021-06-30 12:32:11 +02:00
|
|
|
a = trainer.compute_single_action(obs, explore=True)
|
2020-10-06 20:28:16 +02:00
|
|
|
self.assertEqual(trainer.get_policy().global_timestep, i + 51)
|
2020-03-01 20:53:35 +01:00
|
|
|
check(a, deterministic_action, rtol=0.1)
|
|
|
|
|
|
|
|
# ts >> 50 (BUT: explore=False -> expect deterministic action).
|
2020-10-06 20:28:16 +02:00
|
|
|
for i in range(50):
|
2021-06-30 12:32:11 +02:00
|
|
|
a = trainer.compute_single_action(obs, explore=False)
|
2020-10-06 20:28:16 +02:00
|
|
|
self.assertEqual(trainer.get_policy().global_timestep, i + 101)
|
2020-03-01 20:53:35 +01:00
|
|
|
check(a, deterministic_action)
|
2020-07-08 16:12:20 +02:00
|
|
|
trainer.stop()
|
2020-03-01 20:53:35 +01:00
|
|
|
|
2020-04-26 23:08:13 +02:00
|
|
|
def test_ddpg_loss_function(self):
|
|
|
|
"""Tests DDPG loss function results across all frameworks."""
|
|
|
|
config = ddpg.DEFAULT_CONFIG.copy()
|
|
|
|
# Run locally.
|
2022-02-17 05:06:14 -08:00
|
|
|
config["seed"] = 42
|
2020-04-26 23:08:13 +02:00
|
|
|
config["num_workers"] = 0
|
|
|
|
config["learning_starts"] = 0
|
|
|
|
config["twin_q"] = True
|
|
|
|
config["use_huber"] = True
|
|
|
|
config["huber_threshold"] = 1.0
|
|
|
|
config["gamma"] = 0.99
|
|
|
|
# Make this small (seems to introduce errors).
|
|
|
|
config["l2_reg"] = 1e-10
|
|
|
|
config["prioritized_replay"] = False
|
|
|
|
# Use very simple nets.
|
|
|
|
config["actor_hiddens"] = [10]
|
|
|
|
config["critic_hiddens"] = [10]
|
|
|
|
# Make sure, timing differences do not affect trainer.train().
|
2022-01-25 14:16:58 +01:00
|
|
|
config["min_time_s_per_reporting"] = 0
|
2020-04-26 23:08:13 +02:00
|
|
|
config["timesteps_per_iteration"] = 100
|
|
|
|
|
|
|
|
map_ = {
|
|
|
|
# Normal net.
|
|
|
|
"default_policy/actor_hidden_0/kernel": "policy_model.action_0."
|
|
|
|
"_model.0.weight",
|
|
|
|
"default_policy/actor_hidden_0/bias": "policy_model.action_0."
|
|
|
|
"_model.0.bias",
|
|
|
|
"default_policy/actor_out/kernel": "policy_model.action_out."
|
|
|
|
"_model.0.weight",
|
2022-03-15 09:34:21 -07:00
|
|
|
"default_policy/actor_out/bias": "policy_model.action_out._model.0.bias",
|
2020-04-26 23:08:13 +02:00
|
|
|
"default_policy/sequential/q_hidden_0/kernel": "q_model.q_hidden_0"
|
|
|
|
"._model.0.weight",
|
|
|
|
"default_policy/sequential/q_hidden_0/bias": "q_model.q_hidden_0."
|
|
|
|
"_model.0.bias",
|
|
|
|
"default_policy/sequential/q_out/kernel": "q_model.q_out._model."
|
|
|
|
"0.weight",
|
2022-03-15 09:34:21 -07:00
|
|
|
"default_policy/sequential/q_out/bias": "q_model.q_out._model.0.bias",
|
2020-04-26 23:08:13 +02:00
|
|
|
# -- twin.
|
|
|
|
"default_policy/sequential_1/twin_q_hidden_0/kernel": "twin_"
|
|
|
|
"q_model.twin_q_hidden_0._model.0.weight",
|
|
|
|
"default_policy/sequential_1/twin_q_hidden_0/bias": "twin_"
|
|
|
|
"q_model.twin_q_hidden_0._model.0.bias",
|
|
|
|
"default_policy/sequential_1/twin_q_out/kernel": "twin_"
|
|
|
|
"q_model.twin_q_out._model.0.weight",
|
|
|
|
"default_policy/sequential_1/twin_q_out/bias": "twin_"
|
|
|
|
"q_model.twin_q_out._model.0.bias",
|
|
|
|
# Target net.
|
|
|
|
"default_policy/actor_hidden_0_1/kernel": "policy_model.action_0."
|
|
|
|
"_model.0.weight",
|
|
|
|
"default_policy/actor_hidden_0_1/bias": "policy_model.action_0."
|
|
|
|
"_model.0.bias",
|
|
|
|
"default_policy/actor_out_1/kernel": "policy_model.action_out."
|
|
|
|
"_model.0.weight",
|
|
|
|
"default_policy/actor_out_1/bias": "policy_model.action_out._model"
|
|
|
|
".0.bias",
|
|
|
|
"default_policy/sequential_2/q_hidden_0/kernel": "q_model."
|
|
|
|
"q_hidden_0._model.0.weight",
|
|
|
|
"default_policy/sequential_2/q_hidden_0/bias": "q_model."
|
|
|
|
"q_hidden_0._model.0.bias",
|
|
|
|
"default_policy/sequential_2/q_out/kernel": "q_model."
|
|
|
|
"q_out._model.0.weight",
|
2022-03-15 09:34:21 -07:00
|
|
|
"default_policy/sequential_2/q_out/bias": "q_model.q_out._model.0.bias",
|
2020-04-26 23:08:13 +02:00
|
|
|
# -- twin.
|
|
|
|
"default_policy/sequential_3/twin_q_hidden_0/kernel": "twin_"
|
|
|
|
"q_model.twin_q_hidden_0._model.0.weight",
|
|
|
|
"default_policy/sequential_3/twin_q_hidden_0/bias": "twin_"
|
|
|
|
"q_model.twin_q_hidden_0._model.0.bias",
|
|
|
|
"default_policy/sequential_3/twin_q_out/kernel": "twin_"
|
|
|
|
"q_model.twin_q_out._model.0.weight",
|
|
|
|
"default_policy/sequential_3/twin_q_out/bias": "twin_"
|
|
|
|
"q_model.twin_q_out._model.0.bias",
|
|
|
|
}
|
|
|
|
|
|
|
|
env = SimpleEnv
|
|
|
|
batch_size = 100
|
2021-02-10 15:10:01 +01:00
|
|
|
obs_size = (batch_size, 1)
|
|
|
|
actions = np.random.random(size=(batch_size, 1))
|
2020-04-26 23:08:13 +02:00
|
|
|
|
|
|
|
# Batch of size=n.
|
|
|
|
input_ = self._get_batch_helper(obs_size, actions, batch_size)
|
|
|
|
|
|
|
|
# Simply compare loss values AND grads of all frameworks with each
|
|
|
|
# other.
|
|
|
|
prev_fw_loss = weights_dict = None
|
|
|
|
expect_c, expect_a, expect_t = None, None, None
|
|
|
|
# History of tf-updated NN-weights over n training steps.
|
|
|
|
tf_updated_weights = []
|
|
|
|
# History of input batches used.
|
|
|
|
tf_inputs = []
|
|
|
|
for fw, sess in framework_iterator(
|
2022-01-29 18:41:57 -08:00
|
|
|
config, frameworks=("tf", "torch"), session=True
|
|
|
|
):
|
2020-04-26 23:08:13 +02:00
|
|
|
# Generate Trainer and get its default Policy object.
|
|
|
|
trainer = ddpg.DDPGTrainer(config=config, env=env)
|
|
|
|
policy = trainer.get_policy()
|
|
|
|
p_sess = None
|
|
|
|
if sess:
|
|
|
|
p_sess = policy.get_session()
|
|
|
|
|
|
|
|
# Set all weights (of all nets) to fixed values.
|
|
|
|
if weights_dict is None:
|
|
|
|
assert fw == "tf" # Start with the tf vars-dict.
|
|
|
|
weights_dict = policy.get_weights()
|
|
|
|
else:
|
|
|
|
assert fw == "torch" # Then transfer that to torch Model.
|
2022-01-29 18:41:57 -08:00
|
|
|
model_dict = self._translate_weights_to_torch(weights_dict, map_)
|
2020-04-26 23:08:13 +02:00
|
|
|
policy.model.load_state_dict(model_dict)
|
|
|
|
policy.target_model.load_state_dict(model_dict)
|
|
|
|
|
|
|
|
if fw == "torch":
|
|
|
|
# Actually convert to torch tensors.
|
|
|
|
input_ = policy._lazy_tensor_dict(input_)
|
|
|
|
input_ = {k: input_[k] for k in input_.keys()}
|
|
|
|
|
|
|
|
# Only run the expectation once, should be the same anyways
|
|
|
|
# for all frameworks.
|
|
|
|
if expect_c is None:
|
2022-01-29 18:41:57 -08:00
|
|
|
expect_c, expect_a, expect_t = self._ddpg_loss_helper(
|
|
|
|
input_,
|
|
|
|
weights_dict,
|
|
|
|
sorted(weights_dict.keys()),
|
|
|
|
fw,
|
|
|
|
gamma=config["gamma"],
|
|
|
|
huber_threshold=config["huber_threshold"],
|
|
|
|
l2_reg=config["l2_reg"],
|
|
|
|
sess=sess,
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
|
|
|
|
# Get actual outs and compare to expectation AND previous
|
|
|
|
# framework. c=critic, a=actor, e=entropy, t=td-error.
|
|
|
|
if fw == "tf":
|
2022-01-29 18:41:57 -08:00
|
|
|
c, a, t, tf_c_grads, tf_a_grads = p_sess.run(
|
|
|
|
[
|
2020-04-26 23:08:13 +02:00
|
|
|
policy.critic_loss,
|
|
|
|
policy.actor_loss,
|
|
|
|
policy.td_error,
|
|
|
|
policy._critic_optimizer.compute_gradients(
|
2022-01-29 18:41:57 -08:00
|
|
|
policy.critic_loss, policy.model.q_variables()
|
|
|
|
),
|
2020-04-26 23:08:13 +02:00
|
|
|
policy._actor_optimizer.compute_gradients(
|
2022-01-29 18:41:57 -08:00
|
|
|
policy.actor_loss, policy.model.policy_variables()
|
|
|
|
),
|
|
|
|
],
|
|
|
|
feed_dict=policy._get_loss_inputs_dict(input_, shuffle=False),
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
# Check pure loss values.
|
|
|
|
check(c, expect_c)
|
|
|
|
check(a, expect_a)
|
|
|
|
check(t, expect_t)
|
|
|
|
|
|
|
|
tf_c_grads = [g for g, v in tf_c_grads]
|
|
|
|
tf_a_grads = [g for g, v in tf_a_grads]
|
|
|
|
|
|
|
|
elif fw == "torch":
|
|
|
|
loss_torch(policy, policy.model, None, input_)
|
2022-01-29 18:41:57 -08:00
|
|
|
c, a, t = (
|
|
|
|
policy.get_tower_stats("critic_loss")[0],
|
|
|
|
policy.get_tower_stats("actor_loss")[0],
|
|
|
|
policy.get_tower_stats("td_error")[0],
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
# Check pure loss values.
|
|
|
|
check(c, expect_c)
|
|
|
|
check(a, expect_a)
|
|
|
|
check(t, expect_t)
|
|
|
|
|
|
|
|
# Test actor gradients.
|
|
|
|
policy._actor_optimizer.zero_grad()
|
|
|
|
assert all(v.grad is None for v in policy.model.q_variables())
|
2022-01-29 18:41:57 -08:00
|
|
|
assert all(v.grad is None for v in policy.model.policy_variables())
|
2020-04-26 23:08:13 +02:00
|
|
|
a.backward()
|
|
|
|
# `actor_loss` depends on Q-net vars
|
|
|
|
# (but not twin-Q-net vars!).
|
2022-01-29 18:41:57 -08:00
|
|
|
assert not any(v.grad is None for v in policy.model.q_variables()[:4])
|
|
|
|
assert all(v.grad is None for v in policy.model.q_variables()[4:])
|
2020-04-26 23:08:13 +02:00
|
|
|
assert not all(
|
2022-01-29 18:41:57 -08:00
|
|
|
torch.mean(v.grad) == 0 for v in policy.model.policy_variables()
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
assert not all(
|
2022-01-29 18:41:57 -08:00
|
|
|
torch.min(v.grad) == 0 for v in policy.model.policy_variables()
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
# Compare with tf ones.
|
2022-01-29 18:41:57 -08:00
|
|
|
torch_a_grads = [v.grad for v in policy.model.policy_variables()]
|
2020-04-26 23:08:13 +02:00
|
|
|
for tf_g, torch_g in zip(tf_a_grads, torch_a_grads):
|
|
|
|
if tf_g.shape != torch_g.shape:
|
2020-10-06 20:28:16 +02:00
|
|
|
check(tf_g, np.transpose(torch_g.cpu()))
|
2020-04-26 23:08:13 +02:00
|
|
|
else:
|
|
|
|
check(tf_g, torch_g)
|
|
|
|
|
|
|
|
# Test critic gradients.
|
|
|
|
policy._critic_optimizer.zero_grad()
|
|
|
|
assert all(
|
|
|
|
v.grad is None or torch.mean(v.grad) == 0.0
|
2022-01-29 18:41:57 -08:00
|
|
|
for v in policy.model.q_variables()
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
assert all(
|
|
|
|
v.grad is None or torch.min(v.grad) == 0.0
|
2022-01-29 18:41:57 -08:00
|
|
|
for v in policy.model.q_variables()
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
c.backward()
|
|
|
|
assert not all(
|
2022-01-29 18:41:57 -08:00
|
|
|
torch.mean(v.grad) == 0 for v in policy.model.q_variables()
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
assert not all(
|
2022-01-29 18:41:57 -08:00
|
|
|
torch.min(v.grad) == 0 for v in policy.model.q_variables()
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
# Compare with tf ones.
|
|
|
|
torch_c_grads = [v.grad for v in policy.model.q_variables()]
|
|
|
|
for tf_g, torch_g in zip(tf_c_grads, torch_c_grads):
|
|
|
|
if tf_g.shape != torch_g.shape:
|
2020-10-06 20:28:16 +02:00
|
|
|
check(tf_g, np.transpose(torch_g.cpu()))
|
2020-04-26 23:08:13 +02:00
|
|
|
else:
|
|
|
|
check(tf_g, torch_g)
|
|
|
|
# Compare (unchanged(!) actor grads) with tf ones.
|
2022-01-29 18:41:57 -08:00
|
|
|
torch_a_grads = [v.grad for v in policy.model.policy_variables()]
|
2020-04-26 23:08:13 +02:00
|
|
|
for tf_g, torch_g in zip(tf_a_grads, torch_a_grads):
|
|
|
|
if tf_g.shape != torch_g.shape:
|
2020-10-06 20:28:16 +02:00
|
|
|
check(tf_g, np.transpose(torch_g.cpu()))
|
2020-04-26 23:08:13 +02:00
|
|
|
else:
|
|
|
|
check(tf_g, torch_g)
|
|
|
|
|
|
|
|
# Store this framework's losses in prev_fw_loss to compare with
|
|
|
|
# next framework's outputs.
|
|
|
|
if prev_fw_loss is not None:
|
|
|
|
check(c, prev_fw_loss[0])
|
|
|
|
check(a, prev_fw_loss[1])
|
|
|
|
check(t, prev_fw_loss[2])
|
|
|
|
|
|
|
|
prev_fw_loss = (c, a, t)
|
|
|
|
|
|
|
|
# Update weights from our batch (n times).
|
2021-05-28 22:09:25 +02:00
|
|
|
for update_iteration in range(6):
|
2020-04-26 23:08:13 +02:00
|
|
|
print("train iteration {}".format(update_iteration))
|
|
|
|
if fw == "tf":
|
|
|
|
in_ = self._get_batch_helper(obs_size, actions, batch_size)
|
|
|
|
tf_inputs.append(in_)
|
|
|
|
# Set a fake-batch to use
|
|
|
|
# (instead of sampling from replay buffer).
|
2021-11-19 11:57:37 +01:00
|
|
|
buf = MultiAgentReplayBuffer.get_instance_for_testing()
|
2020-05-05 12:36:42 -07:00
|
|
|
buf._fake_batch = in_
|
2020-04-26 23:08:13 +02:00
|
|
|
trainer.train()
|
|
|
|
updated_weights = policy.get_weights()
|
|
|
|
# Net must have changed.
|
|
|
|
if tf_updated_weights:
|
|
|
|
check(
|
2022-01-29 18:41:57 -08:00
|
|
|
updated_weights["default_policy/actor_hidden_0/kernel"],
|
2020-04-26 23:08:13 +02:00
|
|
|
tf_updated_weights[-1][
|
2022-01-29 18:41:57 -08:00
|
|
|
"default_policy/actor_hidden_0/kernel"
|
|
|
|
],
|
|
|
|
false=True,
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
tf_updated_weights.append(updated_weights)
|
|
|
|
|
|
|
|
# Compare with updated tf-weights. Must all be the same.
|
|
|
|
else:
|
|
|
|
tf_weights = tf_updated_weights[update_iteration]
|
|
|
|
in_ = tf_inputs[update_iteration]
|
|
|
|
# Set a fake-batch to use
|
|
|
|
# (instead of sampling from replay buffer).
|
2021-11-19 11:57:37 +01:00
|
|
|
buf = MultiAgentReplayBuffer.get_instance_for_testing()
|
2020-05-05 12:36:42 -07:00
|
|
|
buf._fake_batch = in_
|
2020-04-26 23:08:13 +02:00
|
|
|
trainer.train()
|
|
|
|
# Compare updated model and target weights.
|
|
|
|
for tf_key in tf_weights.keys():
|
|
|
|
tf_var = tf_weights[tf_key]
|
|
|
|
# Model.
|
|
|
|
if re.search(
|
2022-03-15 09:34:21 -07:00
|
|
|
"actor_out_1|actor_hidden_0_1|sequential_[23]", tf_key
|
2022-01-29 18:41:57 -08:00
|
|
|
):
|
|
|
|
torch_var = policy.target_model.state_dict()[map_[tf_key]]
|
2020-04-26 23:08:13 +02:00
|
|
|
# Target model.
|
|
|
|
else:
|
|
|
|
torch_var = policy.model.state_dict()[map_[tf_key]]
|
|
|
|
if tf_var.shape != torch_var.shape:
|
2022-01-29 18:41:57 -08:00
|
|
|
check(tf_var, np.transpose(torch_var.cpu()), atol=0.1)
|
2020-04-26 23:08:13 +02:00
|
|
|
else:
|
2020-05-27 16:19:13 +02:00
|
|
|
check(tf_var, torch_var, atol=0.1)
|
2020-04-26 23:08:13 +02:00
|
|
|
|
2020-05-08 08:26:32 +02:00
|
|
|
trainer.stop()
|
|
|
|
|
2020-04-26 23:08:13 +02:00
|
|
|
def _get_batch_helper(self, obs_size, actions, batch_size):
|
2022-01-29 18:41:57 -08:00
|
|
|
return SampleBatch(
|
|
|
|
{
|
|
|
|
SampleBatch.CUR_OBS: np.random.random(size=obs_size),
|
|
|
|
SampleBatch.ACTIONS: actions,
|
|
|
|
SampleBatch.REWARDS: np.random.random(size=(batch_size,)),
|
|
|
|
SampleBatch.DONES: np.random.choice([True, False], size=(batch_size,)),
|
|
|
|
SampleBatch.NEXT_OBS: np.random.random(size=obs_size),
|
|
|
|
"weights": np.ones(shape=(batch_size,)),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
def _ddpg_loss_helper(
|
|
|
|
self, train_batch, weights, ks, fw, gamma, huber_threshold, l2_reg, sess
|
|
|
|
):
|
2020-04-26 23:08:13 +02:00
|
|
|
"""Emulates DDPG loss functions for tf and torch."""
|
|
|
|
model_out_t = train_batch[SampleBatch.CUR_OBS]
|
|
|
|
target_model_out_tp1 = train_batch[SampleBatch.NEXT_OBS]
|
|
|
|
# get_policy_output
|
2022-01-29 18:41:57 -08:00
|
|
|
policy_t = sigmoid(
|
|
|
|
2.0
|
|
|
|
* fc(
|
|
|
|
relu(fc(model_out_t, weights[ks[1]], weights[ks[0]], framework=fw)),
|
|
|
|
weights[ks[5]],
|
|
|
|
weights[ks[4]],
|
|
|
|
framework=fw,
|
|
|
|
)
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
# Get policy output for t+1 (target model).
|
2022-01-29 18:41:57 -08:00
|
|
|
policy_tp1 = sigmoid(
|
|
|
|
2.0
|
|
|
|
* fc(
|
|
|
|
relu(
|
|
|
|
fc(
|
|
|
|
target_model_out_tp1,
|
|
|
|
weights[ks[3]],
|
|
|
|
weights[ks[2]],
|
|
|
|
framework=fw,
|
|
|
|
)
|
|
|
|
),
|
|
|
|
weights[ks[7]],
|
|
|
|
weights[ks[6]],
|
|
|
|
framework=fw,
|
|
|
|
)
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
# Assume no smooth target policy.
|
|
|
|
policy_tp1_smoothed = policy_tp1
|
|
|
|
|
|
|
|
# Q-values for the actually selected actions.
|
|
|
|
# get_q_values
|
|
|
|
q_t = fc(
|
|
|
|
relu(
|
2022-01-29 18:41:57 -08:00
|
|
|
fc(
|
|
|
|
np.concatenate([model_out_t, train_batch[SampleBatch.ACTIONS]], -1),
|
|
|
|
weights[ks[9]],
|
|
|
|
weights[ks[8]],
|
|
|
|
framework=fw,
|
|
|
|
)
|
|
|
|
),
|
2020-04-26 23:08:13 +02:00
|
|
|
weights[ks[11]],
|
|
|
|
weights[ks[10]],
|
2022-01-29 18:41:57 -08:00
|
|
|
framework=fw,
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
twin_q_t = fc(
|
|
|
|
relu(
|
2022-01-29 18:41:57 -08:00
|
|
|
fc(
|
|
|
|
np.concatenate([model_out_t, train_batch[SampleBatch.ACTIONS]], -1),
|
|
|
|
weights[ks[13]],
|
|
|
|
weights[ks[12]],
|
|
|
|
framework=fw,
|
|
|
|
)
|
|
|
|
),
|
2020-04-26 23:08:13 +02:00
|
|
|
weights[ks[15]],
|
|
|
|
weights[ks[14]],
|
2022-01-29 18:41:57 -08:00
|
|
|
framework=fw,
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
|
|
|
|
# Q-values for current policy in given current state.
|
|
|
|
# get_q_values
|
|
|
|
q_t_det_policy = fc(
|
|
|
|
relu(
|
2022-01-29 18:41:57 -08:00
|
|
|
fc(
|
|
|
|
np.concatenate([model_out_t, policy_t], -1),
|
|
|
|
weights[ks[9]],
|
|
|
|
weights[ks[8]],
|
|
|
|
framework=fw,
|
|
|
|
)
|
|
|
|
),
|
2020-04-26 23:08:13 +02:00
|
|
|
weights[ks[11]],
|
|
|
|
weights[ks[10]],
|
2022-01-29 18:41:57 -08:00
|
|
|
framework=fw,
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
|
|
|
|
# Target q network evaluation.
|
|
|
|
# target_model.get_q_values
|
|
|
|
q_tp1 = fc(
|
|
|
|
relu(
|
2022-01-29 18:41:57 -08:00
|
|
|
fc(
|
|
|
|
np.concatenate([target_model_out_tp1, policy_tp1_smoothed], -1),
|
|
|
|
weights[ks[17]],
|
|
|
|
weights[ks[16]],
|
|
|
|
framework=fw,
|
|
|
|
)
|
|
|
|
),
|
2020-04-26 23:08:13 +02:00
|
|
|
weights[ks[19]],
|
|
|
|
weights[ks[18]],
|
2022-01-29 18:41:57 -08:00
|
|
|
framework=fw,
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
twin_q_tp1 = fc(
|
|
|
|
relu(
|
2022-01-29 18:41:57 -08:00
|
|
|
fc(
|
|
|
|
np.concatenate([target_model_out_tp1, policy_tp1_smoothed], -1),
|
|
|
|
weights[ks[21]],
|
|
|
|
weights[ks[20]],
|
|
|
|
framework=fw,
|
|
|
|
)
|
|
|
|
),
|
2020-04-26 23:08:13 +02:00
|
|
|
weights[ks[23]],
|
|
|
|
weights[ks[22]],
|
2022-01-29 18:41:57 -08:00
|
|
|
framework=fw,
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
|
|
|
|
q_t_selected = np.squeeze(q_t, axis=-1)
|
|
|
|
twin_q_t_selected = np.squeeze(twin_q_t, axis=-1)
|
|
|
|
q_tp1 = np.minimum(q_tp1, twin_q_tp1)
|
|
|
|
q_tp1_best = np.squeeze(q_tp1, axis=-1)
|
|
|
|
|
|
|
|
dones = train_batch[SampleBatch.DONES]
|
|
|
|
rewards = train_batch[SampleBatch.REWARDS]
|
|
|
|
if fw == "torch":
|
|
|
|
dones = dones.float().numpy()
|
|
|
|
rewards = rewards.numpy()
|
|
|
|
|
|
|
|
q_tp1_best_masked = (1.0 - dones) * q_tp1_best
|
|
|
|
q_t_selected_target = rewards + gamma * q_tp1_best_masked
|
|
|
|
|
|
|
|
td_error = q_t_selected - q_t_selected_target
|
|
|
|
twin_td_error = twin_q_t_selected - q_t_selected_target
|
2022-01-29 18:41:57 -08:00
|
|
|
errors = huber_loss(td_error, huber_threshold) + huber_loss(
|
|
|
|
twin_td_error, huber_threshold
|
|
|
|
)
|
2020-04-26 23:08:13 +02:00
|
|
|
|
|
|
|
critic_loss = np.mean(errors)
|
|
|
|
actor_loss = -np.mean(q_t_det_policy)
|
|
|
|
# Add l2-regularization if required.
|
|
|
|
for name, var in weights.items():
|
|
|
|
if re.match("default_policy/actor_(hidden_0|out)/kernel", name):
|
2022-01-29 18:41:57 -08:00
|
|
|
actor_loss += l2_reg * l2_loss(var)
|
2020-04-26 23:08:13 +02:00
|
|
|
elif re.match("default_policy/sequential(_1)?/\\w+/kernel", name):
|
2022-01-29 18:41:57 -08:00
|
|
|
critic_loss += l2_reg * l2_loss(var)
|
2020-04-26 23:08:13 +02:00
|
|
|
|
|
|
|
return critic_loss, actor_loss, td_error
|
|
|
|
|
|
|
|
def _translate_weights_to_torch(self, weights_dict, map_):
|
|
|
|
model_dict = {
|
|
|
|
map_[k]: convert_to_torch_tensor(
|
2022-01-29 18:41:57 -08:00
|
|
|
np.transpose(v) if re.search("kernel", k) else v
|
|
|
|
)
|
|
|
|
for k, v in weights_dict.items()
|
|
|
|
if re.search("default_policy/(actor_(hidden_0|out)|sequential(_1)?)/", k)
|
2020-04-26 23:08:13 +02:00
|
|
|
}
|
2022-01-29 18:41:57 -08:00
|
|
|
model_dict[
|
|
|
|
"policy_model.action_out_squashed.low_action"
|
|
|
|
] = convert_to_torch_tensor(np.array([0.0]))
|
|
|
|
model_dict[
|
|
|
|
"policy_model.action_out_squashed.action_range"
|
|
|
|
] = convert_to_torch_tensor(np.array([1.0]))
|
2020-04-26 23:08:13 +02:00
|
|
|
return model_dict
|
|
|
|
|
2020-03-01 20:53:35 +01:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2020-03-12 04:39:47 +01:00
|
|
|
import pytest
|
|
|
|
import sys
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|