mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 * Reformatting * Fixing tests * Move atari-py install conditional to req.txt * migrate to new ale install method * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 Move atari-py install conditional to req.txt migrate to new ale install method Make parametric_actions_cartpole return float32 actions/obs Adding type conversions if obs/actions don't match space Add utils to make elements match gym space dtypes Co-authored-by: Jun Gong <jungong@anyscale.com> Co-authored-by: sven1977 <svenmika1977@gmail.com>
216 lines
6.5 KiB
Python
216 lines
6.5 KiB
Python
#!/usr/bin/env python
|
|
|
|
import numpy as np
|
|
import unittest
|
|
|
|
import ray
|
|
from ray.rllib.agents.registry import get_trainer_class
|
|
from ray.rllib.utils.test_utils import check, framework_iterator
|
|
|
|
|
|
def get_mean_action(alg, obs):
|
|
out = []
|
|
for _ in range(2000):
|
|
out.append(float(alg.compute_action(obs)))
|
|
return np.mean(out)
|
|
|
|
|
|
CONFIGS = {
|
|
"A3C": {
|
|
"explore": False,
|
|
"num_workers": 1,
|
|
},
|
|
"APEX_DDPG": {
|
|
"explore": False,
|
|
"observation_filter": "MeanStdFilter",
|
|
"num_workers": 2,
|
|
"min_iter_time_s": 1,
|
|
"optimizer": {
|
|
"num_replay_buffer_shards": 1,
|
|
},
|
|
},
|
|
"ARS": {
|
|
"explore": False,
|
|
"num_rollouts": 10,
|
|
"num_workers": 2,
|
|
"noise_size": 2500000,
|
|
"observation_filter": "MeanStdFilter",
|
|
},
|
|
"DDPG": {
|
|
"explore": False,
|
|
"timesteps_per_iteration": 100,
|
|
},
|
|
"DQN": {
|
|
"explore": False,
|
|
},
|
|
"ES": {
|
|
"explore": False,
|
|
"episodes_per_batch": 10,
|
|
"train_batch_size": 100,
|
|
"num_workers": 2,
|
|
"noise_size": 2500000,
|
|
"observation_filter": "MeanStdFilter",
|
|
},
|
|
"PPO": {
|
|
"explore": False,
|
|
"num_sgd_iter": 5,
|
|
"train_batch_size": 1000,
|
|
"num_workers": 2,
|
|
},
|
|
"SimpleQ": {
|
|
"explore": False,
|
|
},
|
|
"SAC": {
|
|
"explore": False,
|
|
},
|
|
}
|
|
|
|
|
|
def ckpt_restore_test(alg_name,
|
|
tfe=False,
|
|
object_store=False,
|
|
replay_buffer=False):
|
|
config = CONFIGS[alg_name].copy()
|
|
# If required, store replay buffer data in checkpoints as well.
|
|
if replay_buffer:
|
|
config["store_buffer_in_checkpoints"] = True
|
|
|
|
frameworks = (["tfe"] if tfe else []) + ["torch", "tf"]
|
|
for fw in framework_iterator(config, frameworks=frameworks):
|
|
for use_object_store in ([False, True] if object_store else [False]):
|
|
print("use_object_store={}".format(use_object_store))
|
|
cls = get_trainer_class(alg_name)
|
|
if "DDPG" in alg_name or "SAC" in alg_name:
|
|
alg1 = cls(config=config, env="Pendulum-v1")
|
|
alg2 = cls(config=config, env="Pendulum-v1")
|
|
else:
|
|
alg1 = cls(config=config, env="CartPole-v0")
|
|
alg2 = cls(config=config, env="CartPole-v0")
|
|
|
|
policy1 = alg1.get_policy()
|
|
|
|
for _ in range(1):
|
|
res = alg1.train()
|
|
print("current status: " + str(res))
|
|
|
|
# Check optimizer state as well.
|
|
optim_state = policy1.get_state().get("_optimizer_variables")
|
|
|
|
# Sync the models
|
|
if use_object_store:
|
|
alg2.restore_from_object(alg1.save_to_object())
|
|
else:
|
|
alg2.restore(alg1.save())
|
|
|
|
# Compare optimizer state with re-loaded one.
|
|
if optim_state:
|
|
s2 = alg2.get_policy().get_state().get("_optimizer_variables")
|
|
# Tf -> Compare states 1:1.
|
|
if fw in ["tf2", "tf", "tfe"]:
|
|
check(s2, optim_state)
|
|
# For torch, optimizers have state_dicts with keys=params,
|
|
# which are different for the two models (ignore these
|
|
# different keys, but compare all values nevertheless).
|
|
else:
|
|
for i, s2_ in enumerate(s2):
|
|
check(
|
|
list(s2_["state"].values()),
|
|
list(optim_state[i]["state"].values()))
|
|
|
|
# Compare buffer content with restored one.
|
|
if replay_buffer:
|
|
data = alg1.local_replay_buffer.replay_buffers[
|
|
"default_policy"]._storage[42:42 + 42]
|
|
new_data = alg2.local_replay_buffer.replay_buffers[
|
|
"default_policy"]._storage[42:42 + 42]
|
|
check(data, new_data)
|
|
|
|
for _ in range(1):
|
|
if "DDPG" in alg_name or "SAC" in alg_name:
|
|
obs = np.clip(
|
|
np.random.uniform(size=3),
|
|
policy1.observation_space.low,
|
|
policy1.observation_space.high)
|
|
else:
|
|
obs = np.clip(
|
|
np.random.uniform(size=4),
|
|
policy1.observation_space.low,
|
|
policy1.observation_space.high)
|
|
a1 = get_mean_action(alg1, obs)
|
|
a2 = get_mean_action(alg2, obs)
|
|
print("Checking computed actions", alg1, obs, a1, a2)
|
|
if abs(a1 - a2) > .1:
|
|
raise AssertionError("algo={} [a1={} a2={}]".format(
|
|
alg_name, a1, a2))
|
|
# Stop both Trainers.
|
|
alg1.stop()
|
|
alg2.stop()
|
|
|
|
|
|
class TestCheckpointRestorePG(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
ray.init(num_cpus=5)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
ray.shutdown()
|
|
|
|
def test_a3c_checkpoint_restore(self):
|
|
ckpt_restore_test("A3C")
|
|
|
|
def test_ppo_checkpoint_restore(self):
|
|
ckpt_restore_test("PPO", object_store=True)
|
|
|
|
|
|
class TestCheckpointRestoreOffPolicy(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
ray.init(num_cpus=5)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
ray.shutdown()
|
|
|
|
def test_apex_ddpg_checkpoint_restore(self):
|
|
ckpt_restore_test("APEX_DDPG")
|
|
|
|
def test_ddpg_checkpoint_restore(self):
|
|
ckpt_restore_test("DDPG", replay_buffer=True)
|
|
|
|
def test_dqn_checkpoint_restore(self):
|
|
ckpt_restore_test("DQN", object_store=True, replay_buffer=True)
|
|
|
|
def test_sac_checkpoint_restore(self):
|
|
ckpt_restore_test("SAC", replay_buffer=True)
|
|
|
|
def test_simpleq_checkpoint_restore(self):
|
|
ckpt_restore_test("SimpleQ", replay_buffer=True)
|
|
|
|
|
|
class TestCheckpointRestoreEvolutionAlgos(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
ray.init(num_cpus=5)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
ray.shutdown()
|
|
|
|
def test_ars_checkpoint_restore(self):
|
|
ckpt_restore_test("ARS")
|
|
|
|
def test_es_checkpoint_restore(self):
|
|
ckpt_restore_test("ES")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
# One can specify the specific TestCase class to run.
|
|
# None for all unittest.TestCase classes in this file.
|
|
class_ = sys.argv[1] if len(sys.argv) > 1 else None
|
|
sys.exit(
|
|
pytest.main(
|
|
["-v", __file__ + ("" if class_ is None else "::" + class_)]))
|