2018-12-22 16:35:25 +08:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
2018-12-27 07:43:06 +08:00
|
|
|
import os
|
2018-12-22 16:35:25 +08:00
|
|
|
import ray
|
|
|
|
|
2018-12-27 07:43:06 +08:00
|
|
|
from ray.rllib.agents.registry import get_agent_class
|
2019-05-16 22:12:07 -07:00
|
|
|
from ray.rllib.utils import try_import_tf
|
|
|
|
|
|
|
|
tf = try_import_tf()
|
2018-12-22 16:35:25 +08:00
|
|
|
|
|
|
|
ray.init(num_cpus=10)
|
|
|
|
|
|
|
|
|
2018-12-27 07:43:06 +08:00
|
|
|
def train_and_export(algo_name, num_steps, model_dir, ckpt_dir, prefix):
|
2018-12-22 16:35:25 +08:00
|
|
|
cls = get_agent_class(algo_name)
|
|
|
|
alg = cls(config={}, env="CartPole-v0")
|
2018-12-27 07:43:06 +08:00
|
|
|
for _ in range(num_steps):
|
2018-12-22 16:35:25 +08:00
|
|
|
alg.train()
|
|
|
|
|
2018-12-27 07:43:06 +08:00
|
|
|
# Export tensorflow checkpoint for fine-tuning
|
|
|
|
alg.export_policy_checkpoint(ckpt_dir, filename_prefix=prefix)
|
|
|
|
# Export tensorflow SavedModel for online serving
|
|
|
|
alg.export_policy_model(model_dir)
|
2018-12-22 16:35:25 +08:00
|
|
|
|
|
|
|
|
2018-12-27 07:43:06 +08:00
|
|
|
def restore_saved_model(export_dir):
|
2018-12-22 16:35:25 +08:00
|
|
|
signature_key = \
|
|
|
|
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
|
|
|
g = tf.Graph()
|
|
|
|
with g.as_default():
|
|
|
|
with tf.Session(graph=g) as sess:
|
|
|
|
meta_graph_def = \
|
|
|
|
tf.saved_model.load(sess,
|
|
|
|
[tf.saved_model.tag_constants.SERVING],
|
|
|
|
export_dir)
|
|
|
|
print("Model restored!")
|
|
|
|
print("Signature Def Information:")
|
|
|
|
print(meta_graph_def.signature_def[signature_key])
|
|
|
|
print("You can inspect the model using TensorFlow SavedModel CLI.")
|
|
|
|
print("https://www.tensorflow.org/guide/saved_model")
|
|
|
|
|
|
|
|
|
2018-12-27 07:43:06 +08:00
|
|
|
def restore_checkpoint(export_dir, prefix):
|
|
|
|
sess = tf.Session()
|
|
|
|
meta_file = "%s.meta" % prefix
|
|
|
|
saver = tf.train.import_meta_graph(os.path.join(export_dir, meta_file))
|
|
|
|
saver.restore(sess, os.path.join(export_dir, prefix))
|
|
|
|
print("Checkpoint restored!")
|
|
|
|
print("Variables Information:")
|
|
|
|
for v in tf.trainable_variables():
|
|
|
|
value = sess.run(v)
|
|
|
|
print(v.name, value)
|
|
|
|
|
|
|
|
|
2018-12-22 16:35:25 +08:00
|
|
|
if __name__ == "__main__":
|
|
|
|
algo = "DQN"
|
2020-03-16 18:10:14 -07:00
|
|
|
model_dir = os.path.join(ray.utils.get_user_temp_dir(), "model_export_dir")
|
|
|
|
ckpt_dir = os.path.join(ray.utils.get_user_temp_dir(), "ckpt_export_dir")
|
2018-12-27 07:43:06 +08:00
|
|
|
prefix = "model.ckpt"
|
2018-12-22 16:35:25 +08:00
|
|
|
num_steps = 3
|
2018-12-27 07:43:06 +08:00
|
|
|
train_and_export(algo, num_steps, model_dir, ckpt_dir, prefix)
|
|
|
|
restore_saved_model(model_dir)
|
|
|
|
restore_checkpoint(ckpt_dir, prefix)
|