2017-08-27 18:56:52 -07:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
2017-09-28 13:12:06 -07:00
|
|
|
import numpy as np
|
2020-03-12 04:39:47 +01:00
|
|
|
import unittest
|
2017-08-27 18:56:52 -07:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
import ray
|
2022-06-11 15:10:39 +02:00
|
|
|
from ray.rllib.algorithms.registry import get_algorithm_class
|
2020-06-05 21:07:02 +02:00
|
|
|
from ray.rllib.utils.test_utils import check, framework_iterator
|
2017-09-28 13:12:06 -07:00
|
|
|
|
|
|
|
|
|
|
|
def get_mean_action(alg, obs):
|
|
|
|
out = []
|
|
|
|
for _ in range(2000):
|
2022-01-25 14:16:58 +01:00
|
|
|
out.append(float(alg.compute_single_action(obs)))
|
2017-09-28 13:12:06 -07:00
|
|
|
return np.mean(out)
|
|
|
|
|
|
|
|
|
2017-10-13 16:18:16 -07:00
|
|
|
CONFIGS = {
|
2020-03-23 20:19:30 +01:00
|
|
|
"A3C": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2020-05-27 16:19:13 +02:00
|
|
|
"num_workers": 1,
|
2020-03-01 20:53:35 +01:00
|
|
|
},
|
2018-09-03 20:01:53 -07:00
|
|
|
"APEX_DDPG": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2018-09-03 20:01:53 -07:00
|
|
|
"observation_filter": "MeanStdFilter",
|
|
|
|
"num_workers": 2,
|
2022-06-10 17:09:18 +02:00
|
|
|
"min_time_s_per_iteration": 1,
|
2018-09-03 20:01:53 -07:00
|
|
|
"optimizer": {
|
|
|
|
"num_replay_buffer_shards": 1,
|
|
|
|
},
|
2022-08-11 13:07:30 +02:00
|
|
|
"num_steps_sampled_before_learning_starts": 0,
|
2018-09-03 20:01:53 -07:00
|
|
|
},
|
2020-03-23 20:19:30 +01:00
|
|
|
"ARS": {
|
|
|
|
"explore": False,
|
|
|
|
"num_rollouts": 10,
|
|
|
|
"num_workers": 2,
|
|
|
|
"noise_size": 2500000,
|
2020-05-27 16:19:13 +02:00
|
|
|
"observation_filter": "MeanStdFilter",
|
2020-03-23 20:19:30 +01:00
|
|
|
},
|
2018-07-19 15:30:36 -07:00
|
|
|
"DDPG": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2022-06-10 17:09:18 +02:00
|
|
|
"min_sample_timesteps_per_iteration": 100,
|
2022-08-11 13:07:30 +02:00
|
|
|
"num_steps_sampled_before_learning_starts": 0,
|
2018-07-19 15:30:36 -07:00
|
|
|
},
|
2020-03-23 20:19:30 +01:00
|
|
|
"DQN": {
|
2020-05-27 16:19:13 +02:00
|
|
|
"explore": False,
|
2022-08-11 13:07:30 +02:00
|
|
|
"num_steps_sampled_before_learning_starts": 0,
|
2020-03-23 20:19:30 +01:00
|
|
|
},
|
|
|
|
"ES": {
|
|
|
|
"explore": False,
|
|
|
|
"episodes_per_batch": 10,
|
|
|
|
"train_batch_size": 100,
|
|
|
|
"num_workers": 2,
|
|
|
|
"noise_size": 2500000,
|
2020-05-27 16:19:13 +02:00
|
|
|
"observation_filter": "MeanStdFilter",
|
2020-03-23 20:19:30 +01:00
|
|
|
},
|
2018-07-19 15:30:36 -07:00
|
|
|
"PPO": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2018-07-19 15:30:36 -07:00
|
|
|
"num_sgd_iter": 5,
|
2018-09-05 12:06:13 -07:00
|
|
|
"train_batch_size": 1000,
|
2020-05-27 16:19:13 +02:00
|
|
|
"num_workers": 2,
|
2018-07-19 15:30:36 -07:00
|
|
|
},
|
2021-08-31 12:21:49 +02:00
|
|
|
"SimpleQ": {
|
|
|
|
"explore": False,
|
2022-08-11 13:07:30 +02:00
|
|
|
"num_steps_sampled_before_learning_starts": 0,
|
2021-08-31 12:21:49 +02:00
|
|
|
},
|
2020-03-23 20:19:30 +01:00
|
|
|
"SAC": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2022-08-11 13:07:30 +02:00
|
|
|
"num_steps_sampled_before_learning_starts": 0,
|
2018-07-19 15:30:36 -07:00
|
|
|
},
|
2017-10-13 16:18:16 -07:00
|
|
|
}
|
|
|
|
|
2017-11-19 00:36:43 -08:00
|
|
|
|
2021-08-31 12:21:49 +02:00
|
|
|
def ckpt_restore_test(alg_name, tfe=False, object_store=False, replay_buffer=False):
|
|
|
|
config = CONFIGS[alg_name].copy()
|
|
|
|
# If required, store replay buffer data in checkpoints as well.
|
|
|
|
if replay_buffer:
|
|
|
|
config["store_buffer_in_checkpoints"] = True
|
|
|
|
|
2022-01-25 14:16:58 +01:00
|
|
|
frameworks = (["tf2"] if tfe else []) + ["torch", "tf"]
|
2020-06-05 21:07:02 +02:00
|
|
|
for fw in framework_iterator(config, frameworks=frameworks):
|
2021-04-30 12:33:12 +02:00
|
|
|
for use_object_store in [False, True] if object_store else [False]:
|
2020-06-05 21:07:02 +02:00
|
|
|
print("use_object_store={}".format(use_object_store))
|
2022-06-11 15:10:39 +02:00
|
|
|
cls = get_algorithm_class(alg_name)
|
2020-06-05 21:07:02 +02:00
|
|
|
if "DDPG" in alg_name or "SAC" in alg_name:
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
alg1 = cls(config=config, env="Pendulum-v1")
|
|
|
|
alg2 = cls(config=config, env="Pendulum-v1")
|
2020-06-05 21:07:02 +02:00
|
|
|
else:
|
|
|
|
alg1 = cls(config=config, env="CartPole-v0")
|
|
|
|
alg2 = cls(config=config, env="CartPole-v0")
|
|
|
|
|
|
|
|
policy1 = alg1.get_policy()
|
|
|
|
|
|
|
|
for _ in range(1):
|
|
|
|
res = alg1.train()
|
|
|
|
print("current status: " + str(res))
|
|
|
|
|
|
|
|
# Check optimizer state as well.
|
|
|
|
optim_state = policy1.get_state().get("_optimizer_variables")
|
|
|
|
|
|
|
|
# Sync the models
|
|
|
|
if use_object_store:
|
|
|
|
alg2.restore_from_object(alg1.save_to_object())
|
|
|
|
else:
|
|
|
|
alg2.restore(alg1.save())
|
|
|
|
|
|
|
|
# Compare optimizer state with re-loaded one.
|
|
|
|
if optim_state:
|
|
|
|
s2 = alg2.get_policy().get_state().get("_optimizer_variables")
|
|
|
|
# Tf -> Compare states 1:1.
|
2020-10-02 23:07:44 +02:00
|
|
|
if fw in ["tf2", "tf", "tfe"]:
|
2020-06-05 21:07:02 +02:00
|
|
|
check(s2, optim_state)
|
|
|
|
# For torch, optimizers have state_dicts with keys=params,
|
|
|
|
# which are different for the two models (ignore these
|
|
|
|
# different keys, but compare all values nevertheless).
|
|
|
|
else:
|
|
|
|
for i, s2_ in enumerate(s2):
|
|
|
|
check(
|
|
|
|
list(s2_["state"].values()),
|
|
|
|
list(optim_state[i]["state"].values()),
|
|
|
|
)
|
|
|
|
|
2021-08-31 12:21:49 +02:00
|
|
|
# Compare buffer content with restored one.
|
|
|
|
if replay_buffer:
|
|
|
|
data = alg1.local_replay_buffer.replay_buffers[
|
|
|
|
"default_policy"
|
|
|
|
]._storage[42 : 42 + 42]
|
|
|
|
new_data = alg2.local_replay_buffer.replay_buffers[
|
|
|
|
"default_policy"
|
|
|
|
]._storage[42 : 42 + 42]
|
|
|
|
check(data, new_data)
|
|
|
|
|
2020-06-05 21:07:02 +02:00
|
|
|
for _ in range(1):
|
|
|
|
if "DDPG" in alg_name or "SAC" in alg_name:
|
|
|
|
obs = np.clip(
|
|
|
|
np.random.uniform(size=3),
|
|
|
|
policy1.observation_space.low,
|
|
|
|
policy1.observation_space.high,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
obs = np.clip(
|
|
|
|
np.random.uniform(size=4),
|
|
|
|
policy1.observation_space.low,
|
|
|
|
policy1.observation_space.high,
|
|
|
|
)
|
|
|
|
a1 = get_mean_action(alg1, obs)
|
|
|
|
a2 = get_mean_action(alg2, obs)
|
|
|
|
print("Checking computed actions", alg1, obs, a1, a2)
|
|
|
|
if abs(a1 - a2) > 0.1:
|
|
|
|
raise AssertionError(
|
|
|
|
"algo={} [a1={} a2={}]".format(alg_name, a1, a2)
|
|
|
|
)
|
2022-06-11 15:10:39 +02:00
|
|
|
# Stop both algos.
|
2020-06-16 08:51:20 +02:00
|
|
|
alg1.stop()
|
|
|
|
alg2.stop()
|
2017-11-19 00:36:43 -08:00
|
|
|
|
|
|
|
|
2021-04-30 12:33:12 +02:00
|
|
|
class TestCheckpointRestorePG(unittest.TestCase):
|
2020-03-12 04:39:47 +01:00
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
2021-04-30 12:33:12 +02:00
|
|
|
ray.init(num_cpus=5)
|
2020-03-12 04:39:47 +01:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ray.shutdown()
|
|
|
|
|
2020-06-05 21:07:02 +02:00
|
|
|
def test_a3c_checkpoint_restore(self):
|
|
|
|
ckpt_restore_test("A3C")
|
|
|
|
|
2021-04-30 12:33:12 +02:00
|
|
|
def test_ppo_checkpoint_restore(self):
|
|
|
|
ckpt_restore_test("PPO", object_store=True)
|
|
|
|
|
|
|
|
|
|
|
|
class TestCheckpointRestoreOffPolicy(unittest.TestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
ray.init(num_cpus=5)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ray.shutdown()
|
|
|
|
|
2020-06-05 21:07:02 +02:00
|
|
|
def test_apex_ddpg_checkpoint_restore(self):
|
|
|
|
ckpt_restore_test("APEX_DDPG")
|
|
|
|
|
|
|
|
def test_ddpg_checkpoint_restore(self):
|
2021-08-31 12:21:49 +02:00
|
|
|
ckpt_restore_test("DDPG", replay_buffer=True)
|
2020-06-05 21:07:02 +02:00
|
|
|
|
|
|
|
def test_dqn_checkpoint_restore(self):
|
2021-08-31 12:21:49 +02:00
|
|
|
ckpt_restore_test("DQN", object_store=True, replay_buffer=True)
|
2020-06-05 21:07:02 +02:00
|
|
|
|
|
|
|
def test_sac_checkpoint_restore(self):
|
2021-08-31 12:21:49 +02:00
|
|
|
ckpt_restore_test("SAC", replay_buffer=True)
|
|
|
|
|
|
|
|
def test_simpleq_checkpoint_restore(self):
|
|
|
|
ckpt_restore_test("SimpleQ", replay_buffer=True)
|
2020-03-12 04:39:47 +01:00
|
|
|
|
|
|
|
|
2021-04-30 12:33:12 +02:00
|
|
|
class TestCheckpointRestoreEvolutionAlgos(unittest.TestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
ray.init(num_cpus=5)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ray.shutdown()
|
|
|
|
|
|
|
|
def test_ars_checkpoint_restore(self):
|
|
|
|
ckpt_restore_test("ARS")
|
|
|
|
|
|
|
|
def test_es_checkpoint_restore(self):
|
|
|
|
ckpt_restore_test("ES")
|
|
|
|
|
|
|
|
|
2017-11-19 00:36:43 -08:00
|
|
|
if __name__ == "__main__":
|
2020-03-12 04:39:47 +01:00
|
|
|
import pytest
|
|
|
|
import sys
|
2021-04-30 12:33:12 +02:00
|
|
|
|
|
|
|
# One can specify the specific TestCase class to run.
|
|
|
|
# None for all unittest.TestCase classes in this file.
|
|
|
|
class_ = sys.argv[1] if len(sys.argv) > 1 else None
|
|
|
|
sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)]))
|