2020-05-27 16:19:13 +02:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
|
|
|
import os
|
|
|
|
import shutil
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
import ray
|
2021-02-08 12:05:16 +01:00
|
|
|
from ray.rllib.agents.registry import get_trainer_class
|
2021-02-10 15:21:46 +01:00
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2020-05-27 16:19:13 +02:00
|
|
|
from ray.tune.trial import ExportFormat
|
|
|
|
|
2021-02-10 15:21:46 +01:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
2020-05-27 16:19:13 +02:00
|
|
|
CONFIGS = {
|
|
|
|
"A3C": {
|
|
|
|
"explore": False,
|
|
|
|
"num_workers": 1,
|
|
|
|
},
|
|
|
|
"APEX_DDPG": {
|
|
|
|
"explore": False,
|
|
|
|
"observation_filter": "MeanStdFilter",
|
|
|
|
"num_workers": 2,
|
|
|
|
"min_iter_time_s": 1,
|
|
|
|
"optimizer": {
|
|
|
|
"num_replay_buffer_shards": 1,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
"ARS": {
|
|
|
|
"explore": False,
|
|
|
|
"num_rollouts": 10,
|
|
|
|
"num_workers": 2,
|
|
|
|
"noise_size": 2500000,
|
|
|
|
"observation_filter": "MeanStdFilter",
|
|
|
|
},
|
|
|
|
"DDPG": {
|
|
|
|
"explore": False,
|
|
|
|
"timesteps_per_iteration": 100,
|
|
|
|
},
|
|
|
|
"DQN": {
|
|
|
|
"explore": False,
|
|
|
|
},
|
|
|
|
"ES": {
|
|
|
|
"explore": False,
|
|
|
|
"episodes_per_batch": 10,
|
|
|
|
"train_batch_size": 100,
|
|
|
|
"num_workers": 2,
|
|
|
|
"noise_size": 2500000,
|
|
|
|
"observation_filter": "MeanStdFilter",
|
|
|
|
},
|
|
|
|
"PPO": {
|
|
|
|
"explore": False,
|
|
|
|
"num_sgd_iter": 5,
|
|
|
|
"train_batch_size": 1000,
|
|
|
|
"num_workers": 2,
|
|
|
|
},
|
|
|
|
"SAC": {
|
|
|
|
"explore": False,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2021-02-22 17:09:40 +01:00
|
|
|
def export_test(alg_name, failures, framework="tf"):
|
2020-05-27 16:19:13 +02:00
|
|
|
def valid_tf_model(model_dir):
|
|
|
|
return os.path.exists(os.path.join(model_dir, "saved_model.pb")) \
|
|
|
|
and os.listdir(os.path.join(model_dir, "variables"))
|
|
|
|
|
|
|
|
def valid_tf_checkpoint(checkpoint_dir):
|
|
|
|
return os.path.exists(os.path.join(checkpoint_dir, "model.meta")) \
|
|
|
|
and os.path.exists(os.path.join(checkpoint_dir, "model.index")) \
|
|
|
|
and os.path.exists(os.path.join(checkpoint_dir, "checkpoint"))
|
|
|
|
|
2021-02-08 12:05:16 +01:00
|
|
|
cls = get_trainer_class(alg_name)
|
2021-02-22 17:09:40 +01:00
|
|
|
config = CONFIGS[alg_name].copy()
|
|
|
|
config["framework"] = framework
|
2020-05-27 16:19:13 +02:00
|
|
|
if "DDPG" in alg_name or "SAC" in alg_name:
|
2021-02-22 17:09:40 +01:00
|
|
|
algo = cls(config=config, env="Pendulum-v0")
|
2020-05-27 16:19:13 +02:00
|
|
|
else:
|
2021-02-22 17:09:40 +01:00
|
|
|
algo = cls(config=config, env="CartPole-v0")
|
2020-05-27 16:19:13 +02:00
|
|
|
|
|
|
|
for _ in range(1):
|
|
|
|
res = algo.train()
|
|
|
|
print("current status: " + str(res))
|
|
|
|
|
|
|
|
export_dir = os.path.join(ray.utils.get_user_temp_dir(),
|
|
|
|
"export_dir_%s" % alg_name)
|
|
|
|
print("Exporting model ", alg_name, export_dir)
|
|
|
|
algo.export_policy_model(export_dir)
|
2021-02-22 17:09:40 +01:00
|
|
|
if framework == "tf" and not valid_tf_model(export_dir):
|
2020-05-27 16:19:13 +02:00
|
|
|
failures.append(alg_name)
|
|
|
|
shutil.rmtree(export_dir)
|
|
|
|
|
2021-02-22 17:09:40 +01:00
|
|
|
if framework == "tf":
|
|
|
|
print("Exporting checkpoint", alg_name, export_dir)
|
|
|
|
algo.export_policy_checkpoint(export_dir)
|
|
|
|
if framework == "tf" and not valid_tf_checkpoint(export_dir):
|
|
|
|
failures.append(alg_name)
|
|
|
|
shutil.rmtree(export_dir)
|
2020-05-27 16:19:13 +02:00
|
|
|
|
2021-02-22 17:09:40 +01:00
|
|
|
print("Exporting default policy", alg_name, export_dir)
|
|
|
|
algo.export_model([ExportFormat.CHECKPOINT, ExportFormat.MODEL],
|
|
|
|
export_dir)
|
|
|
|
if not valid_tf_model(os.path.join(export_dir, ExportFormat.MODEL)) \
|
|
|
|
or not valid_tf_checkpoint(
|
|
|
|
os.path.join(export_dir, ExportFormat.CHECKPOINT)):
|
|
|
|
failures.append(alg_name)
|
2021-02-10 15:21:46 +01:00
|
|
|
|
2021-02-22 17:09:40 +01:00
|
|
|
# Test loading the exported model.
|
|
|
|
model = tf.saved_model.load(
|
|
|
|
os.path.join(export_dir, ExportFormat.MODEL))
|
|
|
|
assert model
|
2021-02-10 15:21:46 +01:00
|
|
|
|
2021-02-22 17:09:40 +01:00
|
|
|
shutil.rmtree(export_dir)
|
2020-05-27 16:19:13 +02:00
|
|
|
|
|
|
|
|
|
|
|
class TestExport(unittest.TestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
|
|
|
ray.init(
|
|
|
|
num_cpus=10, object_store_memory=1e9, ignore_reinit_error=True)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
|
|
|
ray.shutdown()
|
|
|
|
|
2021-02-22 17:09:40 +01:00
|
|
|
def test_export_ppo(self):
|
|
|
|
failures = []
|
|
|
|
export_test("PPO", failures, "torch")
|
|
|
|
export_test("PPO", failures, "tf")
|
|
|
|
assert not failures, failures
|
|
|
|
|
2020-05-27 16:19:13 +02:00
|
|
|
def test_export(self):
|
|
|
|
failures = []
|
2021-02-22 17:09:40 +01:00
|
|
|
for name in ["A3C", "DQN", "DDPG", "SAC"]:
|
2020-05-27 16:19:13 +02:00
|
|
|
export_test(name, failures)
|
|
|
|
assert not failures, failures
|
|
|
|
print("All export tests passed!")
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import pytest
|
|
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|