ray/rllib/examples/export/cartpole_dqn_export.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

63 lines
2.1 KiB
Python

#!/usr/bin/env python
import os
import ray
from ray.rllib.agents.registry import get_trainer_class
from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
ray.init(num_cpus=10)
def train_and_export(algo_name, num_steps, model_dir, ckpt_dir, prefix):
cls = get_trainer_class(algo_name)
alg = cls(config={}, env="CartPole-v0")
for _ in range(num_steps):
alg.train()
# Export tensorflow checkpoint for fine-tuning
alg.export_policy_checkpoint(ckpt_dir, filename_prefix=prefix)
# Export tensorflow SavedModel for online serving
alg.export_policy_model(model_dir)
def restore_saved_model(export_dir):
signature_key = (
tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
)
g = tf1.Graph()
with g.as_default():
with tf1.Session(graph=g) as sess:
meta_graph_def = tf1.saved_model.load(
sess, [tf1.saved_model.tag_constants.SERVING], export_dir
)
print("Model restored!")
print("Signature Def Information:")
print(meta_graph_def.signature_def[signature_key])
print("You can inspect the model using TensorFlow SavedModel CLI.")
print("https://www.tensorflow.org/guide/saved_model")
def restore_checkpoint(export_dir, prefix):
sess = tf1.Session()
meta_file = "%s.meta" % prefix
saver = tf1.train.import_meta_graph(os.path.join(export_dir, meta_file))
saver.restore(sess, os.path.join(export_dir, prefix))
print("Checkpoint restored!")
print("Variables Information:")
for v in tf1.trainable_variables():
value = sess.run(v)
print(v.name, value)
if __name__ == "__main__":
algo = "DQN"
model_dir = os.path.join(ray._private.utils.get_user_temp_dir(), "model_export_dir")
ckpt_dir = os.path.join(ray._private.utils.get_user_temp_dir(), "ckpt_export_dir")
prefix = "model.ckpt"
num_steps = 3
train_and_export(algo, num_steps, model_dir, ckpt_dir, prefix)
restore_saved_model(model_dir)
restore_checkpoint(ckpt_dir, prefix)