2017-08-27 18:56:52 -07:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
2017-09-28 13:12:06 -07:00
|
|
|
import numpy as np
|
2020-03-12 04:39:47 +01:00
|
|
|
import unittest
|
2017-08-27 18:56:52 -07:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
import ray
|
2021-02-08 12:05:16 +01:00
|
|
|
from ray.rllib.agents.registry import get_trainer_class
|
2020-06-05 21:07:02 +02:00
|
|
|
from ray.rllib.utils.test_utils import check, framework_iterator
|
2017-09-28 13:12:06 -07:00
|
|
|
|
|
|
|
|
|
|
|
def get_mean_action(alg, obs):
|
|
|
|
out = []
|
|
|
|
for _ in range(2000):
|
|
|
|
out.append(float(alg.compute_action(obs)))
|
|
|
|
return np.mean(out)
|
|
|
|
|
|
|
|
|
2017-10-13 16:18:16 -07:00
|
|
|
CONFIGS = {
|
2020-03-23 20:19:30 +01:00
|
|
|
"A3C": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2020-05-27 16:19:13 +02:00
|
|
|
"num_workers": 1,
|
2020-03-01 20:53:35 +01:00
|
|
|
},
|
2018-09-03 20:01:53 -07:00
|
|
|
"APEX_DDPG": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2018-09-03 20:01:53 -07:00
|
|
|
"observation_filter": "MeanStdFilter",
|
|
|
|
"num_workers": 2,
|
|
|
|
"min_iter_time_s": 1,
|
|
|
|
"optimizer": {
|
|
|
|
"num_replay_buffer_shards": 1,
|
|
|
|
},
|
|
|
|
},
|
2020-03-23 20:19:30 +01:00
|
|
|
"ARS": {
|
|
|
|
"explore": False,
|
|
|
|
"num_rollouts": 10,
|
|
|
|
"num_workers": 2,
|
|
|
|
"noise_size": 2500000,
|
2020-05-27 16:19:13 +02:00
|
|
|
"observation_filter": "MeanStdFilter",
|
2020-03-23 20:19:30 +01:00
|
|
|
},
|
2018-07-19 15:30:36 -07:00
|
|
|
"DDPG": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2020-05-27 16:19:13 +02:00
|
|
|
"timesteps_per_iteration": 100,
|
2018-07-19 15:30:36 -07:00
|
|
|
},
|
2020-03-23 20:19:30 +01:00
|
|
|
"DQN": {
|
2020-05-27 16:19:13 +02:00
|
|
|
"explore": False,
|
2020-03-23 20:19:30 +01:00
|
|
|
},
|
|
|
|
"ES": {
|
|
|
|
"explore": False,
|
|
|
|
"episodes_per_batch": 10,
|
|
|
|
"train_batch_size": 100,
|
|
|
|
"num_workers": 2,
|
|
|
|
"noise_size": 2500000,
|
2020-05-27 16:19:13 +02:00
|
|
|
"observation_filter": "MeanStdFilter",
|
2020-03-23 20:19:30 +01:00
|
|
|
},
|
2018-07-19 15:30:36 -07:00
|
|
|
"PPO": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2018-07-19 15:30:36 -07:00
|
|
|
"num_sgd_iter": 5,
|
2018-09-05 12:06:13 -07:00
|
|
|
"train_batch_size": 1000,
|
2020-05-27 16:19:13 +02:00
|
|
|
"num_workers": 2,
|
2018-07-19 15:30:36 -07:00
|
|
|
},
|
2020-03-23 20:19:30 +01:00
|
|
|
"SAC": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2018-07-19 15:30:36 -07:00
|
|
|
},
|
2017-10-13 16:18:16 -07:00
|
|
|
}
|
|
|
|
|
2017-11-19 00:36:43 -08:00
|
|
|
|
2021-04-30 12:33:12 +02:00
|
|
|
def ckpt_restore_test(alg_name, tfe=False, object_store=False):
|
2020-05-27 16:19:13 +02:00
|
|
|
config = CONFIGS[alg_name]
|
2020-06-05 21:07:02 +02:00
|
|
|
frameworks = (["tfe"] if tfe else []) + ["torch", "tf"]
|
|
|
|
for fw in framework_iterator(config, frameworks=frameworks):
|
2021-04-30 12:33:12 +02:00
|
|
|
for use_object_store in ([False, True] if object_store else [False]):
|
2020-06-05 21:07:02 +02:00
|
|
|
print("use_object_store={}".format(use_object_store))
|
2021-02-08 12:05:16 +01:00
|
|
|
cls = get_trainer_class(alg_name)
|
2020-06-05 21:07:02 +02:00
|
|
|
if "DDPG" in alg_name or "SAC" in alg_name:
|
|
|
|
alg1 = cls(config=config, env="Pendulum-v0")
|
|
|
|
alg2 = cls(config=config, env="Pendulum-v0")
|
|
|
|
else:
|
|
|
|
alg1 = cls(config=config, env="CartPole-v0")
|
|
|
|
alg2 = cls(config=config, env="CartPole-v0")
|
|
|
|
|
|
|
|
policy1 = alg1.get_policy()
|
|
|
|
|
|
|
|
for _ in range(1):
|
|
|
|
res = alg1.train()
|
|
|
|
print("current status: " + str(res))
|
|
|
|
|
|
|
|
# Check optimizer state as well.
|
|
|
|
optim_state = policy1.get_state().get("_optimizer_variables")
|
|
|
|
|
|
|
|
# Sync the models
|
|
|
|
if use_object_store:
|
|
|
|
alg2.restore_from_object(alg1.save_to_object())
|
|
|
|
else:
|
|
|
|
alg2.restore(alg1.save())
|
|
|
|
|
|
|
|
# Compare optimizer state with re-loaded one.
|
|
|
|
if optim_state:
|
|
|
|
s2 = alg2.get_policy().get_state().get("_optimizer_variables")
|
|
|
|
# Tf -> Compare states 1:1.
|
2020-10-02 23:07:44 +02:00
|
|
|
if fw in ["tf2", "tf", "tfe"]:
|
2020-06-05 21:07:02 +02:00
|
|
|
check(s2, optim_state)
|
|
|
|
# For torch, optimizers have state_dicts with keys=params,
|
|
|
|
# which are different for the two models (ignore these
|
|
|
|
# different keys, but compare all values nevertheless).
|
|
|
|
else:
|
|
|
|
for i, s2_ in enumerate(s2):
|
|
|
|
check(
|
|
|
|
list(s2_["state"].values()),
|
|
|
|
list(optim_state[i]["state"].values()))
|
|
|
|
|
|
|
|
for _ in range(1):
|
|
|
|
if "DDPG" in alg_name or "SAC" in alg_name:
|
|
|
|
obs = np.clip(
|
|
|
|
np.random.uniform(size=3),
|
|
|
|
policy1.observation_space.low,
|
|
|
|
policy1.observation_space.high)
|
|
|
|
else:
|
|
|
|
obs = np.clip(
|
|
|
|
np.random.uniform(size=4),
|
|
|
|
policy1.observation_space.low,
|
|
|
|
policy1.observation_space.high)
|
|
|
|
a1 = get_mean_action(alg1, obs)
|
|
|
|
a2 = get_mean_action(alg2, obs)
|
|
|
|
print("Checking computed actions", alg1, obs, a1, a2)
|
|
|
|
if abs(a1 - a2) > .1:
|
|
|
|
raise AssertionError("algo={} [a1={} a2={}]".format(
|
|
|
|
alg_name, a1, a2))
|
2020-06-16 08:51:20 +02:00
|
|
|
# Stop both Trainers.
|
|
|
|
alg1.stop()
|
|
|
|
alg2.stop()
|
2017-11-19 00:36:43 -08:00
|
|
|
|
|
|
|
|
2021-04-30 12:33:12 +02:00
|
|
|
class TestCheckpointRestorePG(unittest.TestCase):
|
2020-03-12 04:39:47 +01:00
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
2021-04-30 12:33:12 +02:00
|
|
|
ray.init(num_cpus=5)
|
2020-03-12 04:39:47 +01:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ray.shutdown()
|
|
|
|
|
2020-06-05 21:07:02 +02:00
|
|
|
def test_a3c_checkpoint_restore(self):
|
|
|
|
ckpt_restore_test("A3C")
|
|
|
|
|
2021-04-30 12:33:12 +02:00
|
|
|
def test_ppo_checkpoint_restore(self):
|
|
|
|
ckpt_restore_test("PPO", object_store=True)
|
|
|
|
|
|
|
|
|
|
|
|
class TestCheckpointRestoreOffPolicy(unittest.TestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
ray.init(num_cpus=5)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ray.shutdown()
|
|
|
|
|
2020-06-05 21:07:02 +02:00
|
|
|
def test_apex_ddpg_checkpoint_restore(self):
|
|
|
|
ckpt_restore_test("APEX_DDPG")
|
|
|
|
|
|
|
|
def test_ddpg_checkpoint_restore(self):
|
|
|
|
ckpt_restore_test("DDPG")
|
|
|
|
|
|
|
|
def test_dqn_checkpoint_restore(self):
|
2021-04-30 12:33:12 +02:00
|
|
|
ckpt_restore_test("DQN", object_store=True)
|
2020-06-05 21:07:02 +02:00
|
|
|
|
|
|
|
def test_sac_checkpoint_restore(self):
|
|
|
|
ckpt_restore_test("SAC")
|
2020-03-12 04:39:47 +01:00
|
|
|
|
|
|
|
|
2021-04-30 12:33:12 +02:00
|
|
|
class TestCheckpointRestoreEvolutionAlgos(unittest.TestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
ray.init(num_cpus=5)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ray.shutdown()
|
|
|
|
|
|
|
|
def test_ars_checkpoint_restore(self):
|
|
|
|
ckpt_restore_test("ARS")
|
|
|
|
|
|
|
|
def test_es_checkpoint_restore(self):
|
|
|
|
ckpt_restore_test("ES")
|
|
|
|
|
|
|
|
|
2017-11-19 00:36:43 -08:00
|
|
|
if __name__ == "__main__":
|
2020-03-12 04:39:47 +01:00
|
|
|
import pytest
|
|
|
|
import sys
|
2021-04-30 12:33:12 +02:00
|
|
|
|
|
|
|
# One can specify the specific TestCase class to run.
|
|
|
|
# None for all unittest.TestCase classes in this file.
|
|
|
|
class_ = sys.argv[1] if len(sys.argv) > 1 else None
|
|
|
|
sys.exit(
|
|
|
|
pytest.main(
|
|
|
|
["-v", __file__ + ("" if class_ is None else "::" + class_)]))
|