2017-08-27 18:56:52 -07:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
2018-12-22 16:35:25 +08:00
|
|
|
import os
|
|
|
|
import shutil
|
2019-04-07 16:11:50 -07:00
|
|
|
import gym
|
2017-09-28 13:12:06 -07:00
|
|
|
import numpy as np
|
2017-08-27 18:56:52 -07:00
|
|
|
import ray
|
|
|
|
|
2018-12-21 03:44:34 +09:00
|
|
|
from ray.rllib.agents.registry import get_agent_class
|
2019-02-01 09:07:27 +08:00
|
|
|
from ray.tune.trial import ExportFormat
|
2017-09-28 13:12:06 -07:00
|
|
|
|
|
|
|
|
|
|
|
def get_mean_action(alg, obs):
|
|
|
|
out = []
|
|
|
|
for _ in range(2000):
|
|
|
|
out.append(float(alg.compute_action(obs)))
|
|
|
|
return np.mean(out)
|
|
|
|
|
|
|
|
|
2019-10-31 15:16:02 -07:00
|
|
|
ray.init(num_cpus=10, object_store_memory=1e9)
|
2017-10-13 16:18:16 -07:00
|
|
|
|
|
|
|
CONFIGS = {
|
2020-03-01 20:53:35 +01:00
|
|
|
"SAC": {
|
|
|
|
"explore": False,
|
|
|
|
},
|
2018-07-19 15:30:36 -07:00
|
|
|
"ES": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2018-07-19 15:30:36 -07:00
|
|
|
"episodes_per_batch": 10,
|
2018-09-05 12:06:13 -07:00
|
|
|
"train_batch_size": 100,
|
2018-11-06 17:09:34 -10:00
|
|
|
"num_workers": 2,
|
2019-10-31 15:16:02 -07:00
|
|
|
"noise_size": 2500000,
|
2018-11-06 17:09:34 -10:00
|
|
|
"observation_filter": "MeanStdFilter"
|
2018-07-19 15:30:36 -07:00
|
|
|
},
|
2020-03-01 20:53:35 +01:00
|
|
|
"DQN": {
|
|
|
|
"explore": False
|
|
|
|
},
|
2018-09-03 20:01:53 -07:00
|
|
|
"APEX_DDPG": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2018-09-03 20:01:53 -07:00
|
|
|
"observation_filter": "MeanStdFilter",
|
|
|
|
"num_workers": 2,
|
|
|
|
"min_iter_time_s": 1,
|
|
|
|
"optimizer": {
|
|
|
|
"num_replay_buffer_shards": 1,
|
|
|
|
},
|
|
|
|
},
|
2018-07-19 15:30:36 -07:00
|
|
|
"DDPG": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2018-07-19 15:30:36 -07:00
|
|
|
"timesteps_per_iteration": 100
|
|
|
|
},
|
|
|
|
"PPO": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2018-07-19 15:30:36 -07:00
|
|
|
"num_sgd_iter": 5,
|
2018-09-05 12:06:13 -07:00
|
|
|
"train_batch_size": 1000,
|
2018-07-19 15:30:36 -07:00
|
|
|
"num_workers": 2
|
|
|
|
},
|
|
|
|
"A3C": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2018-07-19 15:30:36 -07:00
|
|
|
"num_workers": 1
|
|
|
|
},
|
2018-11-06 17:09:34 -10:00
|
|
|
"ARS": {
|
2020-03-01 20:53:35 +01:00
|
|
|
"explore": False,
|
2018-11-06 17:09:34 -10:00
|
|
|
"num_rollouts": 10,
|
|
|
|
"num_workers": 2,
|
2019-10-31 15:16:02 -07:00
|
|
|
"noise_size": 2500000,
|
2018-11-06 17:09:34 -10:00
|
|
|
"observation_filter": "MeanStdFilter"
|
|
|
|
}
|
2017-10-13 16:18:16 -07:00
|
|
|
}
|
|
|
|
|
2017-11-19 00:36:43 -08:00
|
|
|
|
2018-12-22 16:35:25 +08:00
|
|
|
def test_ckpt_restore(use_object_store, alg_name, failures):
|
2017-11-19 00:36:43 -08:00
|
|
|
cls = get_agent_class(alg_name)
|
2019-08-01 23:37:36 -07:00
|
|
|
if "DDPG" in alg_name or "SAC" in alg_name:
|
2018-04-19 22:36:29 -07:00
|
|
|
alg1 = cls(config=CONFIGS[name], env="Pendulum-v0")
|
|
|
|
alg2 = cls(config=CONFIGS[name], env="Pendulum-v0")
|
2019-04-07 16:11:50 -07:00
|
|
|
env = gym.make("Pendulum-v0")
|
2018-04-19 22:36:29 -07:00
|
|
|
else:
|
|
|
|
alg1 = cls(config=CONFIGS[name], env="CartPole-v0")
|
|
|
|
alg2 = cls(config=CONFIGS[name], env="CartPole-v0")
|
2019-04-07 16:11:50 -07:00
|
|
|
env = gym.make("CartPole-v0")
|
2017-08-27 18:56:52 -07:00
|
|
|
|
2020-03-01 20:53:35 +01:00
|
|
|
for _ in range(2):
|
2017-08-27 18:56:52 -07:00
|
|
|
res = alg1.train()
|
|
|
|
print("current status: " + str(res))
|
|
|
|
|
|
|
|
# Sync the models
|
2017-11-19 00:36:43 -08:00
|
|
|
if use_object_store:
|
|
|
|
alg2.restore_from_object(alg1.save_to_object())
|
|
|
|
else:
|
|
|
|
alg2.restore(alg1.save())
|
2017-08-27 18:56:52 -07:00
|
|
|
|
2020-02-27 00:22:54 +01:00
|
|
|
for _ in range(5):
|
2019-08-01 23:37:36 -07:00
|
|
|
if "DDPG" in alg_name or "SAC" in alg_name:
|
2019-04-07 16:11:50 -07:00
|
|
|
obs = np.clip(
|
|
|
|
np.random.uniform(size=3),
|
|
|
|
env.observation_space.low,
|
|
|
|
env.observation_space.high)
|
2018-04-19 22:36:29 -07:00
|
|
|
else:
|
2019-04-07 16:11:50 -07:00
|
|
|
obs = np.clip(
|
|
|
|
np.random.uniform(size=4),
|
|
|
|
env.observation_space.low,
|
|
|
|
env.observation_space.high)
|
2017-09-28 13:12:06 -07:00
|
|
|
a1 = get_mean_action(alg1, obs)
|
|
|
|
a2 = get_mean_action(alg2, obs)
|
|
|
|
print("Checking computed actions", alg1, obs, a1, a2)
|
2018-06-09 00:21:35 -07:00
|
|
|
if abs(a1 - a2) > .1:
|
|
|
|
failures.append((alg_name, [a1, a2]))
|
2017-11-19 00:36:43 -08:00
|
|
|
|
|
|
|
|
2018-12-22 16:35:25 +08:00
|
|
|
def test_export(algo_name, failures):
|
2019-02-01 09:07:27 +08:00
|
|
|
def valid_tf_model(model_dir):
|
|
|
|
return os.path.exists(os.path.join(model_dir, "saved_model.pb")) \
|
|
|
|
and os.listdir(os.path.join(model_dir, "variables"))
|
|
|
|
|
|
|
|
def valid_tf_checkpoint(checkpoint_dir):
|
|
|
|
return os.path.exists(os.path.join(checkpoint_dir, "model.meta")) \
|
|
|
|
and os.path.exists(os.path.join(checkpoint_dir, "model.index")) \
|
|
|
|
and os.path.exists(os.path.join(checkpoint_dir, "checkpoint"))
|
|
|
|
|
2018-12-22 16:35:25 +08:00
|
|
|
cls = get_agent_class(algo_name)
|
2019-08-01 23:37:36 -07:00
|
|
|
if "DDPG" in algo_name or "SAC" in algo_name:
|
2018-12-22 16:35:25 +08:00
|
|
|
algo = cls(config=CONFIGS[name], env="Pendulum-v0")
|
|
|
|
else:
|
|
|
|
algo = cls(config=CONFIGS[name], env="CartPole-v0")
|
|
|
|
|
|
|
|
for _ in range(3):
|
|
|
|
res = algo.train()
|
|
|
|
print("current status: " + str(res))
|
|
|
|
|
|
|
|
export_dir = "/tmp/export_dir_%s" % algo_name
|
|
|
|
print("Exporting model ", algo_name, export_dir)
|
|
|
|
algo.export_policy_model(export_dir)
|
2019-02-01 09:07:27 +08:00
|
|
|
if not valid_tf_model(export_dir):
|
2018-12-22 16:35:25 +08:00
|
|
|
failures.append(algo_name)
|
|
|
|
shutil.rmtree(export_dir)
|
|
|
|
|
2018-12-27 07:43:06 +08:00
|
|
|
print("Exporting checkpoint", algo_name, export_dir)
|
|
|
|
algo.export_policy_checkpoint(export_dir)
|
2019-02-01 09:07:27 +08:00
|
|
|
if not valid_tf_checkpoint(export_dir):
|
|
|
|
failures.append(algo_name)
|
|
|
|
shutil.rmtree(export_dir)
|
|
|
|
|
|
|
|
print("Exporting default policy", algo_name, export_dir)
|
|
|
|
algo.export_model([ExportFormat.CHECKPOINT, ExportFormat.MODEL],
|
|
|
|
export_dir)
|
|
|
|
if not valid_tf_model(os.path.join(export_dir, ExportFormat.MODEL)) \
|
|
|
|
or not valid_tf_checkpoint(os.path.join(export_dir,
|
|
|
|
ExportFormat.CHECKPOINT)):
|
2018-12-27 07:43:06 +08:00
|
|
|
failures.append(algo_name)
|
|
|
|
shutil.rmtree(export_dir)
|
|
|
|
|
2018-12-22 16:35:25 +08:00
|
|
|
|
2017-11-19 00:36:43 -08:00
|
|
|
if __name__ == "__main__":
|
2018-06-09 00:21:35 -07:00
|
|
|
failures = []
|
2017-11-19 00:36:43 -08:00
|
|
|
for use_object_store in [False, True]:
|
2019-08-01 23:37:36 -07:00
|
|
|
for name in [
|
|
|
|
"SAC", "ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG", "ARS"
|
|
|
|
]:
|
2018-12-22 16:35:25 +08:00
|
|
|
test_ckpt_restore(use_object_store, name, failures)
|
2017-11-19 00:36:43 -08:00
|
|
|
|
2018-06-09 00:21:35 -07:00
|
|
|
assert not failures, failures
|
2017-11-19 00:36:43 -08:00
|
|
|
print("All checkpoint restore tests passed!")
|
2018-12-22 16:35:25 +08:00
|
|
|
|
|
|
|
failures = []
|
2019-08-01 23:37:36 -07:00
|
|
|
for name in ["SAC", "DQN", "DDPG", "PPO", "A3C"]:
|
2018-12-22 16:35:25 +08:00
|
|
|
test_export(name, failures)
|
|
|
|
assert not failures, failures
|
|
|
|
print("All export tests passed!")
|