#!/usr/bin/env python import os import ray from ray.rllib.algorithms.registry import get_algorithm_class from ray.rllib.utils.framework import try_import_tf tf1, tf, tfv = try_import_tf() ray.init(num_cpus=10) def train_and_export(algo_name, num_steps, model_dir, ckpt_dir, prefix): cls = get_algorithm_class(algo_name) alg = cls(config={}, env="CartPole-v0") for _ in range(num_steps): alg.train() # Export tensorflow checkpoint for fine-tuning alg.export_policy_checkpoint(ckpt_dir, filename_prefix=prefix) # Export tensorflow SavedModel for online serving alg.export_policy_model(model_dir) def restore_saved_model(export_dir): signature_key = ( tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY ) g = tf1.Graph() with g.as_default(): with tf1.Session(graph=g) as sess: meta_graph_def = tf1.saved_model.load( sess, [tf1.saved_model.tag_constants.SERVING], export_dir ) print("Model restored!") print("Signature Def Information:") print(meta_graph_def.signature_def[signature_key]) print("You can inspect the model using TensorFlow SavedModel CLI.") print("https://www.tensorflow.org/guide/saved_model") def restore_checkpoint(export_dir, prefix): sess = tf1.Session() meta_file = "%s.meta" % prefix saver = tf1.train.import_meta_graph(os.path.join(export_dir, meta_file)) saver.restore(sess, os.path.join(export_dir, prefix)) print("Checkpoint restored!") print("Variables Information:") for v in tf1.trainable_variables(): value = sess.run(v) print(v.name, value) if __name__ == "__main__": algo = "DQN" model_dir = os.path.join(ray._private.utils.get_user_temp_dir(), "model_export_dir") ckpt_dir = os.path.join(ray._private.utils.get_user_temp_dir(), "ckpt_export_dir") prefix = "model.ckpt" num_steps = 3 train_and_export(algo, num_steps, model_dir, ckpt_dir, prefix) restore_saved_model(model_dir) restore_checkpoint(ckpt_dir, prefix)