mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
63 lines
2.1 KiB
Python
63 lines
2.1 KiB
Python
#!/usr/bin/env python
|
|
|
|
import os
|
|
import ray
|
|
|
|
from ray.rllib.algorithms.registry import get_algorithm_class
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
ray.init(num_cpus=10)
|
|
|
|
|
|
def train_and_export(algo_name, num_steps, model_dir, ckpt_dir, prefix):
|
|
cls = get_algorithm_class(algo_name)
|
|
alg = cls(config={}, env="CartPole-v0")
|
|
for _ in range(num_steps):
|
|
alg.train()
|
|
|
|
# Export tensorflow checkpoint for fine-tuning
|
|
alg.export_policy_checkpoint(ckpt_dir, filename_prefix=prefix)
|
|
# Export tensorflow SavedModel for online serving
|
|
alg.export_policy_model(model_dir)
|
|
|
|
|
|
def restore_saved_model(export_dir):
|
|
signature_key = (
|
|
tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
|
)
|
|
g = tf1.Graph()
|
|
with g.as_default():
|
|
with tf1.Session(graph=g) as sess:
|
|
meta_graph_def = tf1.saved_model.load(
|
|
sess, [tf1.saved_model.tag_constants.SERVING], export_dir
|
|
)
|
|
print("Model restored!")
|
|
print("Signature Def Information:")
|
|
print(meta_graph_def.signature_def[signature_key])
|
|
print("You can inspect the model using TensorFlow SavedModel CLI.")
|
|
print("https://www.tensorflow.org/guide/saved_model")
|
|
|
|
|
|
def restore_checkpoint(export_dir, prefix):
|
|
sess = tf1.Session()
|
|
meta_file = "%s.meta" % prefix
|
|
saver = tf1.train.import_meta_graph(os.path.join(export_dir, meta_file))
|
|
saver.restore(sess, os.path.join(export_dir, prefix))
|
|
print("Checkpoint restored!")
|
|
print("Variables Information:")
|
|
for v in tf1.trainable_variables():
|
|
value = sess.run(v)
|
|
print(v.name, value)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
algo = "DQN"
|
|
model_dir = os.path.join(ray._private.utils.get_user_temp_dir(), "model_export_dir")
|
|
ckpt_dir = os.path.join(ray._private.utils.get_user_temp_dir(), "ckpt_export_dir")
|
|
prefix = "model.ckpt"
|
|
num_steps = 3
|
|
train_and_export(algo, num_steps, model_dir, ckpt_dir, prefix)
|
|
restore_saved_model(model_dir)
|
|
restore_checkpoint(ckpt_dir, prefix)
|