mirror of
https://github.com/vale981/ray
synced 2025-03-10 05:16:49 -04:00
55 lines
1.4 KiB
Python
55 lines
1.4 KiB
Python
"""This example script shows how to load a connector enabled policy,
|
|
and use it in a serving/inference setting.
|
|
"""
|
|
|
|
import argparse
|
|
import gym
|
|
|
|
from ray.rllib.utils.policy import (
|
|
load_policies_from_checkpoint,
|
|
local_policy_inference,
|
|
)
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
# This should a checkpoint created with connectors enabled.
|
|
parser.add_argument(
|
|
"--checkpoint_file",
|
|
help="Path to an RLlib checkpoint file.",
|
|
)
|
|
parser.add_argument(
|
|
"--policy_id",
|
|
default="default_policy",
|
|
help="ID of policy to load.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
assert args.checkpoint_file, "Must specify flag --checkpoint_file."
|
|
|
|
|
|
def run():
|
|
# Restore policy.
|
|
policies = load_policies_from_checkpoint(args.checkpoint_file, [args.policy_id])
|
|
policy = policies[args.policy_id]
|
|
|
|
# Run CartPole.
|
|
env = gym.make("CartPole-v0")
|
|
obs = env.reset()
|
|
done = False
|
|
step = 0
|
|
while not done:
|
|
step += 1
|
|
|
|
# Use local_policy_inference() to run inference, so we do not have to
|
|
# provide policy states or extra fetch dictionaries.
|
|
policy_outputs = local_policy_inference(policy, "env_1", "agent_1", obs)
|
|
assert len(policy_outputs) == 1
|
|
action, _, _ = policy_outputs[0]
|
|
print(f"step {step}", obs, action)
|
|
|
|
# Step environment forward one more step.
|
|
obs, _, done, _ = env.step(action[0])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run()
|